当前位置: 首页 > news >正文

基于安卓的虫害识别软件设计--(1)模型训练与可视化

引言

  • 简介:使用pytorch框架,从模型训练、模型部署完整地实现了一个基础的图像识别项目
  • 计算资源:使用的是Kaggle(每周免费30h的GPU)

1.创建名为“utils_1”的模块

模块中包含:训练和验证的加载器函数训练函数验证函数

import os
import sysimport torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from tqdm import tqdmdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")def get_train_loader(image_path):train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform = train_transform)train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32,shuffle=True, num_workers= 0)return train_loaderdef get_val_loader(image_path):val_transform = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])val_dataset = datasets.ImageFolder(root=os.path.join(image_path, "validation"),transform = val_transform)val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=32,shuffle = False, num_workers = 0)return val_loaderdef train(train_loader,net):net.train()train_correct = 0.0train_loss = 0.0  # 初始化训练损失train_bar = tqdm(train_loader, file=sys.stdout)loss_function = nn.CrossEntropyLoss()loss_function = loss_function.to(device)optimizer = optim.Adam(net.parameters(), lr=0.001)for step, data in enumerate(train_bar):images, labels = dataimages, labels = images.to(device),labels.to(device)# 梯度清零optimizer.zero_grad()# 训练outputs = net(images)# 计算损失loss = loss_function(outputs, labels)# 反向传播loss.backward()# 更新权重optimizer.step()# 统计_, preds = outputs.max(1)correct = preds.eq(labels).sum()train_correct += correcttrain_loss += loss.item()  # 累加损失值train_bar.desc = 'Training Epoch:[{trained_samples}/{total_samples}]\t Loss: {:0.4f}\t Accuracy: {:0.4f}\t'.format(loss.item(),(100. * correct) / len(outputs),trained_samples=step * train_loader.batch_size + len(images),total_samples=len(train_loader.dataset))train_correct = (100. * train_correct) / len(train_loader.dataset)train_loss /= len(train_loader)  # 计算平均损失值return train_correct, train_loss  # 返回训练正确率和平均损失值def val(val_loader,net):net.eval()val_correct = 0.0val_loss = 0.0  # 初始化验证损失loss_function = nn.CrossEntropyLoss()loss_function = loss_function.to(device)val_bar = tqdm(val_loader, file=sys.stdout)for step, data in enumerate(val_bar):images, labels = dataimages, labels = images.to(device), labels.to(device)with torch.no_grad():# 验证outputs = net(images)# 计算损失loss = loss_function(outputs, labels)# 统计_, preds = outputs.max(1)correct = preds.eq(labels).sum()val_correct += correctval_loss += loss.item()  # 累加损失值val_bar.desc = 'Valing Epoch:[{trained_samples}/{total_samples}]\t Loss: {:0.4f}\t Accuracy: {:0.4f}\t'.format(loss.item(),(100. * correct) / len(outputs),trained_samples=step * val_loader.batch_size + len(images),total_samples=len(val_loader.dataset))val_correct = (100. * val_correct) / len(val_loader.dataset)val_loss /= len(val_loader)  # 计算平均损失值return val_correct , val_loss  # 返回验证正确率和平均损失值

注意:若使用Kaggle,想要导入该模块,需要添加以下代码

import sys
sys.path.append(r'/kaggle/input/mycode2')

其中,模块路径如下图


2.主函数 

主函数包含:使用模型函数训练主函数画图代码

2.1使用模型函数 

【若使用其他模型,可chatgpt创建其函数】

(1)resnet101 

def get_resnet101(class_num):net_name = "resnet101"net = torchvision.models.resnet101(pretrained=True)net.fc = Linear(in_features=2048, out_features=class_num, bias=True)  # ResNet101's fully connected layer expects 2048 input featuresnet = net.to(device)return net_name, net

(2)resnet34 

def get_resnet34(class_num):net_name = "resnet34"net = torchvision.models.resnet34(pretrained=True)net.fc = Linear(in_features=512, out_features=class_num, bias=True)net = net.to(device)return net_name,net

(3)mobilenetv2

def get_mobilenet_v2(class_num):net_name = "mobilenet_v2"net = torchvision.models.mobilenet_v2(pretrained=True)net.classifier[1] = Linear(in_features=1280, out_features=class_num, bias=True)net = net.to(device)return net_name,net

 2.2画图代码 

    save_path="/kaggle/working/"  plt.figure(figsize=(12, 4))# lossplt.subplot(1, 2, 1)plt.plot(range(1, epochs + 1), train_losses, "r-",label='Train loss')plt.plot(range(1, epochs + 1), val_losses, "b-",label='Val loss')plt.legend()plt.xlabel('Epoch')plt.ylabel('Loss')# accplt.subplot(1, 2, 2)plt.plot(range(1, epochs + 1), train_accs,"r-", label='Train acc')plt.plot(range(1, epochs + 1), val_accs,"b-" ,label='Val acc')plt.legend()plt.xlabel('Epoch')plt.ylabel('Acc')plt.legend()plt.savefig(os.path.join(save_path, 'result.png')) # 保存plt.show()

2.3完整代码 

