当前位置: 首页 > news >正文

06-使用pytorch实现手写数字识别

目录

1.思路和流程分析

2.准备训练集和测试集

2.1 torchvision.transforms的图形数据处理方法

2.1.1 torchvison.transforms.ToTensor

2.1.2 torchvision.transforms.Normalize(mean,std)

2.1.3 torchvision.transforms.Compose(transforms)

2.2 准备MNIST数据集的Dataset和DataLoader

3.构建模型

3.1 激活函数的使用

3.2 模型中数据的形状(【添加形状变化图形】)

3.3 模型的损失函数

4.模型的训练

5.模型的保存和加载

5.1 模型的保存

5.2 模型的加载

6.模型的评估

7.总的代码


1.思路和流程分析

2.准备训练集和测试集

2.1 torchvision.transforms的图形数据处理方法

2.1.1 torchvison.transforms.ToTensor

from torchvision.datasets import MNIST
mnist=MNIST(root=r'D:\各种编译器的代码\pythonProject12\机器学习\NLP自然语言处理\datas',train=True,download=True,transform=None)#len(mnist)==60000
print(mnist[0])#(<PIL.Image.Image image mode=L size=28x28 at 0x1D5F6298EE0>, 5)
img=mnist[0][0]
img.show()#打开图片

from torchvision import transforms
import numpy as np
data=np.random.randint(0,255,size=12)
img=data.reshape(2,2,3)
print(img.shape)
img_tensor=transforms.ToTensor()(img)#转换成tensor
print(img_tensor)
print(img_tensor.size())

输出如下:

(2, 2, 3)
tensor([[[235,  30],
         [236,  92]],

        [[  1, 113],
         [ 53,   5]],

        [[ 21, 190],
         [ 46,  11]]], dtype=torch.int32)
torch.Size([3, 2, 2])

2.1.2 torchvision.transforms.Normalize(mean,std)

from torchvision import transforms
import numpy as np
import torchvision
data=np.random.randint(0,255,size=12)
img=data.reshape(2,2,3)
img=transforms.ToTensor()(img)#转换成tensor
print(img)
print('*'*100)
norm_img=transforms.Normalize((10,10,10),(1,1,1))(img)#进行规范化处理
print(norm_img)

2.1.3 torchvision.transforms.Compose(transforms)

transforms.Compose([
    torchvision.transforms.ToTensor(),#先转换为Tensor
    torchvision.transforms.Normalize(mean,std)#再进行正则化
])

2.2 准备MNIST数据集的Dataset和DataLoader

准备训练集:

from torchvision.datasets import MNIST
from torchvision.transforms import Compose,ToTensor,Normalize
from torch.utils.data import DataLoader

BATCH_SIZE=128
#1.准备数据
def get_dataloader(train=True):
    transform_fn = Compose([
        ToTensor(),
        Normalize(mean=(0.1307,), std=(0.3081,))  # mean std的形状和通道数相同
    ])
    dataset = MNIST(root=r'D:\各种编译器的代码\pythonProject12\机器学习\NLP自然语言处理\datas', train=True, transform=transform_fn)
    data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    return data_loader

3.构建模型

3.1 激活函数的使用

import torch
import torch.nn.functional as F
b=torch.Tensor([-2,-1,0,1,2])
print(F.relu(b))    #tensor([0., 0., 0., 1., 2.])

3.2 模型中数据的形状(【添加形状变化图形】)

import torch.nn as nn
import torch.nn.functional as F

#2.构建模型
class MnistNodel(nn.Module):
    def __init__(self):
        super(MnistNodel,self).__init__()
        self.fc1=nn.Linear(1*28*28,28)#第一个全连接
        self.fc2=nn.Linear(28,10)#第二个全连接 最终有10个类别

    def forward(self,input):
        """
        :param input:[batch_size,1,28,28]
        :return:输出层
        """
        #1.修改形状
        x=input.view([input.size(0),1*28*28])#或者input.view([-1,1*28*28])
        #2,进行全连接的操作
        x=self.fc1(x)
        #3.进行激活函数的处理
        x=F.relu(x)#形状无变化
        #4.输出层
        out=self.fc2(x)
        return out

3.3 模型的损失函数

#方法一
criterion=nn.CrossEntropyLoss()#交叉熵损失
loss=criterion(input,target)

