当前位置: 首页 > news >正文

PyTorch深度学习模型训练流程:(一、分类)

自己写了个封装PyTorch深度学习训练流程的函数,实现了根据输入参数训练模型并可视化训练过程的功能,可以方便快捷地检验一个模型的效果,有助于提高选择模型架构、优化超参数等工作的效率。发出来供大家参考,如有不足之处,欢迎批评讨论。

分类是人工智能的一个非常重要的应用,这篇文章分享的函数适用于实现分类的深度学习模型,包括以下功能:

  1. 根据输入的数据集、模型、优化器、损失函数等参数训练一个分类模型;
  2. 使用visdom可视化训练过程,实时输出精确度曲线、损失曲线、混淆矩阵和ROC曲线;
  3. 支持二分类和多分类;
  4. 输入数据集支持形如(X,y)的np.ndarray类型,及形如(train_data,test_data)的torch.utils.data.Dataset类型,可以方便灵活地调用torch内置数据集或自定义数据集;
  5. 支持使用GPU加速深度学习模型的训练。

废话不多说,先来看下输出效果:

二分类
多分类

 深度学习的完整流程通常包括以下几个步骤:

  1. 收集数据
  2. 数据预处理
  3. 选择模型
  4. 训练模型
  5. 评估模型
  6. 超参数调优
  7. 测试模型

本函数封装了训练模型和评估模型的步骤,包括:

  1. 若数据集为(X,y)形式则分离训练集和测试集(测试集占20%),数据标准化,封装训练集和测试集;
  2. 将训练集和测试集设置为加载器;
  3. 遍历训练集加载器,计算每一批次的输出和损失,并反向传播更新神经网络参数;
  4. 每迭代100次评估一下模型,用测试集数据计算并画出精确度曲线、损失曲线、混淆矩阵和ROC曲线。

代码如下:

