基于Swin-Transformer的语义分割实战:从VOC数据集制作到模型训练

张开发
2026/5/17 18:34:29 15 分钟阅读
基于Swin-Transformer的语义分割实战:从VOC数据集制作到模型训练
1. 为什么选择Swin-Transformer做语义分割第一次接触Swin-Transformer是在2021年ICCV会议上这个模型获得了最佳论文奖。当时我就被它的设计思路吸引了——它完美结合了CNN的层次化结构和Transformer的全局建模能力。在实际项目中测试后发现相比传统CNN模型Swin-Transformer在语义分割任务上确实有显著优势。最让我印象深刻的是它的计算效率。传统ViT模型在处理高分辨率图像时计算量会呈平方级增长。而Swin-Transformer通过局部窗口注意力机制将复杂度降到了线性级别。举个例子在处理512x512的医学图像时训练速度比普通ViT快了近3倍显存占用也减少了40%左右。另一个优势是它的多尺度特征提取能力。通过分层设计模型可以像CNN一样逐步扩大感受野。这在处理VOC这类包含不同尺度物体的数据集时特别有用。实测在PASCAL VOC2012上使用相同训练策略Swin-Tiny比ResNet50的mIoU高出约5个百分点。2. 准备VOC格式数据集2.1 数据集目录结构规范VOC数据集的组织方式很有讲究。我建议按照以下结构整理VOCdevkit/ └── VOC2012/ ├── ImageSets/ │ └── Segmentation/ │ ├── train.txt │ ├── val.txt │ └── trainval.txt ├── JPEGImages/ │ └── *.jpg └── SegmentationClass/ └── *.png这里最容易出错的是标注图像的处理。很多人会犯两个错误一是使用JPG格式存储标注必须用PNG二是忽略标签值的连续性。比如做二分类时背景必须是0前景必须是1中间不能有其他数值。我有次训练时发现准确率始终上不去后来发现是标注图里混入了255这个值。2.2 标注工具推荐对于自制数据集我常用Labelme和CVAT这两个工具。Labelme适合小规模标注它的JSON格式可以方便地转为VOC格式import json import cv2 from pathlib import Path def labelme_to_voc(json_path, output_dir): with open(json_path) as f: data json.load(f) img cv2.imread(data[imagePath]) label np.zeros(img.shape[:2], dtypenp.uint8) for shape in data[shapes]: points np.array(shape[points], dtypenp.int32) cv2.fillPoly(label, [points], color1) # 假设是二分类 cv2.imwrite(str(output_dir/SegmentationClass/Path(json_path).stem.png), label)对于大规模标注CVAT的团队协作功能更实用。它支持直接导出VOC格式还能做标注质量检查。3. 环境配置与模型准备3.1 搭建Python环境我强烈建议使用conda创建独立环境。以下是经过多次测试最稳定的配置方案conda create -n swinseg python3.8 -y conda activate swinseg conda install pytorch1.8.1 torchvision0.9.1 cudatoolkit11.1 -c pytorch pip install mmcv-full1.3.17 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.8.0/index.html注意MMCV的版本必须与PyTorch匹配。曾经有次用了MMCV 1.4.0导致训练时出现诡异的NaN损失回退到1.3.17才解决。3.2 下载和修改Swin-Transformer代码从官方仓库克隆代码后需要做几处关键修改修改configs/_base_/datasets/pascal_voc12.py中的data_root路径在mmseg/datasets/voc.py中更新CLASSES和PALETTE单卡训练时需要将SyncBN改为BN# 在configs/_base_/models/upernet_swin.py中 norm_cfgdict(typeBN, requires_gradTrue)对于二分类任务还需要修改num_classes参数。我写了个批量修改的脚本import os import mmcv config_files [configs/swin/upernet_swin_tiny_patch4_window7_512x512.py, configs/_base_/models/upernet_swin.py] for cfg_file in config_files: cfg mmcv.Config.fromfile(cfg_file) cfg.model.decode_head.num_classes 2 cfg.model.auxiliary_head.num_classes 2 cfg.dump(cfg_file)4. 模型训练技巧与调优4.1 关键训练参数设置在tools/train.py中这几个参数对结果影响最大# 学习率设置 optimizer dict( typeAdamW, lr6e-5, # 小数据集可以降到3e-5 betas(0.9, 0.999), weight_decay0.01) # 数据增强 train_pipeline [ dict(typeLoadImageFromFile), dict(typeLoadAnnotations), dict(typeRandomFlip, prob0.5), dict(typePhotoMetricDistortion), # 亮度、对比度扰动 dict(typeNormalize, mean[123.675, 116.28, 103.53], std[58.395, 57.12, 57.375]), dict(typePad, size(512, 512), pad_val0, seg_pad_val255), dict(typeDefaultFormatBundle), dict(typeCollect, keys[img, gt_semantic_seg]) ]我发现在VOC上使用更大的crop size如640x640能提升小物体的分割效果但需要调整window_size参数保持整除关系。4.2 解决常见训练问题问题1损失值震荡大解决方法减小学习率并增加warmup步数lr_config dict( policypoly, warmuplinear, warmup_iters1500, # 从500增加到1500 warmup_ratio1e-6, power1.0, min_lr0.0, by_epochFalse)问题2显存不足降低batch size使用梯度累积optimizer_config dict( typeGradientCumulativeOptimizerHook, cumulative_iters2) # 每2个iter更新一次问题3过拟合增加数据增强早停策略evaluation dict( interval1, metricmIoU, save_bestmIoU, rulegreater)5. 模型测试与部署训练完成后可以用这个脚本测试单张图片from mmseg.apis import inference_segmentor, init_segmentor import mmcv config_file configs/swin/upernet_swin_tiny_patch4_window7_512x512.py checkpoint_file work_dirs/latest.pth model init_segmentor(config_file, checkpoint_file, devicecuda:0) img test.jpg result inference_segmentor(model, img) # 可视化 palette [[0,0,0], [255,255,255]] # 根据类别修改 seg_map result[0].astype(np.uint8) seg_img Image.fromarray(seg_map).convert(P) seg_img.putpalette(np.array(palette, dtypenp.uint8).flatten()) seg_img.save(result.png)对于工业部署我推荐将模型导出为ONNX格式python tools/pytorch2onnx.py \ configs/swin/upernet_swin_tiny_patch4_window7_512x512.py \ work_dirs/latest.pth \ --output-file model.onnx \ --shape 512 512在部署时有个小技巧如果输入分辨率固定可以修改config中的img_size参数重新导出模型能提升推理速度。我在 Jetson Xavier 上测试512x512的输入能达到15FPS完全满足实时性要求。

更多文章