当前位置: 首页 > news >正文

模型优化之剪枝

文章目录

  • 什么是神经网络剪枝
  • 剪枝的好处
  • 不同粒度的剪枝
  • 剪枝的分类
    • 非结构化剪枝
    • 结构化剪枝
  • 哪些层的参数更容易被剪掉
  • 剪枝效果

什么是神经网络剪枝

神经网络剪枝

  • 在训练期间删除连接
  • 密集张量将变得稀疏(用零填充)
  • 可以通过结构化块( n m nm nm)或( 11 11 11)删除连接

在这里插入图片描述

剪枝的好处

  • 减少过拟合
  • 稀疏性优势
  • 文件中有大量的0,如果有适当的稀疏张量表示方法,模型二进制文件尺寸减小。
  • 模型更小,可以减少内存带宽消耗量。
  • 对于特定模式的稀疏模型,可以开发优化算子,实现加速推理。

不同粒度的剪枝

在这里插入图片描述
什么时候做剪枝?

one-shot pruning : 一次性修剪,包括三个步骤训练模型、剪枝、再训练
剪枝:通常根据某种标准(如权重的大小、梯度的大小等)一次性去除大量权重。
再训练:剪枝后,模型通常需要进行一定数量的额外训练(称为fine-tuning或再训练)来恢复剪枝过程中可能损失的性能。

iterative pruning: 迭代式训练,特点如下:
初始训练:首先,对未剪枝的完整模型进行训练,直到达到满意的性能水平。
剪枝:然后,根据某种剪枝策略(例如基于权重的大小或敏感度)剪除模型的部分组件(如权重、神经元或通道)。
再训练:剪枝后,重新训练模型以恢复因剪枝而丢失的性能。
迭代:重复剪枝和再训练的过程,直到达到所需的剪枝率或性能标准。

automated gradual pruning: 自动化渐进剪枝,特点如下:
剪枝策略:采用一种预定义的剪枝策略,例如基于权重阈值、敏感度分析等,该策略在整个剪枝过程中保持一致。
渐进剪枝:在整个训练过程中逐渐增加剪枝率,通常从较低的剪枝率开始,逐步增加到目标剪枝率。
无需再训练:在整个剪枝过程中,模型持续被训练,而不是在剪枝后重新训练。
自动化:整个过程高度自动化,可以减少人为干预的需求

在这里插入图片描述

剪枝的分类

结构化剪枝(Structured Pruning)和非结构化剪枝(Unstructured Pruning)是两种常见的神经网络剪枝方法,它们的主要区别在于剪枝后网络结构的变化以及剪枝操作的粒度。

非结构化剪枝

不改变网络结构或者参数数量,把连接上的参数置0即为剪枝。
基于某种度量(如权重的绝对值大小)对所有权重进行排序,然后根据预先设定的剪枝比例(例如去除50%的最小权重)来决定哪些权重被设置为零。这种剪枝方法不会考虑权重在模型中的位置或结构,只关注权重本身的价值。示例代码:

# 导入剪枝函数
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude# 计算两轮之后完成剪枝时对应的迭代次数end_step
batch_size = 128
epochs = 2
validation_split = 0.1  # 10% of training set will be used for validation set.num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs# 定义剪枝模型参数,开始模型从50%稀疏度(权重为0的参数数量百分比),到80%稀疏度
pruning_params = {'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,final_sparsity=0.80,begin_step=0,end_step=end_step)
}model_for_pruning = prune_low_magnitude(model, **pruning_params)# 当使用函数`prune_low_magnitude`包装了一下模型后,需要重新编译一下
model_for_pruning.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])model_for_pruning.summary()logdir = "./logs/mnist_pruning"callbacks = [tfmot.sparsity.keras.UpdatePruningStep(),tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]model_for_pruning.fit(train_images, train_labels,batch_size=batch_size, epochs=epochs, validation_split=validation_split,callbacks=callbacks)
# --------------------------------------------------
# 评估模型,对比剪枝前后模型的准确率变化
# 经过剪枝,这里有一个小的准确率下降,和没有进行剪枝相比的话
# --------------------------------------------------_, model_for_pruning_accuracy = model_for_pruning.evaluate(test_images, test_labels, verbose=0)print('Baseline test accuracy:', baseline_model_accuracy)
print('Pruned test accuracy:', model_for_pruning_accuracy)

结构化剪枝