from functools import partial
import numpy as np
import pandas as pd
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix, roc_curve, r2_scoreimport torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, Dataset
from visdom import Visdomfrom typing import Union, Optional
from sklearn.base import TransformerMixin
from torch.optim.optimizer import Optimizerdef classify(data: tuple[Union[np.ndarray, Dataset], Union[np.ndarray, Dataset]],model: nn.Module,optimizer: Optimizer,criterion: nn.Module,scaler: Optional[TransformerMixin] = None,batch_size: int = 64,epochs: int = 10,device: Optional[torch.device] = None
) -> nn.Module:"""分类任务的训练函数。:param data: 形如(X,y)的np.ndarray类型,及形如(train_data,test_data)的torch.utils.data.Dataset类型:param model: 分类模型:param optimizer: 优化器:param criterion: 损失函数:param scaler: 数据标准化器:param batch_size: 批大小:param epochs: 训练轮数:param device: 训练设备:return: 训练好的分类模型"""if isinstance(data[0], np.ndarray):X, y = data# 处理类别classes = np.unique(y)classes_str = [str(i) for i in classes]num_classes = len(classes)# 分离训练集和测试集,指定随机种子以便复现X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 数据标准化if scaler is not None:X_train = scaler.fit_transform(X_train)X_test = scaler.transform(X_test)# 转换为tensorX_train = torch.from_numpy(X_train.astype(np.float32))X_test = torch.from_numpy(X_test.astype(np.float32))y_train = torch.from_numpy(y_train.astype(np.int64))y_test = torch.from_numpy(y_test.astype(np.int64))# 将X和y封装成TensorDatasettrain_dataset = TensorDataset(X_train, y_train)test_dataset = TensorDataset(X_test, y_test)elif isinstance(data[0], Dataset):train_dataset, test_dataset = dataclasses = list(train_dataset.class_to_idx.values())classes_str = train_dataset.classesnum_classes = len(classes)else:raise ValueError('Unsupported data type')train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True,num_workers=2,)test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True,num_workers=2,)model.to(device)vis = Visdom()# 训练模型for epoch in range(epochs):for step, (batch_x_train, batch_y_train) in enumerate(train_loader):batch_x_train = batch_x_train.to(device)batch_y_train = batch_y_train.to(device)# 前向传播output = model(batch_x_train)loss = criterion(output, batch_y_train)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()niter = epoch * len(train_loader) + step + 1  # 计算迭代次数if niter % 100 == 0:# 评估模型model.eval()with torch.no_grad():eval_dict = {'test_loss': [],'test_acc': [],'test_cm': [],'test_roc': [],}for batch_x_test, batch_y_test in test_loader:batch_x_test = batch_x_test.to(device)batch_y_test = batch_y_test.to(device)test_output = model(batch_x_test)predicted = torch.argmax(test_output, 1)test_predicted_tuple = (batch_y_test.numpy(), predicted.numpy())# 计算并记录损失、精确度、混淆矩阵、ROC曲线eval_dict['test_loss'].append(criterion(test_output, batch_y_test))eval_dict['test_acc'].append(accuracy_score(*test_predicted_tuple))eval_dict['test_cm'].append(confusion_matrix(*test_predicted_tuple, labels=classes))if num_classes == 2:# eval_dict['test_roc']形状为(len,(fpr,tpr),3)eval_dict['test_roc'].append(roc_curve(*test_predicted_tuple)[:2])else:# 多分类ROC曲线需要one-hot编码y_test_one_hot, predicted_one_hot = map(partial(label_binarize, classes=classes), test_predicted_tuple)fpr_list = []tpr_list = []for i in range(num_classes):fpr, tpr, _ = roc_curve(y_test_one_hot[:, i], predicted_one_hot[:, i])# 无(fpr,tpr)数据点时,插值填充(0,0)数据点if len(fpr) != 3:fpr = np.insert(fpr, 0, 0)tpr = np.insert(tpr, 0, 0)fpr_list.append(fpr)tpr_list.append(tpr)# eval_dict['test_roc']形状为(len,(fpr,tpr),num_classes,3)eval_dict['test_roc'].append((fpr_list, tpr_list))# 画出损失曲线vis.line(X=torch.ones((1, 2)) * (niter // 100),Y=torch.stack((loss, torch.mean(torch.tensor(eval_dict['test_loss'])))).unsqueeze(0),win='loss',update='append',opts=dict(title='Loss', legend=['train_loss', 'test_loss']),)# 画出精确度曲线train_acc = accuracy_score(batch_y_train.numpy(), torch.argmax(output, 1).numpy())vis.line(X=torch.ones((1, 2)) * (niter // 100),Y=torch.tensor((train_acc, np.mean(eval_dict['test_acc']))).unsqueeze(0),win='accuracy',update='append',opts=dict(title='Accuracy', legend=['train_acc', 'test_acc'], ytickmin=0, ytickmax=1),)# 画出混淆矩阵vis.heatmap(X=np.add.reduce(eval_dict['test_cm']),win='confusion_matrix',opts=dict(title='Confusion Matrix', columnnames=classes_str, rownames=classes_str),)# 画出ROC曲线test_roc_arr = np.array(eval_dict['test_roc'])zeros_df = pd.DataFrame({'fpr': [0], 'tpr': [0]})  # 用于填充的(0,0)数据点ones_df = pd.DataFrame({'fpr': [1], 'tpr': [1]})  # 用于填充的(1,1)数据点if num_classes == 2:plot_arr = test_roc_arr[:, :, 1]  # 提取(fpr,tpr)数据点,形状为(len,(fpr,tpr))cats = pd.qcut(plot_arr[:, 0], q=10, labels=False, duplicates='drop')  # 按fpr大小分成10个数据一样多的区间groups = pd.DataFrame(plot_arr, columns=['fpr', 'tpr']).groupby(cats).mean()  # 计算每个区间的平均值,形状为(10,(fpr,tpr))plot_df = pd.concat([zeros_df, groups, ones_df])  # 头添加(0,0),尾添加(1,1)数据点,形状为(12,(fpr,tpr))x = plot_df['fpr']Y = plot_df['tpr']else:plot_df_list = []plot_arr = test_roc_arr[:, :, :, 1].swapaxes(1, 2)  # 提取(fpr,tpr)数据点并换轴,形状为(len,num_classes,(fpr,tpr))for i in range(num_classes):cats = pd.qcut(plot_arr[:, i, 0], q=10, labels=False, duplicates='drop')groups = pd.DataFrame(plot_arr[:, i, :], columns=['fpr', 'tpr']).groupby(cats).mean()  # 形状为(10,(fpr,tpr))plot_df = pd.concat([zeros_df, groups, ones_df])  # 形状为(12,(fpr,tpr))add_num = 12 - len(plot_df)# 长度不足12时,插值填充(0,0)数据点if add_num > 0:plot_df = pd.concat([zeros_df] * add_num + [plot_df])plot_df_list.append(plot_df)  # 形状为(num_classes,12,(fpr,tpr))plot_arr_sum = np.stack(plot_df_list, axis=1)  # 形状为(12,num_classes,(fpr,tpr))x = plot_arr_sum[:, :, 0]Y = plot_arr_sum[:, :, 1]vis.line(X=x,Y=Y,win='ROC',opts=dict(title='ROC', legend=classes_str),)return model

