当前位置: 首页 > news >正文

pytorch实现水果2分类(蓝莓,苹果)

1.数据集的路径,结构

dataset.py

目的:

        输入:没有输入,路径是写死了的。

        输出:返回的是一个对象,里面有self.data。self.data是一个列表,里面是(图片路径.jpg,标签)

        -data[item]返回的是(img_tensor , one-hot编码)。one-hot编码是[0,1]或者[1,0]

import glob
import os.pathimport cv2
import torch
from torch.utils.data import Dataset
from torchvision import transformsclass DtataAndLabel(Dataset):def __init__(self,path='fruits',is_train=True):self.tran=transforms.Compose([transforms.ToTensor(),transforms.Resize(size=(88,88))])is_train='train' if True else 'test'self.data=[]path=os.path.join(path,is_train)print('path=',path)print(os.path.join(path, '*', '*'))img_paths=glob.glob(os.path.join(path,'*','*'))for img_path in img_paths:label=0 if img_path.split('\\')[-2]=='blueberry' else 1self.data.append((img_path,label))def __getitem__(self, idx):#每一张图片返回一个img_tensor,one_hotimg_path,label =self.data[idx]img=cv2.imread(img_path)# img_gray=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)img_tensor=self.tran(img)img_tensor=img_tensor/255img_tensor=torch.flatten(img_tensor)one_hot=torch.zeros(2)one_hot[label]=1return img_tensor,one_hotdef __len__(self):return len(self.data)if __name__ == '__main__':# 测试data=DtataAndLabel()print(data[1][0].shape)print(data[1][1])

net.py

目的:将输入维度(k(k是加载进去的图片数),88,88,3)三通道的宽高是88,88,通过网络变化为(k,2)。

import torch.nn
import torch.nn as nnclass Net(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Linear(88*88*3, 800),nn.ReLU(),nn.Linear(800, 500),nn.ReLU(),nn.Linear(500, 800),nn.ReLU(),nn.Linear(800, 200),nn.ReLU(),nn.Linear(200, 2),)self.softmax=nn.Softmax(dim=1)def forward(self,x):x=self.model(x)x=self.softmax(x)return x
if __name__ == '__main__':net=Net()#测试一下x=torch.randn(1,100*100)out=net(x)print(out.shape)

test_train.py

目的:将图像丢进模型,然后训练出最优模型

步骤:

       1.定义初始化

                -定义拿到data对象

                -定义加载器分批加载,这里可以变换维度

                -定义初始化网络

                -定义损失函数,这里采用了均方差函数

                -定义优化器

        2.实现训练

                -将每一批数据丢给网络,此时维度发生了变化,产生了升维

                -使用优化器        

                        ---自动梯度清0

                        ---自动求导更新参数

                -计算损失值和准确度

        ·~自己建一个文件夹

import torch.optim
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdmfrom net import Net
from dataset import DtataAndLabel
import torch.nn as nn
class TrainAndTest():def __init__(self):self.writer = SummaryWriter("logs")self.train_data=DtataAndLabel(is_train=True)self.test_data=DtataAndLabel(is_train=False)#使用加载器分批加载self.train_loader=DataLoader(self.train_data,batch_size=10,shuffle=True)self.test_loader=DataLoader(self.test_data,batch_size=10,shuffle=True)#初始化网络#损失函数#优化器net=Net()self.net=netself.loss=nn.MSELoss()self.opt=torch.optim.Adam(net.parameters(),lr=0.001)self.min_loss=100.0self.weight_path='weight/best.pt'def train(self,epoch):sum_loss = 0sum_acc = 0for img_tensors, targets in tqdm(self.train_loader, desc="train...", total=len(self.train_loader)):out = self.net(img_tensors)loss = self.loss(out, targets)self.opt.zero_grad()loss.backward()self.opt.step()sum_loss += loss.item()pred_cls = torch.argmax(out, dim=1)target_cls = torch.argmax(targets, dim=1)accuracy = torch.mean(torch.eq(pred_cls, target_cls).to(torch.float32))sum_acc += accuracy.item()avg_loss = sum_loss / len(self.train_loader)avg_acc = sum_acc / len(self.train_loader)print(f'train:loss{round(avg_loss, 3)} acc:{round(avg_acc, 3)}')self.writer.add_scalars("loss", {"train_avg_loss": avg_loss}, epoch)self.writer.add_scalars("acc", {"train_avg_acc": avg_acc}, epoch)def test(self,epoch):sum_loss = 0sum_acc = 0for img_tensors, targets in tqdm(self.test_loader, desc="test...", total=len(self.test_loader)):out = self.net(img_tensors)loss = self.loss(out, targets)sum_loss += loss.item()pred_cls = torch.argmax(out, dim=1)target_cls = torch.argmax(targets, dim=1)accuracy = torch.mean(torch.eq(pred_cls, target_cls).to(torch.float32))sum_acc += accuracy.item()avg_loss = sum_loss / len(self.test_loader)avg_acc = sum_acc / len(self.test_loader)print(f'test:loss{round(avg_loss, 3)} acc:{round(avg_acc, 3)}')self.writer.add_scalars("loss", {"test_avg_loss": avg_loss}, epoch)self.writer.add_scalars("acc", {"test_avg_acc": avg_acc}, epoch)if avg_loss<self.min_loss:self.min_loss=min(self.min_loss,avg_loss)torch.save(self.net.state_dict(), self.weight_path)def run(self):for epo in range(100):self.train(epo)self.test(epo)if __name__ == '__main__':trainer=TrainAndTest()trainer.run()

