当前位置: 首页 > news >正文

0基础学习PyTorch——时尚分类(Fashion MNIST)训练和推理

大纲

  • 环境准备
  • 安装依赖
  • 下载训练集
  • 训练
    • 定义模型
    • 训练
      • 加载训练集
      • 定义损失函数和优化器
      • 训练模型
      • 保存模型
      • 完整文件
  • 推理
    • 加载模型
    • 加载并预处理本地文件
    • 推理
    • 完整文件
  • 代码地址
  • 参考资料

时尚分类是PyTorch官方文档中推荐的案例。本文将拆解这个案例,进行部署以及测试。

环境准备

基础环境可以参考《0基础学习PyTorch——最小Demo》来进行部署。

安装依赖

torchvision 是 PyTorch 的一个官方库,专门用于计算机视觉任务。它提供了常用的数据集、模型架构和图像处理工具,简化了计算机视觉项目的开发过程。后续我们的数据都来源于该库。

source env.sh install torchvision

在这里插入图片描述

下载训练集

将下列内容保存为download.py。

# download.py
import torchvision# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', download=True)# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))

然后运行这个文件

python download.py

在这里插入图片描述
此时目录结构如下
在这里插入图片描述

训练

定义模型

将下面内容保存为garmentclassifier.py。该文件会被训练和推理两个环节使用。

# garmentclassifier.py
import torch.nn as nn
import torch.nn.functional as F# 定义一个用于服装分类的卷积神经网络
class GarmentClassifier(nn.Module):def __init__(self):super(GarmentClassifier, self).__init__()# 定义第一个卷积层,输入通道数为1,输出通道数为6,卷积核大小为5x5self.conv1 = nn.Conv2d(1, 6, 5)# 定义最大池化层,池化窗口大小为2x2self.pool = nn.MaxPool2d(2, 2)# 定义第二个卷积层,输入通道数为6,输出通道数为16,卷积核大小为5x5self.conv2 = nn.Conv2d(6, 16, 5)# 定义第一个全连接层,输入大小为16*4*4,输出大小为120self.fc1 = nn.Linear(16 * 4 * 4, 120)# 定义第二个全连接层,输入大小为120,输出大小为84self.fc2 = nn.Linear(120, 84)# 定义第三个全连接层,输入大小为84,输出大小为10(对应10个类别)self.fc3 = nn.Linear(84, 10)def forward(self, x):# 通过第一个卷积层和ReLU激活函数,然后通过最大池化层x = self.pool(F.relu(self.conv1(x)))# 通过第二个卷积层和ReLU激活函数,然后通过最大池化层x = self.pool(F.relu(self.conv2(x)))# 展平张量,从多维张量变为二维张量x = x.view(-1, 16 * 4 * 4)# 通过第一个全连接层和ReLU激活函数x = F.relu(self.fc1(x))# 通过第二个全连接层和ReLU激活函数x = F.relu(self.fc2(x))# 通过第三个全连接层(输出层)x = self.fc3(x)return x

训练

加载训练集

这次我们直接从本地加载训练集,但是需要做归一化处理。

from datetime import datetime
import torch
import torchvision
import torchvision.transforms as transforms
from garmentclassifier import GarmentClassifier# 定义图像转换操作:将图像转换为张量,并进行归一化处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))]) # 对图像的每个通道进行标准化,使得每个通道的像素值具有零均值和单位标准差# 加载FashionMNIST训练数据集,并应用定义的图像转换操作
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform)# 创建数据加载器,用于批量加载训练数据,batch_size为4,数据顺序随机打乱
trainloader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)

定义损失函数和优化器

# 实例化模型
model = GarmentClassifier()
# 定义损失函数为交叉熵损失
loss_fn = torch.nn.CrossEntropyLoss()
# 定义优化器为随机梯度下降(SGD),学习率为0.001,动量为0.9
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

训练模型

# 训练模型,训练2个epoch
for epoch in range(2):running_loss = 0.0  # 初始化累计损失# 枚举数据加载器中的数据,i是批次索引,data是当前批次的数据for i, data in enumerate(trainloader, 0):inputs, labels = data  # 获取输入数据和对应的标签optimizer.zero_grad()  # 清空梯度outputs = model(inputs)  # 前向传播,计算模型输出loss = loss_fn(outputs, labels)  # 计算损失loss.backward()  # 反向传播,计算梯度optimizer.step()  # 更新模型参数running_loss += loss.item()  # 累加损失# 每2000个批次打印一次平均损失if i % 2000 == 1999:print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000}')running_loss = 0.0  # 重置累计损失

保存模型