#方法二
output=F.log_softmax(x,dim=-1)#1.对输出值计算softmax和取对数
loss=F.nll_loss(output,target)#2.使用torch中带权损失nll_loss

4.模型的训练

from torch.optim import Adam

model=MnistNodel()#实例化模型
optimizer=Adam(model.parameters(),lr=0.001)
def train(epoch):
    '''实现训练的过程'''
    data_loader=get_dataloader()
    for idx,(input,target) in enumerate(data_loader):
        optimizer.zero_grad()
        output=model(input)#调用模型,得到预测值
        loss=F.nll_loss(output,target)#得到损失
        loss.backward()#反向传播
        optimizer.step()#梯度的更新
        if idx%100==0:
            print(epoch,idx,loss.item(),sep='\t')

5.模型的保存和加载

5.1 模型的保存

torch.save(model.state_dict(),'path')#保存模型参数
torch.save(optimizer.state_dict(),'path')#保存优化器参数

5.2 模型的加载

model.load_state_dict(torch.load('path'))
optimizer.load_state_dict(torch.load('path'))

6.模型的评估

import numpy as np

def test():
    loss_list=[]
    acc_list=[]
    test_dataloader=get_dataloader(train=False)
    for idx,(input,target) in enumerate(test_dataloader):
        with torch.no_grad():
            output=model(input)
            cur_loss=F.nll_loss(output,target)
            loss_list.append(cur_loss)
            #计算准确率
            #output [batch_size,10] target:[batch_size]
            pred=output.max(dim=-1)[-1]
            cur_acc=pred.eq(target).float().mean()
            acc_list.append(cur_acc)
    print('平均准确率:',np.mean(acc_list),'\t平均损失:',np.mean(loss_list))

#结果如下:
#    平均准确率: 0.9503709 	平均损失: 0.17310049

7.总的代码

import torch,os
import numpy as np
from torchvision.datasets import MNIST
from torchvision.transforms import Compose,ToTensor,Normalize
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

BATCH_SIZE=128
#1.准备数据
def get_dataloader(train=True):
    transform_fn = Compose([
        ToTensor(),
        Normalize(mean=(0.1307,), std=(0.3081,))  # mean std的形状和通道数相同
    ])
    dataset = MNIST(root=r'D:\各种编译器的代码\pythonProject12\机器学习\NLP自然语言处理\datas', train=True, transform=transform_fn)
    data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    return data_loader

#2.构建模型
class MnistNodel(nn.Module):
    def __init__(self):
        super(MnistNodel,self).__init__()
        self.fc1=nn.Linear(1*28*28,28)#第一个全连接
        self.fc2=nn.Linear(28,10)#第二个全连接 最终有10个类别

    def forward(self,input):
        """
        :param input:[batch_size,1,28,28]
        :return:输出层
        """
        #1.修改形状
        x=input.view([input.size(0),1*28*28])#或者input.view([-1,1*28*28])
        #2,进行全连接的操作
        x=self.fc1(x)
        #3.进行激活函数的处理
        x=F.relu(x)#形状无变化
        #4.输出层
        out=self.fc2(x)
        return F.log_softmax(out)

model=MnistNodel()#实例化模型
if os.path.exists(r'D:\各种编译器的代码\pythonProject12\机器学习\NLP自然语言处理\模型的保存\mnist_model.pkl'):
    model.load_state_dict(torch.load(r'D:\各种编译器的代码\pythonProject12\机器学习\NLP自然语言处理\模型的保存\mnist_model.pkl'))
optimizer=Adam(model.parameters(),lr=0.001)
if os.path.exists(r'D:\各种编译器的代码\pythonProject12\机器学习\NLP自然语言处理\模型的保存\mnist_optimizer.pkl'):
    optimizer.load_state_dict(torch.load(r'D:\各种编译器的代码\pythonProject12\机器学习\NLP自然语言处理\模型的保存\mnist_optimizer.pkl'))

def train(epoch):
    '''实现训练的过程'''
    data_loader=get_dataloader()
    for idx,(input,target) in enumerate(data_loader):
        optimizer.zero_grad()
        output=model(input)#调用模型,得到预测值
        loss=F.nll_loss(output,target)#得到损失
        loss.backward()#反向传播
        optimizer.step()#梯度的更新
        if idx%100==0:
            print(epoch,idx,loss.item(),sep='\t')
        #模型的保存
        if idx%100==0:#每隔100个保存一下
            torch.save(model.state_dict(),r'D:\各种编译器的代码\pythonProject12\机器学习\NLP自然语言处理\模型的保存\mnist_model.pkl')
            torch.save(optimizer.state_dict(),r'D:\各种编译器的代码\pythonProject12\机器学习\NLP自然语言处理\模型的保存\mnist_optimizer.pkl')

