当前位置: 首页 > news >正文

(动手学习深度学习)第13章 计算机视觉---微调

文章目录

    • 微调
      • 总结
    • 微调代码实现

微调

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

总结

  • 微调通过使用在大数据上的恶道的预训练好的模型来初始化模型权重来完成提升精度。
  • 预训练模型质量很重要
  • 微调通常速度更快、精确度更高

微调代码实现

  1. 导入相关库
%matplotlib inline
import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
import matplotlib as plt
  1. 获取数据集
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip','fba480ffa8aa7e0febbb511d181409f899b9baa5')data_dir = d2l.download_extract('hotdog')
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir,'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir,'test'))
print(train_imgs)
print(train_imgs[0])
train_imgs[0][0]

在这里插入图片描述
查看数据集中图像的形状

hotdogs = [train_imgs[i][0] for i in range(8)]
not_hotdogs= [train_imgs[-i-1][0] for i in range(8)]
d2l.show_images(hotdogs + not_hotdogs, 2 ,8, scale=1.4)

在这里插入图片描述

  1. 数据增强
# 图像增广
normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224,0.225]
)
train_augs = torchvision.transforms.Compose(  # 训练集数据增强[torchvision.transforms.RandomResizedCrop(224),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),normalize]
)
test_augs = torchvision.transforms.Compose(  # 验证集不做数据增强[torchvision.transforms.Resize(256),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),normalize]
)
  1. 定义和初始化模型
# 下载resnet18,
# 老:pretrain=True: 也下载预训练的模型参数
# 新:weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1
pretrained_net = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
print(pretrained_net.fc)

在这里插入图片描述

  1. 微调模型
  • (1)直接修改网络层(如最后全连接层:512—>1000,改成512—>2)
  • (2)在增加一层分类层(如:512—>1000, 改成512—>1000, 1000—>2)

本次选择(1):将resnet18最后全连接层的输出,改成自己训练集的类别,并初始化最后全连接层的权重参数

finetune_net = pretrained_net
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight)

在这里插入图片描述

print(finetune_net)

在这里插入图片描述

  1. 训练模型
  • 特征提取层(预训练层):使用较小的学习率
  • 输出全连接层(微调层):使用较大的学习率
def train_fine_tuning(net, learning_rate, batch_size=128, num_epochs=10, param_group=True):train_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir,'train'), transform=train_augs),batch_size=batch_size,shuffle=True)test_iter = torch.utils.data.DataLoader(torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'), transform=test_augs),batch_size=batch_size)device = d2l.try_all_gpus()loss = nn.CrossEntropyLoss(reduction='none')if param_group:params_1x = [param for name, param in net.named_parameters()if name not in ['fc.weight', 'fc.bias']]trainer = torch.optim.SGD([{'params': params_1x}, {'params': net.fc.parameters(), 'lr': learning_rate * 10}],lr=learning_rate, weight_decay=0.001)else:trainer = torch.optim.SGD(net.parameters(),lr=learning_rate,weight_decay=0.001)d2l.train_ch13(net, train_iter, test_iter, loss,trainer, num_epochs, device)

训练模型

import time# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以train_fine_tuning(finetune_net, 5e-5, 128, 10)# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f} s')

在这里插入图片描述

直接训练:整个模型都使用相同的学习率,重新训练

scracth_net = torchvision.models.resnet18()
scracth_net.fc = nn.Linear(scracth_net.fc.in_features, 2)import time# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以train_fine_tuning(scracth_net, 5e-4, param_group=False)# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f} s')

在这里插入图片描述

相关文章:

  • C++中sort()函数的greater<int>()参数
  • 如何为视频添加旁白,有哪些操作技巧?
  • unexpected end of stream on
  • Python武器库开发-flask篇之error404(二十七)
  • CMS与FullGC
  • 动态规划43(Leetcode91解码方法)
  • JS原生-弹框+阿里巴巴矢量图
  • 华为摄像头通过stm32叠加字符串
  • WPF中Dispatcher对象的用途是什么
  • 分发糖果(贪心算法)
  • VivadoAndTcl: namespace
  • 【Essential C++学习笔记】第四章 基于对象的编程风格
  • SIMULIA-Simpack 2022x新功能介绍
  • 11.16~11.19绘制图表,导入EXCEL中数据,进行拟合
  • 纯JS,RSA,AES,公钥,私钥生成及加解密
  • 「译」Node.js Streams 基础
  • ES6, React, Redux, Webpack写的一个爬 GitHub 的网页
  • es的写入过程
  • Javascript编码规范
  • JDK9: 集成 Jshell 和 Maven 项目.
  • MySQL用户中的%到底包不包括localhost?
  • React-生命周期杂记
  • 第13期 DApp 榜单 :来,吃我这波安利
  • 前嗅ForeSpider教程:创建模板
  • 小程序滚动组件,左边导航栏与右边内容联动效果实现
  • 异步
  • 正则表达式
  • 翻译 | The Principles of OOD 面向对象设计原则
  • ​虚拟化系列介绍(十)
  • #define与typedef区别
  • #LLM入门|Prompt#1.7_文本拓展_Expanding
  • #LLM入门|Prompt#2.3_对查询任务进行分类|意图分析_Classification
  • (k8s中)docker netty OOM问题记录
  • (八十八)VFL语言初步 - 实现布局
  • (读书笔记)Javascript高级程序设计---ECMAScript基础
  • (二开)Flink 修改源码拓展 SQL 语法
  • (附源码)php新闻发布平台 毕业设计 141646
  • (附源码)ssm捐赠救助系统 毕业设计 060945
  • (规划)24届春招和25届暑假实习路线准备规划
  • (含react-draggable库以及相关BUG如何解决)固定在左上方某盒子内(如按钮)添加可拖动功能,使用react hook语法实现
  • (经验分享)作为一名普通本科计算机专业学生,我大学四年到底走了多少弯路
  • (欧拉)openEuler系统添加网卡文件配置流程、(欧拉)openEuler系统手动配置ipv6地址流程、(欧拉)openEuler系统网络管理说明
  • (全注解开发)学习Spring-MVC的第三天
  • (四)模仿学习-完成后台管理页面查询
  • (转)socket Aio demo
  • .form文件_SSM框架文件上传篇
  • .NET 动态调用WebService + WSE + UsernameToken
  • .Net(C#)常用转换byte转uint32、byte转float等
  • .NET/C# 获取一个正在运行的进程的命令行参数
  • .Net下使用 Geb.Video.FFMPEG 操作视频文件
  • .NET正则基础之——正则委托
  • /dev/sda2 is mounted; will not make a filesystem here!
  • ::
  • [Electron]ipcMain.on和ipcMain.handle的区别
  • [javaSE] 数据结构(二叉查找树-插入节点)