当前位置: 首页 > news >正文

昇思第7天

模型训练

模型训练一般分为四个步骤:

构建数据集。
定义神经网络模型。
定义超参、损失函数及优化器。
输入数据集进行训练与评估。

  1. 数据集加载
import mindspore
from mindspore import nn
# 从 MindSpore 数据集包中导入 vision 和 transforms 模块。
# vision:包含处理图像数据的工具。
# transforms:包含数据转换的工具。
from mindspore.dataset import vision, transforms
# 从 MindSpore 数据集包中导入 MnistDataset 类,用于加载 MNIST 数据集。
from mindspore.dataset import MnistDataset
# 从 download 模块中导入 download 函数,用于下载数据集。
from download import download# 指定数据集的 URL 地址。
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \"notebook/datasets/MNIST_Data.zip"# 使用 download 函数下载数据集并解压到当前目录。
path = download(url, "./", kind="zip", replace=True)# 定义一个数据管道函数,接收数据集路径和批量大小作为参数。
def datapipe(path, batch_size):# 定义图像数据的转换操作列表。image_transforms = [vision.Rescale(1.0 / 255.0, 0),       # 缩放图像像素值到 [0, 1] 范围。vision.Normalize(mean=(0.1307,), std=(0.3081,)),  # 标准化图像数据。vision.HWC2CHW()                      # 转换图像格式从 HWC(高度、宽度、通道)到 CHW(通道、高度、宽度)。]# 定义标签数据的转换操作,将标签转换为 int32 类型。label_transform = transforms.TypeCast(mindspore.int32)# 加载指定路径的数据集。dataset = MnistDataset(path)# 对数据集的图像应用转换操作。dataset = dataset.map(image_transforms, 'image')# 对数据集的标签应用转换操作。dataset = dataset.map(label_transform, 'label')# 将数据集分批,每批包含指定数量的样本。dataset = dataset.batch(batch_size)# 返回处理后的数据集。return dataset# 创建训练数据集,批量大小为 64。
train_dataset = datapipe('MNIST_Data/train', batch_size=64)# 创建测试数据集,批量大小为 64。
test_dataset = datapipe('MNIST_Data/test', batch_size=64)
  1. 构建神经网络
 # 定义一个神经网络类 Network,继承自 nn.Cell。
class Network(nn.Cell):# 在初始化方法中定义网络的结构。def __init__(self):# 调用父类的初始化方法。super().__init__()# 定义一个平坦化层,用于将输入的多维数据展开为一维。self.flatten = nn.Flatten()# 定义一个顺序容器 SequentialCell,其中包含多个层顺序连接。self.dense_relu_sequential = nn.SequentialCell(# 全连接层,将输入数据的尺寸从 28*28(即 784)转换为 512。nn.Dense(28*28, 512),# ReLU 激活函数。nn.ReLU(),# 全连接层,将输入数据的尺寸从 512 转换为 512。nn.Dense(512, 512),# ReLU 激活函数。nn.ReLU(),# 全连接层,将输入数据的尺寸从 512 转换为 10(对应于 10 个类别)。nn.Dense(512, 10))# 定义前向传播方法,用于计算网络的输出。def construct(self, x):# 将输入数据平坦化。x = self.flatten(x)# 依次通过顺序容器中的各层,得到最终的输出 logits。logits = self.dense_relu_sequential(x)# 返回计算得到的 logits。return logits# 创建一个 Network 类的实例,表示定义好的神经网络模型。
model = Network()

3.定义超参、损失函数及优化器。

# 定义训练的参数。
# 训练的轮数,即数据集将被遍历的次数。
epochs = 3
# 每个批次的大小,即一次训练中使用的样本数。
batch_size = 64
# 学习率,即模型参数在每次更新时调整的幅度。
learning_rate = 1e-2
# 定义训练的参数。
# 训练的轮数,即数据集将被遍历的次数。
epochs = 3
# 每个批次的大小,即一次训练中使用的样本数。
batch_size = 64
# 学习率,即模型参数在每次更新时调整的幅度。
learning_rate = 1e-2# 定义损失函数,用于计算预测结果与实际标签之间的差异。
# 使用交叉熵损失函数(CrossEntropyLoss),这是分类问题中常用的损失函数。
loss_fn = nn.CrossEntropyLoss()# 定义优化器,用于更新模型的参数。# 使用随机梯度下降(SGD)优化器。
# model.trainable_params() 获取模型中所有需要训练的参数。
# learning_rate 指定优化器的学习率。
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)

4.训练与评估
训练