# 获取当前时间戳,格式为 'YYYYMMDD_HHMMSS'
timestamp = datetime.now().strftime('%Y%m%d%H%M%S.pth')# 定义模型保存路径,包含时间戳
model_path = 'model_{}'.format(timestamp)      # 保存模型的状态字典到指定路径
torch.save(model.state_dict(), model_path)

完整文件

from datetime import datetime
import torch
import torchvision
import torchvision.transforms as transforms
from garmentclassifier import GarmentClassifier# 定义图像转换操作:将图像转换为张量,并进行归一化处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))]) # 对图像的每个通道进行标准化,使得每个通道的像素值具有零均值和单位标准差# 加载FashionMNIST训练数据集,并应用定义的图像转换操作
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform)# 创建数据加载器,用于批量加载训练数据,batch_size为4,数据顺序随机打乱
trainloader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)# 实例化模型
model = GarmentClassifier()
# 定义损失函数为交叉熵损失
loss_fn = torch.nn.CrossEntropyLoss()
# 定义优化器为随机梯度下降(SGD),学习率为0.001,动量为0.9
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# 训练模型,训练2个epoch
for epoch in range(2):running_loss = 0.0  # 初始化累计损失# 枚举数据加载器中的数据,i是批次索引,data是当前批次的数据for i, data in enumerate(trainloader, 0):inputs, labels = data  # 获取输入数据和对应的标签optimizer.zero_grad()  # 清空梯度outputs = model(inputs)  # 前向传播,计算模型输出loss = loss_fn(outputs, labels)  # 计算损失loss.backward()  # 反向传播,计算梯度optimizer.step()  # 更新模型参数running_loss += loss.item()  # 累加损失# 每2000个批次打印一次平均损失if i % 2000 == 1999:print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000}')running_loss = 0.0  # 重置累计损失# 获取当前时间戳,格式为 'YYYYMMDD_HHMMSS'
timestamp = datetime.now().strftime('%Y%m%d%H%M%S.pth')# 定义模型保存路径,包含时间戳
model_path = 'model_{}'.format(timestamp)      # 保存模型的状态字典到指定路径
torch.save(model.state_dict(), model_path)

执行该文件,我们会得到一个后缀为pth的模型文件。

推理

加载模型

我们加载上一步创建的模型。