import torch
import torchvision.models
from matplotlib import pyplot as plt
from torch.nn import Linear
import os# 导入自己创建的模块
from utils_1 import get_train_loader, train, val, get_val_loader# 模型选择
def get_resnet101(class_num):net_name = "resnet101"net = torchvision.models.resnet101(pretrained=True)net.fc = Linear(in_features=2048, out_features=class_num, bias=True)  # ResNet101's fully connected layer expects 2048 input featuresnet = net.to(device)return net_name, net# def get_resnet34(class_num):
#     net_name = "resnet34"
#     net = torchvision.models.resnet34(pretrained=True)
#     net.fc = Linear(in_features=512, out_features=class_num, bias=True)
#     net = net.to(device)
#     return net_name,net# def get_mobilenet_v2(class_num):
#     net_name = "mobilenet_v2"
#     net = torchvision.models.mobilenet_v2(pretrained=True)
#     net.classifier[1] = Linear(in_features=1280, out_features=class_num, bias=True)
#     net = net.to(device)
#     return net_name,net# 训练主函数
if __name__ == '__main__':device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#1 加载数据image_path = r"/kaggle/input/fruits3"train_loader = get_train_loader(image_path)val_loader = get_val_loader(image_path)#2 加载模型net_name,net = get_resnet34(class_num=5)#3 训练epochs = 5best_acc = 0train_losses = []val_losses = []train_accs = []val_accs = []for epoch in range(epochs):train_acc,train_loss = train(train_loader, net)val_acc,val_loss = val(val_loader, net)train_losses.append(train_loss)val_losses.append(val_loss)train_accs.append(train_acc.item())val_accs.append(val_acc.item())if best_acc<val_acc:best_acc = val_acctorch.save(net, os.path.join("/kaggle/working/", net_name + ".pt"))# 画图save_path="/kaggle/working/" # 图片保存路径plt.figure(figsize=(12, 4))# lossplt.subplot(1, 2, 1)plt.plot(range(1, epochs + 1), train_losses, "r-",label='Train loss')plt.plot(range(1, epochs + 1), val_losses, "b-",label='Val loss')plt.legend()plt.xlabel('Epoch')plt.ylabel('Loss')# accplt.subplot(1, 2, 2)plt.plot(range(1, epochs + 1), train_accs,"r-", label='Train acc')plt.plot(range(1, epochs + 1), val_accs,"b-" ,label='Val acc')plt.legend()plt.xlabel('Epoch')plt.ylabel('Acc')plt.legend()plt.savefig(os.path.join(save_path, 'result.png')) # 保存plt.show()

2.4训练效果与模型文件

相关文章:

  • 关于安装typescript后运行tsc -v命令报错问题
  • idm2024最新完美破解版免费下载 idm绿色直装版注册机免费分享 idm永久激活码工具
  • HTML5 视频 Vedio 标签详解
  • 神经网络---网络模型的保存、加载
  • 分治算法例子
  • OceanBase v4.2 解读:tenant=all 语义优化,提升易用性
  • Java Web学习笔记4——HTML、CSS
  • PyTorch 的 torch.nn 模块学习
  • 正则表达式----IP地址合法性判断
  • 啵啵啵啵啵啵啵啵啵啵啵啵啵啵啵
  • Java面试——中间件
  • 嵌入式Linux系统编程 — 2.1 标准I/O库简介
  • cs与msf权限传递
  • 最大矩形问题
  • 如何给 MySQL 表和列授予权限?(官方版)
  • -------------------- 第二讲-------- 第一节------在此给出链表的基本操作
  • IOS评论框不贴底(ios12新bug)
  • Promise面试题,控制异步流程
  • SQLServer插入数据
  • SSH 免密登录
  • Sublime text 3 3103 注册码
  • vue--为什么data属性必须是一个函数
  • Webpack入门之遇到的那些坑,系列示例Demo
  • 基于Mobx的多页面小程序的全局共享状态管理实践
  • 使用docker-compose进行多节点部署
  • 探索 JS 中的模块化
  • 用jQuery怎么做到前后端分离
  • 在 Chrome DevTools 中调试 JavaScript 入门
  • ​软考-高级-系统架构设计师教程(清华第2版)【第12章 信息系统架构设计理论与实践(P420~465)-思维导图】​
  • ​学习一下,什么是预包装食品?​
  • # 手柄编程_北通阿修罗3动手评:一款兼具功能、操控性的电竞手柄
  • #数据结构 笔记三
  • $NOIp2018$劝退记
  • (12)目标检测_SSD基于pytorch搭建代码
  • (二十三)Flask之高频面试点
  • (分享)自己整理的一些简单awk实用语句
  • (全部习题答案)研究生英语读写教程基础级教师用书PDF|| 研究生英语读写教程提高级教师用书PDF
  • .a文件和.so文件
  • .dat文件写入byte类型数组_用Python从Abaqus导出txt、dat数据
  • .net core 控制台应用程序读取配置文件app.config
  • .NET 设计模式—简单工厂(Simple Factory Pattern)
  • .NET应用架构设计:原则、模式与实践 目录预览
  • /proc/interrupts 和 /proc/stat 查看中断的情况
  • /proc/vmstat 详解
  • @converter 只能用mysql吗_python-MySQLConverter对象没有mysql-connector属性’...
  • @ModelAttribute 注解
  • [Android Pro] AndroidX重构和映射
  • [Android]RecyclerView添加HeaderView出现宽度问题
  • [c#基础]DataTable的Select方法
  • [CVPR 2023:3D Gaussian Splatting:实时的神经场渲染]
  • [Django学习]查询过滤器(lookup types)
  • [Enterprise Library]调用Enterprise Library时出现的错误事件之关闭办法
  • [error] 17755#0: *58522 readv() failed (104: Connection reset by peer) while reading upstream
  • [HTML]Web前端开发技术29(HTML5、CSS3、JavaScript )JavaScript基础——喵喵画网页
  • [JS]JavaScript 简介