# 定义前向函数,用于计算模型输出和损失。
def forward_fn(data, label):# 使用模型计算预测值(logits)。logits = model(data)# 计算预测值与真实标签之间的损失。loss = loss_fn(logits, label)# 返回损失值和预测值。return loss, logits# 获取梯度函数,用于计算损失相对于模型参数的梯度。
# mindspore.value_and_grad 会计算前向函数的值和梯度。
# forward_fn: 计算损失的前向函数。
# None: 不需要计算的额外输出。
# optimizer.parameters: 需要计算梯度的参数。
# has_aux=True: 表示前向函数返回多个值。
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# 定义单步训练函数。
def train_step(data, label):# 计算损失和梯度。(loss, _), grads = grad_fn(data, label)# 使用优化器更新模型参数。optimizer(grads)# 返回当前步的损失值。return loss# 定义训练循环函数。
def train_loop(model, dataset):# 获取数据集的大小(即批次的数量)。size = dataset.get_dataset_size()# 设置模型为训练模式。model.set_train()# 枚举数据集的每个批次。for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):# 执行单步训练,获取当前批次的损失值。loss = train_step(data, label)# 每 100 个批次打印一次损失值和当前批次编号。if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

测试函数

# 定义测试循环函数,用于在测试集上评估模型的性能。
def test_loop(model, dataset, loss_fn):# 获取数据集的批次数量。num_batches = dataset.get_dataset_size()# 设置模型为评估模式。model.set_train(False)# 初始化总样本数、测试损失和正确预测数。total, test_loss, correct = 0, 0, 0# 枚举数据集的每个批次。for data, label in dataset.create_tuple_iterator():# 使用模型进行预测。pred = model(data)# 累加总样本数。total += len(data)# 累加测试损失。test_loss += loss_fn(pred, label).asnumpy()# 累加正确预测数。correct += (pred.argmax(1) == label).asnumpy().sum()# 计算平均损失。test_loss /= num_batches# 计算准确率。correct /= total# 打印测试结果,包括准确率和平均损失。print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

运行

# 定义损失函数和优化器。
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)# 执行多个 epoch 的训练循环。
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")# 执行训练循环。train_loop(model, train_dataset)# 在测试集上进行评估。test_loop(model, test_dataset, loss_fn)print("Done!")

在这里插入图片描述

相关文章:

  • 递归算法练习
  • Qt的信号与槽机制底层原理
  • 核方法总结(三)———核主成分(kernel PCA)学习笔记
  • 【Python】字典练习
  • 深入了解 Redis 五种类型命令与如何在 Java 中操作 Redis
  • 冒泡排序写法
  • javaEE——Servlet
  • 探索哈希函数:数据完整性的守护者
  • 线性代数笔记
  • 软考系统架构师高效备考方法论
  • Python从零学习笔记(1)
  • 接口测试流程及测试点!
  • JS数据处理(冒泡寻找对象里面有个Key相同的值并处理相关数据)
  • slot插槽详解及动态插槽的使用
  • 全网最详细,零基础学会AI绘画Stable Diffusion,学不会来打我!
  • 【162天】黑马程序员27天视频学习笔记【Day02-上】
  • const let
  • CSS实用技巧干货
  • JavaScript新鲜事·第5期
  • JS+CSS实现数字滚动
  • node学习系列之简单文件上传
  • PAT A1050
  • PAT A1120
  • puppeteer stop redirect 的正确姿势及 net::ERR_FAILED 的解决
  • Shadow DOM 内部构造及如何构建独立组件
  • Windows Containers 大冒险: 容器网络
  • 从tcpdump抓包看TCP/IP协议
  • 复习Javascript专题(四):js中的深浅拷贝
  • 利用阿里云 OSS 搭建私有 Docker 仓库
  • 浏览器缓存机制分析
  • 强力优化Rancher k8s中国区的使用体验
  • 区块链分支循环
  • 使用Maven插件构建SpringBoot项目,生成Docker镜像push到DockerHub上
  • 我从编程教室毕业
  • 学习笔记DL002:AI、机器学习、表示学习、深度学习,第一次大衰退
  • 用quicker-worker.js轻松跑一个大数据遍历
  • # Python csv、xlsx、json、二进制(MP3) 文件读写基本使用
  • #我与Java虚拟机的故事#连载09:面试大厂逃不过的JVM
  • (1)(1.13) SiK无线电高级配置(六)
  • (CVPRW,2024)可学习的提示:遥感领域小样本语义分割
  • (C语言)共用体union的用法举例
  • (ResultSet.TYPE_SCROLL_INSENSITIVE,ResultSet.CONCUR_READ_ONLY)讲解
  • (编译到47%失败)to be deleted
  • (读书笔记)Javascript高级程序设计---ECMAScript基础
  • (二十五)admin-boot项目之集成消息队列Rabbitmq
  • (六)c52学习之旅-独立按键
  • (每日持续更新)jdk api之StringBufferInputStream基础、应用、实战
  • (七)理解angular中的module和injector,即依赖注入
  • (切换多语言)vantUI+vue-i18n进行国际化配置及新增没有的语言包
  • (全注解开发)学习Spring-MVC的第三天
  • (一)Linux+Windows下安装ffmpeg
  • (原)Matlab的svmtrain和svmclassify
  • (转) ns2/nam与nam实现相关的文件
  • ***php进行支付宝开发中return_url和notify_url的区别分析
  • .net 4.0发布后不能正常显示图片问题