当前位置: 首页 > news >正文

昇思25天学习打卡营第4天 | 网络构建

在学习和实践MindSpore神经网络模型构建的过程中,我深刻理解了MindSpore中如何通过nn.Cell类来构建和管理复杂的神经网络模型。通过这次的实践,我对神经网络的基本构建和应用有了更加全面的认识,以下是我学习过程中所总结的几点心得:

一、神经网络模型的基本构建

在MindSpore中,神经网络模型由神经网络层和Tensor操作构成。nn.Cell类是构建所有网络的基类,也是网络的基本单元。通过继承nn.Cell类并实现其__init__和construct方法,我们可以构建出各种复杂的神经网络结构。__init__方法用于进行子Cell的实例化和状态管理,而construct方法用于实现具体的Tensor操作。

例如,在构建一个用于Mnist数据集分类的神经网络模型时,我们定义了一个Network类,继承了nn.Cell。在__init__方法中,我们实例化了Flatten和SequentialCell子类,并在SequentialCell中定义了多层全连接层和激活函数ReLU。最后在construct方法中,实现了数据从输入到输出的完整流动过程。

class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512, weight_init="normal", bias_init="zeros"),nn.ReLU(),nn.Dense(512, 512, weight_init="normal", bias_init="zeros"),nn.ReLU(),nn.Dense(512, 10, weight_init="normal", bias_init="zeros"))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logits
二、模型实例化与结构查看

在完成模型构建后,我们可以通过实例化Network对象来查看其结构。通过print(model),我们可以清晰地看到模型中每一层的详细信息,包括输入和输出通道数、是否有偏置等。这对于我们理解模型的内部构造和调试非常有帮助。

model = Network()
print(model)

输出的模型结构如下:

Network<(flatten): Flatten<>(dense_relu_sequential): SequentialCell<(0): Dense<input_channels=784, output_channels=512, has_bias=True>(1): ReLU<>(2): Dense<input_channels=512, output_channels=512, has_bias=True>(3): ReLU<>(4): Dense<input_channels=512, output_channels=10, has_bias=True>>>
三、模型推理与输出

在实例化模型后,我们可以通过构造输入数据并直接调用模型来进行推理。需要注意的是,construct方法不可直接调用,我们可以直接通过模型实例来传入输入数据进行推理。通过这种方式,我们可以获得模型的输出Tensor。

X = ops.ones((1, 28, 28), mindspore.float32)
logits = model(X)

通过进一步使用nn.Softmax层,我们可以将输出的logits值转化为预测概率,并通过argmax方法获得预测的类别。

pred_probab = nn.Softmax(axis=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")
四、逐层解析与理解

为了更好地理解神经网络的工作原理,我们可以逐层解析模型中的每一层。例如,我们可以构造一个shape为(3, 28, 28)的随机数据,并依次通过每一个神经网络层来观察其效果。

通过对nn.Flatten、nn.Dense、nn.ReLU等层的逐层解析,我们可以清晰地看到数据在每一层的变化过程。这对于我们理解神经网络的内部工作原理非常有帮助。

input_image = ops.ones((3, 28, 28), mindspore.float32)
flatten = nn.Flatten()
flat_image = flatten(input_image)
print(flat_image.shape)layer1 = nn.Dense(in_channels=28*28, out_channels=20)
hidden1 = layer1(flat_image)
print(hidden1.shape)hidden1 = nn.ReLU()(hidden1)
print(hidden1)
五、模型参数管理

神经网络中的每一层(如nn.Dense)都具有权重参数和偏置参数。这些参数会在训练过程中不断进行优化。我们可以通过model.parameters_and_names()来获取模型中所有参数的名称及其详细信息。

for name, param in model.parameters_and_names():print(f"Layer: {name}\nSize: {param.shape}\nValues : {param[:2]} \n")
总结

通过这次学习和实践,我掌握了在MindSpore中构建神经网络模型的基本方法和技巧。通过对模型逐层解析和参数管理的深入理解,我不仅提高了对神经网络内部工作原理的认识,也增强了实际操作的能力。这为我今后在深度学习领域的研究和应用打下了坚实的基础。
在这里插入图片描述

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • Fastgpt本地或服务器私有化部署常见问题
  • 设计App的后端接口分类以及环境依赖包详情
  • java启动shell脚本
  • python入门课程Pro(1)--数据结构及判断
  • spring-boot2.x整合Kafka步骤
  • 7.18 学习笔记 解决分页越界问题 及分页查询
  • (02)Unity使用在线AI大模型(调用Python)
  • junit mockito service
  • 【Linux知识点汇总】07 Linux系统防火墙相关命令,关闭和开启防火墙、开放端口号
  • leetcode-三数之和
  • 第一章 应急响应- Linux入侵排查
  • windows 11 PC查询连接过的wlan密码
  • 高通Android 12 设置Global属性为null问题
  • 数据库——单表查询
  • 机械臂泡水维修|机器人雨后进水维修措施
  • 【347天】每日项目总结系列085(2018.01.18)
  • CSS中外联样式表代表的含义
  • Dubbo 整合 Pinpoint 做分布式服务请求跟踪
  • emacs初体验
  • Eureka 2.0 开源流产,真的对你影响很大吗?
  • IIS 10 PHP CGI 设置 PHP_INI_SCAN_DIR
  • java中的hashCode
  • mysql常用命令汇总
  • MySQL几个简单SQL的优化
  • SOFAMosn配置模型
  • sublime配置文件
  • Vue ES6 Jade Scss Webpack Gulp
  • vue-cli3搭建项目
  • 对象管理器(defineProperty)学习笔记
  • 分类模型——Logistics Regression
  • 排序算法学习笔记
  • 问:在指定的JSON数据中(最外层是数组)根据指定条件拿到匹配到的结果
  • const的用法,特别是用在函数前面与后面的区别
  • LIGO、Virgo第三轮探测告捷,同时探测到一对黑洞合并产生的引力波事件 ...
  • ​queue --- 一个同步的队列类​
  • ​软考-高级-系统架构设计师教程(清华第2版)【第15章 面向服务架构设计理论与实践(P527~554)-思维导图】​
  • #C++ 智能指针 std::unique_ptr 、std::shared_ptr 和 std::weak_ptr
  • #define 用法
  • (C++哈希表01)
  • (LeetCode) T14. Longest Common Prefix
  • (动手学习深度学习)第13章 计算机视觉---微调
  • (附源码)ssm教师工作量核算统计系统 毕业设计 162307
  • (更新)A股上市公司华证ESG评级得分稳健性校验ESG得分年均值中位数(2009-2023年.12)
  • (转)es进行聚合操作时提示Fielddata is disabled on text fields by default
  • (转)ORM
  • . Flume面试题
  • .gitignore文件忽略的内容不生效问题解决
  • .naturalWidth 和naturalHeight属性,
  • .net framwork4.6操作MySQL报错Character set ‘utf8mb3‘ is not supported 解决方法
  • .Net mvc总结
  • .net 程序发生了一个不可捕获的异常
  • .NET 设计模式—简单工厂(Simple Factory Pattern)
  • .net 使用$.ajax实现从前台调用后台方法(包含静态方法和非静态方法调用)
  • .net 写了一个支持重试、熔断和超时策略的 HttpClient 实例池
  • .Net 应用中使用dot trace进行性能诊断