def test():
    loss_list=[]
    acc_list=[]
    test_dataloader=get_dataloader(train=False)
    for idx,(input,target) in enumerate(test_dataloader):
        with torch.no_grad():
            output=model(input)
            cur_loss=F.nll_loss(output,target)
            loss_list.append(cur_loss)
            #计算准确率
            #output [batch_size,10] target:[batch_size]
            pred=output.max(dim=-1)[-1]
            cur_acc=pred.eq(target).float().mean()
            acc_list.append(cur_acc)
    print('平均准确率:',np.mean(acc_list),'\t平均损失:',np.mean(loss_list))


if __name__ == '__main__':
    # for i in range(3):#训练三轮
    #     train(i)
    test()

相关文章:

  • 高级特效开发阶段学习总结
  • WPF 简单的ComboBox自定义样式。
  • Servlet 规范和 Servlet 容器
  • 切面的优先级、基于XML的AOP实现
  • 【Java面试宝典】常用类中的方法重写|equals方法与逻辑运算符==的区别
  • 重构的原则
  • Restyle起来!
  • 【Unity3D日常BUG】Unity3D中出现“unsafe code 不安全的代码”的错误时的解决方法
  • Node中实现一个简易的图片验证码流程
  • java-Lambda表达式
  • Robotics System Toolbox中的机器人运动(7)--RRT规划避障路径
  • 和一个海归的博士聊人生
  • 移动端布局介绍——css像素/物理像素/设备像素比
  • redis简介及八种数据类型
  • GAN Step By Step -- Step1 GAN介绍
  • 「面试题」如何实现一个圣杯布局?
  • 【Amaple教程】5. 插件
  • 【跃迁之路】【735天】程序员高效学习方法论探索系列(实验阶段492-2019.2.25)...
  • C++类的相互关联
  • Docker 1.12实践:Docker Service、Stack与分布式应用捆绑包
  • Fundebug计费标准解释:事件数是如何定义的?
  • Java Agent 学习笔记
  • js中的正则表达式入门
  • JWT究竟是什么呢?
  • learning koa2.x
  • nodejs:开发并发布一个nodejs包
  • rabbitmq延迟消息示例
  • uva 10370 Above Average
  • 初探 Vue 生命周期和钩子函数
  • 从0实现一个tiny react(三)生命周期
  • 湖南卫视:中国白领因网络偷菜成当代最寂寞的人?
  • 利用阿里云 OSS 搭建私有 Docker 仓库
  • 聊聊flink的TableFactory
  • shell使用lftp连接ftp和sftp,并可以指定私钥
  • 阿里云重庆大学大数据训练营落地分享
  • 长三角G60科创走廊智能驾驶产业联盟揭牌成立,近80家企业助力智能驾驶行业发展 ...
  • #pragma multi_compile #pragma shader_feature
  • (4)事件处理——(2)在页面加载的时候执行任务(Performing tasks on page load)...
  • (Java实习生)每日10道面试题打卡——JavaWeb篇
  • (k8s中)docker netty OOM问题记录
  • (LNMP) How To Install Linux, nginx, MySQL, PHP
  • (二)斐波那契Fabonacci函数
  • (翻译)terry crowley: 写给程序员
  • (附源码)php新闻发布平台 毕业设计 141646
  • (附源码)小程序 交通违法举报系统 毕业设计 242045
  • (转)ORM
  • (转载)PyTorch代码规范最佳实践和样式指南
  • .bat批处理(一):@echo off
  • .htaccess 强制https 单独排除某个目录
  • .NET CF命令行调试器MDbg入门(四) Attaching to Processes
  • .net core webapi 部署iis_一键部署VS插件:让.NET开发者更幸福
  • .NET Core 将实体类转换为 SQL(ORM 映射)
  • .NET Core、DNX、DNU、DNVM、MVC6学习资料
  • .Net Framework 4.x 程序到底运行在哪个 CLR 版本之上
  • .NET Reactor简单使用教程