当前位置: 首页 > news >正文

LeNet实验 四分类 与 四分类变为多个二分类

 

目录

1. 划分二分类

2. 训练独立的二分类模型

3. 二分类预测结果代码

4. 二分类预测结果

5 改进训练模型

6 优化后 预测结果代码

 7 优化后预测结果

8 训练四分类模型 

9 预测结果代码

10 四分类结果识别


1. 划分二分类

可以根据不同的类别进行多个划分,以实现NonDemented为例,划分为NonDemented和Demented两类,不属于NonDemented的全都属于Demented

2. 训练独立的二分类模型

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGeneratorfrom 文件准备 import data_dir# 数据生成器
train_datagen = ImageDataGenerator(rescale=1./255,shear_range=0.2,zoom_range=0.2,horizontal_flip=True,validation_split=0.2  # 20%用于验证
)train_generator = train_datagen.flow_from_directory(data_dir,target_size=(28, 28),batch_size=32,class_mode='binary',subset='training'
)validation_generator = train_datagen.flow_from_directory(data_dir,target_size=(28, 28),batch_size=32,class_mode='binary',subset='validation'
)# 构建LeNet-5模型
model = models.Sequential()
model.add(layers.Conv2D(6, (5, 5), activation='relu', input_shape=(28, 28, 3), padding='same'))
model.add(layers.AveragePooling2D((2, 2)))
model.add(layers.Conv2D(16, (5, 5), activation='relu', padding='same'))
model.add(layers.AveragePooling2D((2, 2)))
model.add(layers.Conv2D(120, (5, 5), activation='relu', padding='same'))
model.add(layers.Flatten())
model.add(layers.Dense(84, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))# 编译模型
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])# 训练模型
model.fit(train_generator,steps_per_epoch=train_generator.samples // train_generator.batch_size,epochs=10,validation_data=validation_generator,validation_steps=validation_generator.samples // validation_generator.batch_size
)# 保存模型
model.save('lenet_binary_classification_model.h5')

3. 预测结果代码

import tensorflow as tf
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt# 加载模型
model = tf.keras.models.load_model('lenet_binary_classification_model.h5')# 预处理图像
def preprocess_image(img_path):img = image.load_img(img_path, target_size=(28, 28))img_array = image.img_to_array(img) / 255.0img_array = np.expand_dims(img_array, axis=0)return img_array# 预测图像
img_path = 'D:\Pycharm_workspace\LeNet实验_二分类\Demented\moderateDem24.jpg'  # 测试图像路径
img_array = preprocess_image(img_path)
prediction = model.predict(img_array)
predicted_class = 'Demented' if prediction[0][0] > 0.5 else 'NonDemented'print(f'The predicted class is: {predicted_class}')# 显示图像
img = image.load_img(img_path, target_size=(28, 28))
plt.imshow(img)
plt.title(f'Predicted: {predicted_class}')
plt.show()

4. 预测结果

Demented结果

 NonDemented结果没有。。。。。。

竟然全都没有。。。。因为预测的全部都是Demented

疯狂找原因中

猜测是像素太低使得训练的模型准确率太低

于是重新训练

5 改进训练模型

进行重新训练

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt# 定义LeNet模型
def create_lenet_model(input_shape):model = Sequential([Conv2D(6, (5, 5), activation='relu', input_shape=input_shape, padding='same'),MaxPooling2D((2, 2), strides=2),Conv2D(16, (5, 5), activation='relu'),MaxPooling2D((2, 2), strides=2),Flatten(),Dense(120, activation='relu'),Dense(84, activation='relu'),Dense(1, activation='sigmoid')])model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])return model# 数据增强和数据生成器
train_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)# 训练数据生成器
train_generator = train_datagen.flow_from_directory('D:\Pycharm_workspace\LeNet实验_二分类\image',target_size=(176, 208),batch_size=32,class_mode='binary',subset='training'
)# 验证数据生成器
validation_generator = train_datagen.flow_from_directory('D:\Pycharm_workspace\LeNet实验_二分类\image',target_size=(176, 208),batch_size=32,class_mode='binary',subset='validation'
)# 创建并训练模型
input_shape = (176, 208, 3)
model = create_lenet_model(input_shape)
history = model.fit(train_generator, epochs=10, validation_data=validation_generator)# 保存模型
model.save('dementia_classification_model.h5')# 绘制训练和验证损失
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.title('训练和验证损失')
plt.xlabel('时期')
plt.ylabel('损失')
plt.legend()
plt.show()# 绘制训练和验证准确率
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.title('训练和验证准确率')
plt.xlabel('时期')
plt.ylabel('准确率')
plt.legend()
plt.show()

 这里还有图形画loss与准确率但是我忘记保存了,就用控制台的输出

 可以看到loss值非常小而且准确率是100

