当前位置: 首页 > news >正文

图像分类实战:深度学习在CIFAR-10数据集上的应用

1.前言

        图像分类是计算机视觉领域的一个核心任务,算法能够自动识别图像中的物体或场景,并将其归类到预定义的类别中。近年来,深度学习技术的发展极大地推动了图像分类领域的进步。CIFAR-10数据集作为计算机视觉领域的一个经典小型数据集,为研究者提供了一个理想的实验平台,用于验证和比较不同的图像分类算法。本文将介绍CIFAR-10数据集的基本情况和加载方法,并展示如何构建与训练一个卷积神经网络(CNN)模型来进行图像分类,最后对模型的性能进行评估与可视化。

2.数据集介绍与加载

        CIFAR-10数据集由加拿大高等研究院(Canadian Institute for Advanced Research, CIFAR)发布,是计算机视觉领域广泛使用的基准数据集之一。它包含了10个类别(飞机、汽车、鸟类、猫、鹿、狗、青蛙、船、卡车、马)的彩色图像,每类有6,000张图像,共计60,000张。所有图像尺寸统一为32x32像素,且已进行标准化处理,其色彩模式为RGB。数据集被划分为50,000张训练图像和10,000张测试图像,保证了训练集与测试集的均衡分布。

        数据加载

        使用Python的tensorflow.keras.datasets模块加载CIFAR-10数据集,同时进行必要的预处理,如归一化和标签转换。

import tensorflow as tf# 加载CIFAR-10数据集
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()# 数据归一化
x_train, x_test = x_train / 255.0, x_test / 255.0# 将标签转换为one-hot编码
y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)

3.构建与训练CNN模型

        ResNet(Residual Neural Network)是一种深度残差学习网络,通过引入残差块解决了深度神经网络训练过程中的梯度消失和爆炸问题,从而能够构建和训练极深的模型,显著提升模型的性能和泛化能力。

        关于CNN模型的更多介绍,请看这篇文章:

卷积神经网络(CNN):图像识别的强大工具-CSDN博客文章浏览阅读795次,点赞9次,收藏18次。卷积神经网络是一种强大的图像识别工具,它能够自动学习图像的特征,并在各种图像识别任务中取得出色的效果。通过使用深度学习框架和大量的训练数据,我们可以构建出高效准确的卷积神经网络模型,实现对图像的分类、识别等任务。希望这篇文章能够帮助你更好地理解卷积神经网络在图像识别中的应用。如果你有任何问题或需要进一步的帮助,请随时提问。https://blog.csdn.net/meijinbo/article/details/137015665

3.1.构建模型

        使用Keras构建一个适用于CIFAR-10数据集的小型ResNet模型。

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, Add, MaxPooling2D, GlobalAveragePooling2D, Densedef residual_block(input_tensor, filters, strides=1, use_projection=False):shortcut = input_tensorif use_projection:shortcut = Conv2D(filters, kernel_size=1, strides=strides, padding='valid')(shortcut)shortcut = BatchNormalization()(shortcut)x = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(input_tensor)x = BatchNormalization()(x)x = Activation('relu')(x)x = Conv2D(filters, kernel_size=3, strides=1, padding='same')(x)x = BatchNormalization()(x)if strides != 1 or input_tensor.shape[-1] != filters:shortcut = Conv2D(filters, kernel_size=1, strides=strides, padding='valid')(shortcut)shortcut = BatchNormalization()(shortcut)x = Add()([shortcut, x])x = Activation('relu')(x)return xdef build_resnet():model = Sequential()model.add(Conv2D(16, kernel_size=3, padding='same', input_shape=(32, 32, 3)))model.add(BatchNormalization())model.add(Activation('relu'))for _ in range(2):model.add(residual_block(model.output, 16))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(residual_block(model.output, 32, strides=2, use_projection=True))for _ in range(2):model.add(residual_block(model.output, 32))model.add(GlobalAveragePooling2D())model.add(Dense(10, activation='softmax'))return modelresnet_model = build_resnet()
resnet_model.summary()

3.2.模型训练

        配置模型训练参数,启动训练过程,并监控训练进度。

resnet_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])history = resnet_model.fit(x_train, y_train,batch_size=128,epochs=100,validation_data=(x_test, y_test),verbose=1)

4.模型性能评估与可视化

4.1.性能评估

        评估模型在测试集上的最终性能指标。

test_loss, test_acc = resnet_model.evaluate(x_test, y_test, verbose=2)
print(f'Test accuracy: {test_acc:.4f}')

 4.2.可视化

        绘制训练过程中损失和准确率曲线,以直观了解模型收敛情况与过拟合风险。

