遥感影像分类任务的复现
目录
一、概述
二、环境配置
三、运行
1.dataset
2.configs
代码下载地址:GitHub - cxyth/rs-segment.pytorch: 用于遥感影像分类任务的语义分割模板工程
部分解码器下载地址:
qubvel/segmentation_models.pytorch: Segmentation models with pretrained backbones. PyTorch. (github.com)
一、概述
该代码中有如下几个文件夹:
- configs中是一个配置文件,用来运行模型
- dataset中是一个myDataset.py的文件,主要用来处理输入的数据,可以将自己的数据集放到此处
- models中包含了该项目可用的encoders如resnet、densenet等,如果需要对decoders代码进行改进,可以将需要改进的解码器文件放到该目录下,解码器文件获取地址decoders
- scripts中包含了一系列数据处理文件
- utils中包含了损失函数,混淆矩阵等内容的定义
二、环境配置
- torch>1.6
- torchsummay
- segmentation_models_pytorch
pip install git+https://github.com/qubvel/segmentation_models.pytorch
(或者直接下载工程然后python setup.py install)- gdal (gdal无法通过pip下载,详细请见:)
- albumentations
- opencv
- matplotlib
- scipy
- scikit-learn
- scikit-image
- tqdm
- pandas
三、运行
python main.py -c test.yml -m train -g 0,1
其中main.py是我们需要运行的main文件 ,yml.py是调用的配置文件,train/test表示要训练还是测试,-g 0,1 表示要用的GPU,不写该项默认为0。训练和测试的结果会放在新建的runs文件中。
可以看到这个模型使用起来是非常简单的,其中比较重要的就是数据怎么放和如何写配置文件来运行你想调用的模型。
1.dataset
我们可以按照一下目录在dataset文件夹中设置自己数据,Test的文件夹以同样的方式放入dataset文件夹中 。
'images'目录存放图像,'labels'目录存放标签,并保证图像和对应的标签同名。
2.configs
如下是一个configs的示例,具体内容详见代码
mode: "train" # "train", "infer"
run_name: ""
run_dir: "./runs/"
comment: "for debug" # 其他备注信息
dataset_params: {
name: "DeepGlobe-road", #项目的名字
train_dirs: ["E:\\New_gj\\dataset\\T_multi\\split\\train"], #训练集文件夹
val_dirs: ["E:\\New_gj\\dataset\\T_multi\\split\\val"], #验证集文件夹
image_ext: ".tif", #文件格式
cls_info: { #分类类别的类别名及value值,0为背景
nodata: 0,
suger: 1,
corn: 2,
paddy: 3,
banana: 4,
orange: 5,
},
cls_color: [ #设置mask的颜色
[0,0,0],
[255, 0, 0],
[255, 255, 0],
[255, 0, 255],
[0, 255, 0],
[0, 0, 255],
],
resample: false #是否进行重采样
}
network_params: {
type: "smp", # "custom", ... #smp为使用库内模型,custom为本地模型
arch: "Unet", #decoder_name
encoder: "efficientnet-b1", #encoder_name
in_height: 512, #输入尺寸
in_width: 512,
in_channel: 3, #输入通道
out_channel: 6, #加背景的分类数 #输出通道,即类别数
pretrained: "imagenet"
}
train_params: {
epochs: 100, # 3, 9, 21, 45, 93...
batch_size: 8,
lr: 0.001,
smoothing: 0.1,
cutmix: false,
gamma: 0.98, # 学习率衰减系数
momentum: 0.9, # 动量
weight_decay: 0.0005, # 权重衰减
save_inter: 2, # 保存间隔(epoch)
min_inter: 10, # 保存起始点(epoch)
iter_inter: 100, # 显示迭代间隔(batch)
plot: true
}
inference_params: {
ckpt_name: "checkpoint-best.pth", # full path = os.path.join(run_dir, run_name, "ckpt", ckpt_name)
in_dir: "E:\\New_gj\\dataset\\T_multi\\predict\\images", #测试的images的文件地址,注意是images的地址
out_dir: "val_best_single", # full path = os.path.join(run_dir, out_dir, "results")
tile_size: 512,
overlap: 256,
batch_size: 1,
tta: false,
draw: true,
evaluate: false
}