6 优化后 预测结果代码

import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt
import os# 加载模型
model = load_model('dementia_classification_model.h5')# 定义类别标签
class_labels = ['Demented', 'NonDemented']# 预测函数
def predict_image(img_path):img = image.load_img(img_path, target_size=(176, 208))img_array = image.img_to_array(img)img_array = np.expand_dims(img_array, axis=0)img_array /= 255.0prediction = model.predict(img_array)predicted_class = class_labels[int(prediction[0] > 0.5)]# 显示图像和预测结果plt.imshow(image.load_img(img_path))plt.title(f'Predicted: {predicted_class}')plt.axis('off')plt.show()# 预测并展示结果
img_path = r'D:\Pycharm_workspace\LeNet实验_二分类\image\NonDemented\nonDem1.jpg'  # 替换为你的图片路径
predict_image(img_path)

 7 优化后预测结果

 图片与预测结果对应上了(右侧是图片链接可以看到是Dem的类型)

 NonDem的也是对应上了

 就此训练完成

 

8 训练四分类模型 

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt# 定义LeNet模型
def create_lenet_model(input_shape):model = Sequential([Conv2D(6, (5, 5), activation='relu', input_shape=input_shape, padding='same'),MaxPooling2D((2, 2), strides=2),Conv2D(16, (5, 5), activation='relu'),MaxPooling2D((2, 2), strides=2),Flatten(),Dense(120, activation='relu'),Dense(84, activation='relu'),Dense(4, activation='softmax')])model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])return model# 数据增强和数据生成器
train_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)# 训练数据生成器
train_generator = train_datagen.flow_from_directory('D:\Pycharm_workspace\LeNet实验_四分类\image',target_size=(176, 208),batch_size=32,class_mode='categorical',subset='training'
)# 验证数据生成器
validation_generator = train_datagen.flow_from_directory('D:\Pycharm_workspace\LeNet实验_四分类\image',target_size=(176, 208),batch_size=32,class_mode='categorical',subset='validation'
)# 创建并训练模型
input_shape = (176, 208, 3)
model = create_lenet_model(input_shape)
history = model.fit(train_generator, epochs=10, validation_data=validation_generator)# 保存模型
model.save('dementia_classification_model.h5')# 绘制训练和验证损失
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.title('训练和验证损失')
plt.xlabel('时期')
plt.ylabel('损失')
plt.legend()
plt.show()# 绘制训练和验证准确率
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.title('训练和验证准确率')
plt.xlabel('时期')
plt.ylabel('准确率')
plt.legend()
plt.show()

 

loss值与准确率的变化图

可以看到才第四轮准确率就已经很高了 

9 预测结果代码

import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt# 加载模型
model = load_model('dementia_classification_model.h5')# 定义类别标签
class_labels = ['MildDemented', 'ModerateDemented', 'NonDemented', 'VeryMildDemented']# 预测函数
def predict_image(img_path):img = image.load_img(img_path, target_size=(176, 208))img_array = image.img_to_array(img)img_array = np.expand_dims(img_array, axis=0)img_array /= 255.0prediction = model.predict(img_array)predicted_class = class_labels[np.argmax(prediction)]# 显示图像和预测结果plt.imshow(image.load_img(img_path))plt.title(f'Predicted: {predicted_class}')plt.axis('off')plt.show()# 预测并展示结果
img_path = r'D:\Pycharm_workspace\LeNet实验_四分类\image\VeryMildDemented\verymildDem0.jpg'  # 你的图片路径
predict_image(img_path)

 

