当前位置: 首页 > news >正文

pytorch神经网络训练(AlexNet)

  • 导包
import osimport torchimport torch.nn as nnimport torch.optim as optimfrom torch.utils.data import Dataset, DataLoaderfrom PIL import Imagefrom torchvision import models, transforms
  • 定义自定义图像数据集
class CustomImageDataset(Dataset): 

定义一个自定义的图像数据集类,继承自Dataset

def __init__(self, main_dir, transform=None): 

初始化方法,接收主目录和转换方法

        self.main_dir = main_dir 

主目录,包含多个子目录,每个子目录包含同一类别的图像

        self.transform = transform

 图像转换方法,用于对图像进行预处理

        self.files = [] 

存储所有图像文件的路径

        self.labels = [] 

存储所有图像的标签

        self.label_to_index = {} 

创建一个字典,用于将标签映射到索引

        for index, label in enumerate(os.listdir(main_dir)):

 遍历主目录中的所有子目录

 

          self.label_to_index[label] = index label_dir = os.path.join(main_dir, label) 

将标签映射到索引,构建标签子目录的路径

           if os.path.isdir(label_dir): for file in os.listdir(label_dir): self.files.append(os.path.join(label_dir, file))self.labels.append(label) 

如果是目录,遍历目录中的所有文件,将文件路径添加到列表,将标签添加到列表

def __len__(self):

定义数据集的长度

        return len(self.files) 

返回文件列表的长度

def __getitem__(self, idx): 

定义获取数据集中单个样本的方法

        image = Image.open(self.files[idx]) label = self.labels[idx] if self.transform: image = self.transform(image) return image, self.label_to_index[label] 

打开图像文件,获取图像的标签,如果有转换方法,对图像进行转换,返回图像和对应的标签索引

  • 定义数据转换
transform = transforms.Compose([transforms.Resize((227, 227)),  # AlexNet的输入图像大小transforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.RandomRotation(10),  # 随机旋转transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # AlexNet的标准化])

  • 创建数据集
dataset = CustomImageDataset(main_dir="D:\\图像处理、深度学习\\flowers", transform=transform)
  • 创建数据加载器
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
  • 加载预训练的AlexNet模型
alexnet_model = models.alexnet(pretrained=True)
  • 修改最后几层以适应新的分类任务
num_ftrs = alexnet_model.classifier[6].in_featuresalexnet_model.classifier[6] = nn.Linear(num_ftrs, len(dataset.label_to_index))
  • 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(alexnet_model.parameters(), lr=0.0001)
  • 如果有多个GPU,可以使用nn.DataParallel来并行化模型
if torch.cuda.device_count() > 1:alexnet_model = nn.DataParallel(alexnet_model)
  • 将模型发送到GPU(如果可用)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")alexnet_model.to(device)                                                               

  • 模型评估
def evaluate_model(model, data_loader, device):model.eval()  # 将模型设置为评估模式correct = 0total = 0with torch.no_grad():  # 在这个块中,所有计算都不会计算梯度for images, labels in data_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalreturn accuracy
  • 训练模型
num_epochs = 10for epoch in range(num_epochs):alexnet_model.train()running_loss = 0.0for images, labels in data_loader:images, labels = images.to(device), labels.to(device)

前向传播

        outputs = alexnet_model(images)loss = criterion(outputs, labels)

反向传播和优化

        optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item()

在每个epoch结束后评估模型

    train_accuracy = evaluate_model(alexnet_model, data_loader, device)print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(data_loader):.4f}, Train Accuracy: {train_accuracy:.2f}%')

相关文章:

  • 长难句打卡6.14
  • for 、while循环
  • Git代码冲突原理与三路合并算法
  • Android Studio新增功能:Device Streaming
  • 基于redis的分布式锁
  • 开源WebGIS全流程常用技术栈
  • Log4j日志级别介绍
  • 2024.06.01 校招 实习 内推 面经
  • Spring Boot 的启动原理、Spring Boot 自动配置原理
  • C++面向对象程序设计 - 命名空间
  • stm32编写Modbus步骤
  • Idea jdk配置的地方 启动时指定切换的地方
  • 嵌入式学习
  • 学习分享-分布式 NoSQL 数据库管理系统Cassandra以及它和redis的区别
  • 2024 Java 异常—面试常见问题
  • $translatePartialLoader加载失败及解决方式
  • 【mysql】环境安装、服务启动、密码设置
  • LeetCode刷题——29. Divide Two Integers(Part 1靠自己)
  • Redis在Web项目中的应用与实践
  • Vim 折腾记
  • Vue UI框架库开发介绍
  • 从 Android Sample ApiDemos 中学习 android.animation API 的用法
  • 第2章 网络文档
  • 复杂数据处理
  • 使用agvtool更改app version/build
  • 吐槽Javascript系列二:数组中的splice和slice方法
  • 项目管理碎碎念系列之一:干系人管理
  • 找一份好的前端工作,起点很重要
  • MiKTeX could not find the script engine ‘perl.exe‘ which is required to execute ‘latexmk‘.
  • ​Base64转换成图片,android studio build乱码,找不到okio.ByteString接腾讯人脸识别
  • ​马来语翻译中文去哪比较好?
  • # SpringBoot 如何让指定的Bean先加载
  • #每日一题合集#牛客JZ23-JZ33
  • (CPU/GPU)粒子继承贴图颜色发射
  • (pytorch进阶之路)扩散概率模型
  • (附源码)springboot家庭财务分析系统 毕业设计641323
  • (论文阅读31/100)Stacked hourglass networks for human pose estimation
  • (转)ORM
  • (转)程序员技术练级攻略
  • **PHP分步表单提交思路(分页表单提交)
  • .NET Core日志内容详解,详解不同日志级别的区别和有关日志记录的实用工具和第三方库详解与示例
  • .NET/C# 编译期能确定的字符串会在字符串暂存池中不会被 GC 垃圾回收掉
  • .net6 webapi log4net完整配置使用流程
  • .NET中GET与SET的用法
  • .net中生成excel后调整宽度
  • /tmp目录下出现system-private文件夹解决方法
  • [1181]linux两台服务器之间传输文件和文件夹
  • [2019.3.20]BZOJ4573 [Zjoi2016]大森林
  • [51nod1610]路径计数
  • [HDU5685]Problem A
  • [IE编程] 多页面基于IE内核浏览器的代码示例
  • [Java][Android][Process] 暴力的服务能够解决一切,暴力的方式运行命令行语句
  • [LaTex]arXiv投稿攻略——jpg/png转pdf
  • [MySQL]视图索引以及连接查询案列
  • [NLP] LlaMa2模型运行在Mac机器