精度的计算:

                比如通过网络出现的维度是(1,2),其数值是[[0.9 , 0.1]](0.9与0.1表示预测的两个类别的概率)。我们通过maxarg取到其中最大的索引0,与之前真实的标签0或者1做比较。从而可以得出结果

 

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • FastApi+WebSocket 解析
  • 【数据结构】双向链表
  • DDR3 (四)
  • WPF引入多个控件库使用
  • 机器学习——LR、‌GBDT、‌SVM、‌CNN、‌DNN、‌RNN、‌Word2Vec等模型的原理和应用
  • Kodcloud可道云安装与一键发布上线实现远程访问详细教程
  • Linux操作系统CentOS如何更换yum镜像源
  • 【智能制造-14】机器视觉软件
  • 2.5 C#视觉程序开发实例1----CamManager实现模拟相机采集图片
  • python杨辉三角的两种书写方式
  • LLM代理应用实战:构建Plotly数据可视化代理
  • 【区块链农场】:农场游戏+游戏
  • Unity之OpenXR+XR Interaction Toolkit实现 Gaze眼部追踪
  • 使用node-cmd重启electron
  • 常见的开源工具(代码托管平台)都有哪些
  • 【css3】浏览器内核及其兼容性
  • crontab执行失败的多种原因
  • Java知识点总结(JDBC-连接步骤及CRUD)
  • puppeteer stop redirect 的正确姿势及 net::ERR_FAILED 的解决
  • RxJS: 简单入门
  • 程序员该如何有效的找工作?
  • 初识 beanstalkd
  • 离散点最小(凸)包围边界查找
  • 理解 C# 泛型接口中的协变与逆变(抗变)
  • 浅谈Golang中select的用法
  • 如何实现 font-size 的响应式
  • ​​​​​​​​​​​​​​Γ函数
  • ​LeetCode解法汇总1276. 不浪费原料的汉堡制作方案
  • ​必胜客礼品卡回收多少钱,回收平台哪家好
  • # Swust 12th acm 邀请赛# [ A ] A+B problem [题解]
  • #include到底该写在哪
  • #我与Java虚拟机的故事#连载15:完整阅读的第一本技术书籍
  • $$$$GB2312-80区位编码表$$$$
  • (1)Hilt的基本概念和使用
  • (Java岗)秋招打卡!一本学历拿下美团、阿里、快手、米哈游offer
  • (二)springcloud实战之config配置中心
  • (附程序)AD采集中的10种经典软件滤波程序优缺点分析
  • (附源码)python旅游推荐系统 毕业设计 250623
  • (汇总)os模块以及shutil模块对文件的操作
  • (十六)一篇文章学会Java的常用API
  • (十一)手动添加用户和文件的特殊权限
  • (实测可用)(3)Git的使用——RT Thread Stdio添加的软件包,github与gitee冲突造成无法上传文件到gitee
  • (使用vite搭建vue3项目(vite + vue3 + vue router + pinia + element plus))
  • (学习日记)2024.04.04:UCOSIII第三十二节:计数信号量实验
  • (转) 深度模型优化性能 调参
  • (转)Sql Server 保留几位小数的两种做法
  • **登录+JWT+异常处理+拦截器+ThreadLocal-开发思想与代码实现**
  • .net core docker部署教程和细节问题
  • .NET core 自定义过滤器 Filter 实现webapi RestFul 统一接口数据返回格式
  • .net core使用ef 6
  • /etc/fstab 只读无法修改的解决办法
  • [ NOI 2001 ] 食物链
  • [000-002-01].数据库调优相关学习
  • [100天算法】-x 的平方根(day 61)
  • [2544]最短路 (两种算法)(HDU)