import os
import glob
import torch
import torchvision.transforms as transforms
from PIL import Image
from datetime import datetime
from garmentclassifier import GarmentClassifierdef get_latest_model_path(directory, pattern="model_*.pth"):# 获取目录下所有符合模式的文件model_files = glob.glob(os.path.join(directory, pattern))if not model_files:raise FileNotFoundError("No model files found in the directory.")# 找到最新的模型文件latest_model_file = max(model_files, key=os.path.getmtime)return latest_model_file# 定义图像转换操作:将图像转换为张量,并进行归一化处理
transform = transforms.Compose([transforms.Resize((28, 28)),  # 调整图像大小为28x28transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 加载训练好的模型
model = GarmentClassifier()
model_path = get_latest_model_path('./')  # 获取最新的模型文件
model.load_state_dict(torch.load(model_path, weights_only=False)) # 加载模型参数
model.eval()  # 设置模型为评估模式

加载并预处理本地文件

在这里插入图片描述

# 从本地加载图像
image_path = 'shoe.jpg'  # 替换为实际的图像路径
image = Image.open(image_path).convert('L')  # 将图像转换为灰度图# 预处理图像
image = transform(image)
image = image.unsqueeze(0)  # 增加一个批次维度

我们使用transform进行归一化处理。

推理

# 推理(预测)
with torch.no_grad():  # 在推理过程中不需要计算梯度outputs = model(image)  # 前向传播,计算模型输出_, predicted = torch.max(outputs, 1)  # 获取预测结果# 定义类别名称
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot')# 打印预测结果
print(f'Predicted label: {classes[predicted.item()]}')

由于推理出来的是索引号,为了方便解读,我们将类型映射打印出来。

完整文件

import os
import glob
import torch
import torchvision.transforms as transforms
from PIL import Image
from datetime import datetime
from garmentclassifier import GarmentClassifierdef get_latest_model_path(directory, pattern="model_*.pth"):# 获取目录下所有符合模式的文件model_files = glob.glob(os.path.join(directory, pattern))if not model_files:raise FileNotFoundError("No model files found in the directory.")# 找到最新的模型文件latest_model_file = max(model_files, key=os.path.getmtime)return latest_model_file# 定义图像转换操作:将图像转换为张量,并进行归一化处理
transform = transforms.Compose([transforms.Resize((28, 28)),  # 调整图像大小为28x28transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 加载训练好的模型
model = GarmentClassifier()
model_path = get_latest_model_path('./')  # 获取最新的模型文件
model.load_state_dict(torch.load(model_path, weights_only=False)) # 加载模型参数
model.eval()  # 设置模型为评估模式# 从本地加载图像
image_path = 'shoe.jpg'  # 替换为实际的图像路径
image = Image.open(image_path).convert('L')  # 将图像转换为灰度图# 预处理图像
image = transform(image)
image = image.unsqueeze(0)  # 增加一个批次维度# 推理(预测)
with torch.no_grad():  # 在推理过程中不需要计算梯度outputs = model(image)  # 前向传播,计算模型输出_, predicted = torch.max(outputs, 1)  # 获取预测结果# 定义类别名称
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot')# 打印预测结果
print(f'Predicted label: {classes[predicted.item()]}')

执行这个文件,我们看到推理结果是:Sandal(凉鞋)。
在这里插入图片描述

代码地址

https://github.com/f304646673/deeplearning/tree/main/FashionMNIST

参考资料

  • https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html

相关文章:

  • 阿里云函数计算 x NVIDIA 加速企业 AI 应用落地
  • 10.Lab Nine —— file system-上
  • 丹摩智算(damodel)部署stable diffusion实验
  • 三子棋小游戏
  • 【React】组件通信
  • Android 已经过时的方法用什么新方法替代?
  • 使用Python解决数据分析中的相关性分析
  • macOS 15 Blank OVF - macOS Sequoia 虚拟化解决方案
  • 分享个锂电池升压芯片,3.7V升5V大电流输出的芯片。AD2403 PWM升压芯片
  • 如何创建一个包含多个列的表?
  • 828华为云征文|华为云Flexus云服务器X实例——部署EduSoho网校系统、二次开发对接华为云视频点播实现CDN加速播放
  • fastadmin本地安装插件提示”请从官网渠道下载插件压缩包(code:2)(code:1)“
  • 数据结构(Day18)
  • vue-pdf 实现pdf预览、高亮、分页、定位功能
  • 问题记录:end value has mixed support, consider using flex-end instead
  • Android系统模拟器绘制实现概述
  • Angular Elements 及其运作原理
  • ESLint简单操作
  • leetcode388. Longest Absolute File Path
  • node和express搭建代理服务器(源码)
  • win10下安装mysql5.7
  • 从@property说起(二)当我们写下@property (nonatomic, weak) id obj时,我们究竟写了什么...
  • 分类模型——Logistics Regression
  • 网页视频流m3u8/ts视频下载
  • 一起来学SpringBoot | 第十篇:使用Spring Cache集成Redis
  • ​业务双活的数据切换思路设计(下)
  • (11)MSP430F5529 定时器B
  • (8)Linux使用C语言读取proc/stat等cpu使用数据
  • (java)关于Thread的挂起和恢复
  • (k8s)kubernetes 部署Promehteus学习之路
  • (接口封装)
  • (三)elasticsearch 源码之启动流程分析
  • (实测可用)(3)Git的使用——RT Thread Stdio添加的软件包,github与gitee冲突造成无法上传文件到gitee
  • (转)C语言家族扩展收藏 (转)C语言家族扩展
  • (轉貼) 寄發紅帖基本原則(教育部禮儀司頒布) (雜項)
  • .Net Framework 4.x 程序到底运行在哪个 CLR 版本之上
  • .net MVC中使用angularJs刷新页面数据列表
  • .NET 读取 JSON格式的数据
  • .net 受管制代码
  • .Net 中Partitioner static与dynamic的性能对比
  • .NET/C# 使用反射调用含 ref 或 out 参数的方法
  • @data注解_一枚 架构师 也不会用的Lombok注解,相见恨晚
  • @property @synthesize @dynamic 及相关属性作用探究
  • [240621] Anthropic 发布了 Claude 3.5 Sonnet AI 助手 | Socket.IO 拒绝服务漏洞
  • [AHK V2]鼠标悬停展开窗口,鼠标离开折叠窗口
  • [C#]OpenCvSharp 实现Bitmap和Mat的格式相互转换
  • [C++]spdlog学习
  • [C++打怪升级]--学习总目录
  • [Django学习]查询过滤器(lookup types)
  • [ESP32] 编码旋钮驱动
  • [HJ73 计算日期到天数转换]
  • [IE9] IE9 RC版下载链接
  • [IE编程] IE中使网页元素进入编辑模式
  • [imx9]DDR test Tool for imx9
  • [JS]Math.random()随机数的二三事