10 四分类结果识别

1 MildDem成功识别(右侧有图片名称)

2 ModerateDem 成功识别

3 NonDem成功识别

 4 VeryMildDem成功识别

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 【Python】TensorFlow介绍与实战
  • 55 、mysql的存储引擎、备份恢复以及日志备份、恢复
  • 【Vue3】响应式数据
  • RocketMQ~架构与工作流程了解
  • electron项目中实现视频下载保存到本地
  • 【深度学习】VGG-16原理及代码实现
  • 【深度学习】OCR中的Shrink操作详解
  • 【分布式事务】怎么解决分布式场景下数据一致性问题
  • Springboot 3.x - Reactive programming (2)
  • 钡铼Profinet、EtherCAT、Modbus、MQTT、Ethernet/IP、OPC UA分布式IO系统BL20X系列耦合器
  • GOLLIE : ANNOTATION GUIDELINES IMPROVE ZERO-SHOT INFORMATION-EXTRACTION
  • vue基于Cookies实现记住密码自动登录功能
  • Spring Boot外部配置加载顺序
  • Github报错:Kex_exchange_identification: Connection closed by remote host
  • Linux云计算 |【第一阶段】ENGINEER-DAY3
  • [数据结构]链表的实现在PHP中
  • 【翻译】Mashape是如何管理15000个API和微服务的(三)
  • DataBase in Android
  • ECS应用管理最佳实践
  • ES学习笔记(12)--Symbol
  • Java基本数据类型之Number
  • Python学习笔记 字符串拼接
  • React 快速上手 - 06 容器组件、展示组件、操作组件
  • RxJS 实现摩斯密码(Morse) 【内附脑图】
  • WebSocket使用
  • 初识 beanstalkd
  • 从零开始的无人驾驶 1
  • 分享一个自己写的基于canvas的原生js图片爆炸插件
  • 利用jquery编写加法运算验证码
  • 买一台 iPhone X,还是创建一家未来的独角兽?
  • 前端面试之CSS3新特性
  • 前端之React实战:创建跨平台的项目架构
  • 让你的分享飞起来——极光推出社会化分享组件
  • 小程序上传图片到七牛云(支持多张上传,预览,删除)
  • 学习笔记TF060:图像语音结合,看图说话
  • 一个JAVA程序员成长之路分享
  • 用Canvas画一棵二叉树
  • 用简单代码看卷积组块发展
  • 【运维趟坑回忆录】vpc迁移 - 吃螃蟹之路
  • LIGO、Virgo第三轮探测告捷,同时探测到一对黑洞合并产生的引力波事件 ...
  • 不要一棍子打翻所有黑盒模型,其实可以让它们发挥作用 ...
  • 如何用纯 CSS 创作一个菱形 loader 动画
  • ​​​​​​​sokit v1.3抓手机应用socket数据包: Socket是传输控制层协议,WebSocket是应用层协议。
  • ​sqlite3 --- SQLite 数据库 DB-API 2.0 接口模块​
  • ​力扣解法汇总946-验证栈序列
  • #window11设置系统变量#
  • #Z2294. 打印树的直径
  • $分析了六十多年间100万字的政府工作报告,我看到了这样的变迁
  • (145)光线追踪距离场柔和阴影
  • (2)MFC+openGL单文档框架glFrame
  • (C#)获取字符编码的类
  • (html转换)StringEscapeUtils类的转义与反转义方法
  • (Java岗)秋招打卡!一本学历拿下美团、阿里、快手、米哈游offer
  • (leetcode学习)236. 二叉树的最近公共祖先
  • (二)原生js案例之数码时钟计时