结构化剪枝改变了网络结构,即网络层输出元素个数,比如卷积核的减少会影响特征图数量。
在下面的例子中是基于选择的模型层做剪枝,所以需要指出哪些层去做结构化剪枝。比如剪枝第二个卷积层和第一个全连接层,剪枝策略为pruning_params_2_by_4,表示该层剪枝比例为2 / 4,即该层保留一半(2/4)的权重,而将另一半设为零。
注意:第一个卷积层不能被结构化剪枝。要是结构化剪枝的话,应该至少大于一个input channels(本例所用图片为单通道灰度图),所以我们对第一个卷积层使用随机剪枝。

model = keras.Sequential([prune_low_magnitude(keras.layers.Conv2D(32, 5, padding='same', activation='relu',input_shape=(28, 28, 1),name="pruning_sparsity_0_5"),**pruning_params_sparsity_0_5),keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same'),prune_low_magnitude(keras.layers.Conv2D(64, 5, padding='same',name="structural_pruning"),**pruning_params_2_by_4),keras.layers.BatchNormalization(),keras.layers.ReLU(),keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same'),keras.layers.Flatten(),prune_low_magnitude(keras.layers.Dense(1024, activation='relu',name="structural_pruning_dense"),**pruning_params_2_by_4),keras.layers.Dropout(0.4),keras.layers.Dense(10)
])

哪些层的参数更容易被剪掉

因为卷积层(conv)中的参数相比全连接层(fc)来说参数量少,所以卷积层参数的压缩比没有全连接层参数的压缩比大。换句话说,就是卷积层参数更加敏感,剪掉对准确率影响相对更大。越靠后的卷积层或卷积层之后的那些全连接层往往参数越容易被剪掉。

剪枝效果

  • 一般50%-70%左右的稀疏性,准确率降低幅度并不大
  • 剪枝是独立于量化技巧,通常与量化配合效果不错
  • 可以通过微调尝试不同的参数组合

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • libevent之android与鸿蒙编译过程
  • H3C M-LAG与双活网关接口结合应用场景实验
  • 数据结构-链表-第二天
  • elasticsearch的高亮查询三种模式查询及可能存在的问题
  • 数据结构----双向链表
  • linux笔记1
  • 删除 Docker 容器的日志文件
  • 线程通信【详解】
  • 使用DOM破坏启动xss
  • 机器学习-识别手写数字
  • 网络编程day8
  • 系统编程 网络 基于tcp协议
  • JavaScript_10_练习:轮播图
  • 深度学习--RNN以及RNN的延伸
  • 「数组」数组双指针算法合集:二路合并|逆向合并|快慢去重|对撞指针 / LeetCode 88|26|11(C++)
  • 【140天】尚学堂高淇Java300集视频精华笔记(86-87)
  • 【347天】每日项目总结系列085(2018.01.18)
  • 【挥舞JS】JS实现继承,封装一个extends方法
  • Android开发 - 掌握ConstraintLayout(四)创建基本约束
  • Apache Pulsar 2.1 重磅发布
  • If…else
  • JDK 6和JDK 7中的substring()方法
  • node入门
  • oschina
  • 阿里云前端周刊 - 第 26 期
  • 百度小程序遇到的问题
  • 力扣(LeetCode)21
  • 前端技术周刊 2019-01-14:客户端存储
  • 如何设计一个微型分布式架构?
  • 适配mpvue平台的的微信小程序日历组件mpvue-calendar
  • 移动互联网+智能运营体系搭建=你家有金矿啊!
  • 译米田引理
  • 主流的CSS水平和垂直居中技术大全
  • Nginx惊现漏洞 百万网站面临“拖库”风险
  • ​​​​​​​​​​​​​​Γ函数
  • ​​​​​​​Installing ROS on the Raspberry Pi
  • ​【经验分享】微机原理、指令判断、判断指令是否正确判断指令是否正确​
  • # windows 安装 mysql 显示 no packages found 解决方法
  • # 手柄编程_北通阿修罗3动手评:一款兼具功能、操控性的电竞手柄
  • #{} 和 ${}区别
  • #android不同版本废弃api,新api。
  • #define
  • #php的pecl工具#
  • #我与Java虚拟机的故事#连载09:面试大厂逃不过的JVM
  • #我与Java虚拟机的故事#连载10: 如何在阿里、腾讯、百度、及字节跳动等公司面试中脱颖而出...
  • (4.10~4.16)
  • (Arcgis)Python编程批量将HDF5文件转换为TIFF格式并应用地理转换和投影信息
  • (arch)linux 转换文件编码格式
  • (pojstep1.1.2)2654(直叙式模拟)
  • (板子)A* astar算法,AcWing第k短路+八数码 带注释
  • (附源码)springboot美食分享系统 毕业设计 612231
  • (附源码)ssm考生评分系统 毕业设计 071114
  • (一) 初入MySQL 【认识和部署】
  • (转)iOS字体
  • (转)关于pipe()的详细解析