注意:代码运行前要先在命令行输入python -m visdom.server,在浏览器中打开提供的链接:

 成功运行的效果如下:

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • Debezium系列之:记录一次命令行可以访问mysql数据库,但是debezium connector无法访问数据库原因排查
  • 掌控安全CTF-2024年8月擂台赛-ez_misc
  • XSS - LABS —— 靶场笔记合集
  • Sentinel组件详解:使用与原理
  • 进程间通信:采用有名管道,创建两个发送接收端,父进程写入管道1和管道2,子进程读取管道2和管道1.
  • Python中的赋值运算符:解锁编程的无限可能
  • 加速打开gtihub的工具dev-sidecar
  • 急急急!苹果手机突然黑屏无法开机怎么办?能解决吗?
  • PHPShort轻量级网址缩短程序源码开心版,内含汉化包
  • 微软Win11 22H2/23H2八月可选更新KB5041587发布!
  • Element-plus组件库基础组件使用
  • 如祺出行发布首份中期业绩,总收入增长13.6%
  • ShardingSphere学习笔记
  • Java—可变参数、不可变集合
  • Java面试宝典-java基础04
  • 【附node操作实例】redis简明入门系列—字符串类型
  • Computed property XXX was assigned to but it has no setter
  • egg(89)--egg之redis的发布和订阅
  • Java反射-动态类加载和重新加载
  • Magento 1.x 中文订单打印乱码
  • markdown编辑器简评
  • Otto开发初探——微服务依赖管理新利器
  • React中的“虫洞”——Context
  • weex踩坑之旅第一弹 ~ 搭建具有入口文件的weex脚手架
  • 爬虫模拟登陆 SegmentFault
  • 浅谈Kotlin实战篇之自定义View图片圆角简单应用(一)
  • 如何利用MongoDB打造TOP榜小程序
  • 支付宝花15年解决的这个问题,顶得上做出十个支付宝 ...
  • ​​​​​​​STM32通过SPI硬件读写W25Q64
  • ​一、什么是射频识别?二、射频识别系统组成及工作原理三、射频识别系统分类四、RFID与物联网​
  • (BAT向)Java岗常问高频面试汇总:MyBatis 微服务 Spring 分布式 MySQL等(1)
  • (C语言版)链表(三)——实现双向链表创建、删除、插入、释放内存等简单操作...
  • (delphi11最新学习资料) Object Pascal 学习笔记---第8章第5节(封闭类和Final方法)
  • (安卓)跳转应用市场APP详情页的方式
  • (二)原生js案例之数码时钟计时
  • (非本人原创)我们工作到底是为了什么?​——HP大中华区总裁孙振耀退休感言(r4笔记第60天)...
  • (牛客腾讯思维编程题)编码编码分组打印下标题目分析
  • (七)Flink Watermark
  • (企业 / 公司项目)前端使用pingyin-pro将汉字转成拼音
  • (强烈推荐)移动端音视频从零到上手(下)
  • (三)终结任务
  • (深度全面解析)ChatGPT的重大更新给创业者带来了哪些红利机会
  • (源码版)2024美国大学生数学建模E题财产保险的可持续模型详解思路+具体代码季节性时序预测SARIMA天气预测建模
  • (转)【Hibernate总结系列】使用举例
  • (转)memcache、redis缓存
  • (转)清华学霸演讲稿:永远不要说你已经尽力了
  • (转载)虚幻引擎3--【UnrealScript教程】章节一:20.location和rotation
  • *_zh_CN.properties 国际化资源文件 struts 防乱码等
  • .net FrameWork简介,数组,枚举
  • .NET 实现 NTFS 文件系统的硬链接 mklink /J(Junction)
  • .net 使用ajax控件后如何调用前端脚本
  • .NET/C# 在代码中测量代码执行耗时的建议(比较系统性能计数器和系统时间)
  • .net6 当连接用户的shell断掉后,dotnet会自动关闭,达不到长期运行的效果。.NET 进程守护
  • .net操作Excel出错解决
  • .NET导入Excel数据