当前位置: 首页 > news >正文

Tensorflow2.0笔记 - 自定义Layer和Model实现CIFAR10数据集的训练

       本笔记记录使用自定义Layer和Model来做CIFAR10数据集的训练。

        CIFAR10数据集下载:

        https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

        自定义的Layer和Model实现较为简单,参数量较少,并且没有卷积层和dropout等,最终准确率不高,仅做练习使用。

import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metricstf.__version__def preprocess(x, y):x = tf.cast(x, dtype=tf.float32) / 255y = tf.cast(y, dtype=tf.int32)return x,ybatchsize = 128
#CIFAR10数据集下载,可以直接使用网络下载
(x,y), (x_val, y_val) = datasets.cifar10.load_data()
#CIFAR10的标签(训练集)数据维度是[50000, 1],通过squeeze消除掉里面1的维度,变成[50000]
print("y.shape:", y.shape)
y = tf.squeeze(y)
print("squeezed y.shape:", y.shape)
y_val = tf.squeeze(y_val)
#进行onehot编码
y = tf.one_hot(y, depth=10)
y_val = tf.one_hot(y_val, depth=10)
print("Datasets: ", x.shape, " ", y.shape, " x.min():", x.min(), " x.max():", x.max())train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.map(preprocess).shuffle(10000).batch(batchsize)
test_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
test_db = test_db.map(preprocess).batch(batchsize)sample = next(iter(train_db))
print("Batch:", sample[0].shape, sample[1].shape)#自定义Layer
class MyDense(layers.Layer):def __init__(self, input_dim, output_dim):super(MyDense, self).__init__()self.kernel = self.add_weight(name='w', shape=[input_dim, output_dim], initializer=tf.random_uniform_initializer(0, 1.0))self.bias = self.add_weight(name='b', shape=[output_dim], initializer=tf.random_uniform_initializer(0, 1.0))#self.kernel = self.add_weight(name='w', shape=[input_dim, output_dim])#self.bias = self.add_weight(name='b', shape=[output_dim])def call(self, inputs, training = None):x = inputs@self.kernel + self.biasreturn xclass MyNetwork(keras.Model):def __init__(self):super(MyNetwork, self).__init__()self.fc1 = MyDense(32 * 32 * 3, 512)self.fc2 = MyDense(512, 512)self.fc3 = MyDense(512, 256)self.fc4 = MyDense(256, 256)self.fc5 = MyDense(256, 10)def call(self, inputs, training = None):x = tf.reshape(inputs, [-1, 32 * 32 * 3])x = self.fc1(x)x = tf.nn.relu(x)x = self.fc2(x)x = tf.nn.relu(x)x = self.fc3(x)x = tf.nn.relu(x)x = self.fc4(x)x = tf.nn.relu(x)x = self.fc5(x)x = tf.nn.relu(x)#返回logitsreturn xtotal_epoches = 35
learn_rate = 0.001
network = MyNetwork()
network.compile(optimizer=optimizers.Adam(learning_rate=learn_rate),loss = tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['Accuracy'])
network.fit(train_db, epochs=total_epoches, validation_data=test_db, validation_freq=1)

运行结果:

相关文章:

  • TCP的十个重要的机制
  • [每周一更]-第92期:Go项目中的限流算法
  • 信创环境ES索引管理脚本:close, delete
  • 优化 Nginx 处理 504 Gateway Timeout 错误
  • 【漏洞复现】WordPress Plugin LearnDash LMS 敏感信息暴漏
  • 即刻体验 | 使用 Flutter 3.19 更高效地开发
  • 【软件工程】详细设计(一)
  • Autodesk AutoCAD 2025 (macOS, Windows) - 自动计算机辅助设计软件
  • 文件操作讲解
  • Golang基础-9
  • 后端前行Vue之路(三):计算属性和监视属性
  • YARN集群 和 MapReduce 原理及应用
  • Git 常用命令集
  • GitGithub小册:版本管理必备利器
  • 超文本传输协议HTTP
  • Java 23种设计模式 之单例模式 7种实现方式
  • Lucene解析 - 基本概念
  • miaov-React 最佳入门
  • Netty 4.1 源代码学习:线程模型
  • PhantomJS 安装
  • python3 使用 asyncio 代替线程
  • React-flux杂记
  • TypeScript实现数据结构(一)栈,队列,链表
  • Wamp集成环境 添加PHP的新版本
  • Web设计流程优化:网页效果图设计新思路
  • 关于字符编码你应该知道的事情
  • 机器学习 vs. 深度学习
  • 实战:基于Spring Boot快速开发RESTful风格API接口
  • 视频flv转mp4最快的几种方法(就是不用格式工厂)
  • 我建了一个叫Hello World的项目
  • ​软考-高级-信息系统项目管理师教程 第四版【第14章-项目沟通管理-思维导图】​
  • #HarmonyOS:Web组件的使用
  • #Z2294. 打印树的直径
  • #我与Java虚拟机的故事#连载08:书读百遍其义自见
  • (14)学习笔记:动手深度学习(Pytorch神经网络基础)
  • (145)光线追踪距离场柔和阴影
  • (4)通过调用hadoop的java api实现本地文件上传到hadoop文件系统上
  • (6)STL算法之转换
  • (多级缓存)缓存同步
  • (五)Python 垃圾回收机制
  • (转)负载均衡,回话保持,cookie
  • **PHP分步表单提交思路(分页表单提交)
  • .[hudsonL@cock.li].mkp勒索病毒数据怎么处理|数据解密恢复
  • .Net IE10 _doPostBack 未定义
  • .Net MVC + EF搭建学生管理系统
  • .Net 中的反射(动态创建类型实例) - Part.4(转自http://www.tracefact.net/CLR-and-Framework/Reflection-Part4.aspx)...
  • .NET导入Excel数据
  • .Net调用Java编写的WebServices返回值为Null的解决方法(SoapUI工具测试有返回值)
  • .net连接MySQL的方法
  • /dev下添加设备节点的方法步骤(通过device_create)
  • /etc/skel 目录作用
  • /proc/vmstat 详解
  • /var/lib/dpkg/lock 锁定问题
  • @EnableAsync和@Async开始异步任务支持
  • @html.ActionLink的几种参数格式