当前位置: 首页 > news >正文

EfficientNet-v2-s图像分类训练(简洁版)

使用torchvision集成的efficientnet-v2-s模型,调用torchvision库中的Oxford IIIT Pet数据集,对模型进行训练。
若有修改要求,可以修改以下部分:

train_dataset = OxfordIIITPet(root='./data', split='trainval', download=True, transform=transform_train)
test_dataset = OxfordIIITPet(root='./data', split='test', download=True, transform=transform_test)
#常见数据集可以直接加载,若是自己的数据集就自己写个dataset/dataloader
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 37)
#37为数据集类别数,修改为自己对应的
scheduler = ReduceLROnPlateau(optimizer, 'max', patience=3, factor=0.1, verbose=True)
#学习率处可以自己调整,可玩性较高

训练截图:
在这里插入图片描述

其实十轮左右就稳定在90以上了,跑了三十轮,记得修改保存路径,我这里是用kaggle跑的。
代码如下:

import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.models import efficientnet_v2_s
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import OxfordIIITPet
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm# 数据预处理 + 数据增强
transform_train = transforms.Compose([transforms.Resize((256, 256)),  # 增大图片预处理尺寸transforms.RandomCrop((224, 224)),  # 随机裁剪到模型输入尺寸transforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # 颜色抖动transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 加载数据集
train_dataset = OxfordIIITPet(root='./data', split='trainval', download=True, transform=transform_train)
test_dataset = OxfordIIITPet(root='./data', split='test', download=True, transform=transform_test)train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)# 模型定义
model = efficientnet_v2_s(pretrained=True)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 37)# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.001)
scheduler = ReduceLROnPlateau(optimizer, 'max', patience=3, factor=0.1, verbose=True)# 训练模型
def train_model(num_epochs):model.train()best_accuracy = 0for epoch in range(num_epochs):model.train()running_loss = 0.0for inputs, labels in tqdm(train_loader, desc=f'Training Epoch {epoch + 1}'):inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()# 每个 epoch 后测试accuracy = test_model()scheduler.step(accuracy)# 如果当前模型表现更好,保存模型if accuracy > best_accuracy:best_accuracy = accuracytorch.save(model.state_dict(), '/kaggle/working/best_oxford_pets_efficientnetv2.pth')print(f'New best model saved with accuracy: {best_accuracy:.2f}%')def test_model():model.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in tqdm(test_loader, desc='Testing'):inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()# 调试输出if total < 50:  # 只打印前50个样本的信息print(f'Predicted: {predicted[:10]}, Labels: {labels[:10]}')accuracy = 100 * correct / totalprint(f'Testing Accuracy: {accuracy:.2f}%')return accuracytrain_model(30)

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • DataX介绍
  • Python模块中的全局变量
  • Mecanim Animation System
  • Golang | Leetcode Golang题解之第310题最小高度树
  • 音视频入门基础:WAV专题(5)——FFmpeg源码中解码WAV Header的实现
  • Linux Socket TCP处理粘包问题
  • 实现基于 Python 和 xterm.js 的 Web 交互终端demo
  • 掌控情绪,驾驭人生,在人生的漫长旅程中,情绪如同多变的天气,时而风和日丽,时而狂风骤雨
  • pypinyin,一个有趣的 Python 库!
  • 关于qt中如何布局
  • c++ - 模拟实现set、map
  • vscode启动不了的问题解决
  • 5 mysql 查询语句
  • Java中等题-多数元素2(力扣)【摩尔投票升级版】
  • 黑暗之魂和艾尔登法环有什么联系吗 黑暗之魂和艾尔登法环哪一个好玩 苹果电脑怎么玩Windows游戏 apple电脑可以玩游戏吗
  • [分享]iOS开发 - 实现UITableView Plain SectionView和table不停留一起滑动
  • 0x05 Python数据分析,Anaconda八斩刀
  • android百种动画侧滑库、步骤视图、TextView效果、社交、搜房、K线图等源码
  • Angular6错误 Service: No provider for Renderer2
  • IE报vuex requires a Promise polyfill in this browser问题解决
  • javascript面向对象之创建对象
  • JavaScript设计模式系列一:工厂模式
  • JavaScript设计模式之工厂模式
  • Just for fun——迅速写完快速排序
  • Laravel 中的一个后期静态绑定
  • leetcode386. Lexicographical Numbers
  • opencv python Meanshift 和 Camshift
  • Python学习笔记 字符串拼接
  • vue自定义指令实现v-tap插件
  • 包装类对象
  • 短视频宝贝=慢?阿里巴巴工程师这样秒开短视频
  • 开源地图数据可视化库——mapnik
  • 入手阿里云新服务器的部署NODE
  • 文本多行溢出显示...之最后一行不到行尾的解决
  • 新版博客前端前瞻
  • ​Distil-Whisper:比Whisper快6倍,体积小50%的语音识别模型
  • ​ssh免密码登录设置及问题总结
  • # 深度解析 Socket 与 WebSocket:原理、区别与应用
  • (6)【Python/机器学习/深度学习】Machine-Learning模型与算法应用—使用Adaboost建模及工作环境下的数据分析整理
  • (el-Transfer)操作(不使用 ts):Element-plus 中 Select 组件动态设置 options 值需求的解决过程
  • (补充):java各种进制、原码、反码、补码和文本、图像、音频在计算机中的存储方式
  • (附源码)计算机毕业设计ssm基于B_S的汽车售后服务管理系统
  • (七)Flink Watermark
  • (限时免费)震惊!流落人间的haproxy宝典被找到了!一切玄妙尽在此处!
  • (一) 初入MySQL 【认识和部署】
  • (状压dp)uva 10817 Headmaster's Headache
  • *_zh_CN.properties 国际化资源文件 struts 防乱码等
  • *++p:p先自+,然后*p,最终为3 ++*p:先*p,即arr[0]=1,然后再++,最终为2 *p++:值为arr[0],即1,该语句执行完毕后,p指向arr[1]
  • ./和../以及/和~之间的区别
  • .bat批处理(九):替换带有等号=的字符串的子串
  • .mat 文件的加载与创建 矩阵变图像? ∈ Matlab 使用笔记
  • .NET COER+CONSUL微服务项目在CENTOS环境下的部署实践
  • .Net Core 笔试1
  • .net core 实现redis分片_基于 Redis 的分布式任务调度框架 earth-frost
  • .Net IE10 _doPostBack 未定义