import matplotlib.pyplot as pltdef plot_history(history):plt.figure(figsize=(12, 6))plt.subplot(1, 2, 1)plt.plot(history.history['accuracy'], label='Training Accuracy')plt.plot(history.history['val_accuracy'], label='Validation Accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy')plt.legend()plt.subplot(1, 2, 2)plt.plot(history.history['loss'], label='Training Loss')plt.plot(history.history['val_loss'], label='Validation Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.show()plot_history(history)  # 显示训练过程中的准确率与损失曲线

        以下是基于PyTorch的实现:

import torch.nn as nn  
import torch.nn.functional as F  class SimpleCNN(nn.Module):  def __init__(self):  super(SimpleCNN, self).__init__()  self.conv1 = nn.Conv2d(3, 6, 5)  self.pool = nn.MaxPool2d(2, 2)  self.conv2 = nn.Conv2d(6, 16, 5)  self.fc1 = nn.Linear(16 * 5 * 5, 120)  self.fc2 = nn.Linear(120, 84)  self.fc3 = nn.Linear(84, 10)  def forward(self, x):  x = self.pool(F.relu(self.conv1(x)))  x = self.pool(F.relu(self.conv2(x)))  x = x.view(-1, 16 * 5 * 5)  x = F.relu(self.fc1(x))  x = F.relu(self.fc2(x))  x = self.fc3(x)  return x  # 实例化模型、定义损失函数和优化器  
model = SimpleCNN()  
criterion = nn.CrossEntropyLoss()  
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)  # 训练模型  
for epoch in range(2):  # 假设我们训练两个epoch  running_loss = 0.0  for i, data in enumerate(trainloader, 0):  inputs, labels = data  optimizer.zero_grad()  outputs = model(inputs)  loss = criterion(outputs, labels)  loss.backward()  optimizer.step()  running_loss += loss.item()  if i % 2000 == 1999:  # 每2

 5.总结

        通过以上步骤,我们已经完成了在CIFAR-10数据集上使用深度学习进行图像分类的全过程。从数据集的介绍与加载,到构建并训练ResNet模型,再到模型性能的评估与可视化,这一系列操作展示了如何将理论知识应用于实际问题,揭示了深度学习在图像分类任务中的强大能力。实践中,可根据具体需求调整模型结构、优化策略等参数以进一步提升模型性能。

相关文章:

  • Java 中文官方教程 2022 版翻译完成
  • LQR的横向控制与算法仿真实现
  • BaseDao封装增删改查
  • mybatisplus如何拼接动态sql
  • 13 React useEffect 详解
  • uniapp先显示提示消息再返回上一页
  • 数据结构刷题篇 之 【力扣二叉树基础OJ】详细讲解(含每道题链接及递归图解)
  • Python 进阶教程
  • 算法部署总结
  • math模块篇(七)
  • 【笔试】美团2023年秋招第1场笔试(后端数开软件方向)
  • Java基础语法(二)
  • 骗子查询系统源码
  • 在vue中使用echarts饼图示例
  • C++——vector类及其模拟实现
  • 「前端早读君006」移动开发必备:那些玩转H5的小技巧
  • 【腾讯Bugly干货分享】从0到1打造直播 App
  • CentOS 7 修改主机名
  • download使用浅析
  • JavaScript创建对象的四种方式
  • Javascript基础之Array数组API
  • linux安装openssl、swoole等扩展的具体步骤
  • Linux快速配置 VIM 实现语法高亮 补全 缩进等功能
  • macOS 中 shell 创建文件夹及文件并 VS Code 打开
  • Objective-C 中关联引用的概念
  • Vue实战(四)登录/注册页的实现
  • WordPress 获取当前文章下的所有附件/获取指定ID文章的附件(图片、文件、视频)...
  • 批量截取pdf文件
  • 前端存储 - localStorage
  • 前端面试之CSS3新特性
  • 微服务入门【系列视频课程】
  • PostgreSQL之连接数修改
  • 关于Android全面屏虚拟导航栏的适配总结
  • # 数据结构
  • (编译到47%失败)to be deleted
  • (附源码)计算机毕业设计SSM智慧停车系统
  • (六)激光线扫描-三维重建
  • (深度全面解析)ChatGPT的重大更新给创业者带来了哪些红利机会
  • (算法)Travel Information Center
  • (已解决)vue+element-ui实现个人中心,仿照原神
  • (原创)攻击方式学习之(4) - 拒绝服务(DOS/DDOS/DRDOS)
  • (转)JAVA中的堆栈
  • (转)创业的注意事项
  • ****** 二十三 ******、软设笔记【数据库】-数据操作-常用关系操作、关系运算
  • ***测试-HTTP方法
  • . Flume面试题
  • .naturalWidth 和naturalHeight属性,
  • .NET 5.0正式发布,有什么功能特性(翻译)
  • .NET/C# 避免调试器不小心提前计算本应延迟计算的值
  • .NET/C# 在代码中测量代码执行耗时的建议(比较系统性能计数器和系统时间)
  • .net打印*三角形
  • .Net高阶异常处理第二篇~~ dump进阶之MiniDumpWriter
  • .NET使用HttpClient以multipart/form-data形式post上传文件及其相关参数
  • .Net小白的大学四年,内含面经
  • /dev/sda2 is mounted; will not make a filesystem here!