当前位置: 首页 > news >正文

动画展示梯度下降(二维)

动画展示梯度下降(二维)

flyfish

在这里插入图片描述

  1. 类初始化GradientDescentAnimation 类用于初始化梯度下降算法,包括目标函数、学习率、训练轮数等参数。

  2. 执行梯度下降perform_gradient_descent 方法执行梯度下降算法,并记录每次迭代的局部最小值和梯度。

  3. 创建动画create_animation 方法负责创建和展示梯度下降的动画

import sympy as sp
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation# 设置中文字体支持和坐标轴负号显示
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = Falseclass GradientDescentAnimation:def __init__(self, expr, x_values, learning_rate, training_epochs):"""初始化梯度下降动画类:param expr: 表达式,使用 sympy 表示的目标函数:param x_values: numpy 数组,x 的取值范围:param learning_rate: float,学习率:param training_epochs: int,训练轮数"""self.expr = expr  # 目标函数的符号表达式self.x_values = x_values  # x 的取值范围self.learning_rate = learning_rate  # 学习率,用于控制梯度下降的步长self.training_epochs = training_epochs  # 训练轮数,即迭代次数self.x = sp.symbols("x")  # 定义符号变量 x# 将符号表达式转换为可用于数值计算的函数self.func = sp.lambdify(self.x, self.expr, "numpy")# 计算目标函数的导数self.deriv = sp.diff(self.expr)# 将导数表达式转换为可用于数值计算的函数self.deriv_func = sp.lambdify(self.x, self.deriv, "numpy")# 创建一个数组用于存储每个 epoch 的局部最小值和梯度self.model_params = np.zeros((training_epochs, 2)) def perform_gradient_descent(self):"""执行梯度下降算法,记录每次更新的局部最小值和梯度"""# 随机选择一个初始局部最小值self.local_min = np.random.choice(self.x_values, 1)print(f"初始局部最小值: {self.local_min}")for i in range(self.training_epochs):# 计算当前局部最小值处的梯度grad = self.deriv_func(self.local_min)# 根据梯度下降更新局部最小值self.local_min = self.local_min - (grad * self.learning_rate)# 记录当前的局部最小值self.model_params[i, 0] = self.local_min[0]# 记录当前的梯度self.model_params[i, 1] = grad[0] print(f"{self.training_epochs}轮训练后的局部最小值: {self.local_min}")def create_animation(self):"""创建梯度下降过程的动画"""grad_fig, ax = plt.subplots(figsize=(12, 6), dpi=100)# 绘制目标函数和其导数的图像ax.plot(self.x_values, self.func(self.x_values), label=f"${sp.latex(self.expr)}$")ax.plot(self.x_values, self.deriv_func(self.x_values), label=f"dy/dx ${sp.latex(self.deriv)}$")# 设置初始图表标题、网格、坐标轴标签等ax.set_title(f"初始局部最小值: {self.local_min[0]}")plt.axhline(0, color='white', linewidth=0.5)plt.axvline(0, color='white', linewidth=0.5)plt.grid(color="gray", linestyle="--", linewidth=0.5)plt.xlabel("x")plt.ylabel("f(x)")plt.legend()def tangent_line(x, x1, y1):"""计算切线:param x: numpy 数组,x 的取值范围:param x1: float,切点的x坐标:param y1: float,切点的y坐标:return: numpy 数组,切线在每个 x 值的对应 y 值"""return self.deriv_func(x1) * (x - x1) + y1# 初始化图像元素title = ax.text(0.5, 1.05, '', transform=ax.transAxes, ha="center", fontweight="bold")initial_local_min = self.model_params[0, 0]# 用于标记局部最小值的散点local_min_scat = ax.scatter(initial_local_min, self.func(initial_local_min), color="orange")# 计算初始切线的范围initial_tangent_range = np.linspace(initial_local_min - 0.5, initial_local_min + 0.5, 10)# 绘制初始切线tangent_plot = ax.plot(initial_tangent_range, tangent_line(x=initial_tangent_range, x1=initial_local_min, y1=self.func(initial_local_min)), linestyle="--", color="orange", linewidth=2)[0]# 添加注释以显示当前梯度grad_annotation = ax.annotate('Gradient={0:2f}'.format(self.deriv_func(initial_local_min)),xy=(initial_local_min, self.func(initial_local_min)), xytext=(initial_local_min, self.func(initial_local_min) + 1),arrowprops={'arrowstyle': "-", 'facecolor': 'orange'},textcoords='data', color='orange', rotation=20, fontweight="bold")def drawframe(epoch):"""绘制每一帧:param epoch: int,当前的训练轮数"""title.set_text(f'Epoch={epoch:4d}, 当前局部最小值: {self.model_params[epoch, 0]:.2f}')# 更新局部最小值的位置x1 = self.model_params[epoch, 0]y1 = self.func(self.model_params[epoch, 0])local_min_scat.set_offsets((x1, y1))# 更新切线tangent_range = np.linspace(x1 - 0.5, x1 + 0.5, 10)tangent_values = tangent_line(x=tangent_range, x1=x1, y1=y1)tangent_plot.set_xdata(tangent_range)tangent_plot.set_ydata(tangent_values)# 更新注释grad_annotation.set_position((x1, y1 + 1))grad_annotation.xy = (x1, y1)grad_annotation.set_text('Gradient={0:2f}'.format(self.model_params[epoch, 1]))return local_min_scat,# 创建动画anim = animation.FuncAnimation(grad_fig, drawframe, frames=self.training_epochs, repeat=False, interval=500, blit=True)return anim# 使用示例
x = sp.symbols("x")  # 在此处定义符号变量 x
expr = 3 * x ** 2 - 3 * x + 4x_values = np.linspace(-2, 2, 20)
grad_anim = GradientDescentAnimation(expr=expr, x_values=x_values, learning_rate=0.001, training_epochs=100)
grad_anim.perform_gradient_descent()
anim = grad_anim.create_animation()
anim.save('gradient.gif', writer=animation.PillowWriter(fps=30))  # 保存动画为 GIF 文件

代码解释

self.x = sp.symbols("x")  # 定义符号变量 x
# 将符号表达式转换为可用于数值计算的函数
self.func = sp.lambdify(self.x, self.expr, "numpy")
# 计算目标函数的导数
self.deriv = sp.diff(self.expr)
# 将导数表达式转换为可用于数值计算的函数
self.deriv_func = sp.lambdify(self.x, self.deriv, "numpy")
  1. self.x = sp.symbols("x") :
    sp.symbols("x"):这是在 SymPy 中创建一个符号变量 x。符号变量用于表示数学表达式中的变量,可以被用于符号计算。
    self.x:将符号变量 x 存储为类的一个属性,使其可以在其他方法中使用。

  2. self.func = sp.lambdify(self.x, self.expr, "numpy") :
    sp.lambdify(self.x, self.expr, "numpy"):将符号表达式 self.expr 转换为一个可以在 Python 中进行数值计算的函数。这是通过 SymPy 的 lambdify 函数实现的。
    self.expr 是传递给 GradientDescentAnimation 类的目标函数符号表达式。
    "numpy" 参数指定使用 NumPy 库进行数值计算,这样生成的函数可以对 NumPy 数组进行快速运算。
    self.func:将生成的数值计算函数存储为类的一个属性。这个函数可以接受数值或数组作为输入,并返回计算结果。

  3. self.deriv = sp.diff(self.expr) :
    sp.diff(self.expr):对符号表达式 self.expr 进行求导运算,返回该表达式关于变量 x 的导数。
    self.deriv:将求得的导数表达式存储为类的一个属性。

  4. self.deriv_func = sp.lambdify(self.x, self.deriv, "numpy") :
    sp.lambdify(self.x, self.deriv, "numpy"):将导数表达式 self.deriv 转换为一个可以进行数值计算的函数。

self.deriv_func:将生成的导数计算函数存储为类的一个属性。这个函数可以接受数值或数组作为输入,并返回导数的计算结果。

SymPy 是一个用于符号数学计算的 Python 库,支持多种数学操作和功能。以下是一些常用的 SymPy 用法示例:

1. 定义符号变量

import sympy as sp# 创建符号变量
x = sp.symbols('x')
y, z = sp.symbols('y z')# 使用多个变量
a, b, c = sp.symbols('a b c')

2. 表达式的创建和操作

# 定义一个符号表达式
expr = x**2 + 2*x + 1# 扩展表达式
expanded_expr = sp.expand((x + 1)**2)
print(expanded_expr)  # 输出: x**2 + 2*x + 1# 简化表达式
simplified_expr = sp.simplify(x**2 + 2*x + 1)
print(simplified_expr)  # 输出: x**2 + 2*x + 1

3. 求导

# 对表达式求导
derivative = sp.diff(expr, x)
print(derivative)  # 输出: 2*x + 2

4. 求积分

# 不定积分
indefinite_integral = sp.integrate(expr, x)
print(indefinite_integral)  # 输出: x**3/3 + x**2 + x# 定积分
definite_integral = sp.integrate(expr, (x, 0, 1))
print(definite_integral)  # 输出: 7/3

5. 解方程

# 解一元方程
solution = sp.solve(x**2 + 2*x + 1, x)
print(solution)  # 输出: [-1]# 解联立方程
eq1 = sp.Eq(2*x + y, 1)
eq2 = sp.Eq(x - y, 3)
solution_system = sp.solve((eq1, eq2), (x, y))
print(solution_system)  # 输出: {x: 2, y: -1}

6. 极限

# 求极限
limit_result = sp.limit(sp.sin(x)/x, x, 0)
print(limit_result)  # 输出: 1

7. 级数展开(泰勒展开)

# 泰勒展开
taylor_series = sp.series(sp.sin(x), x, 0, 5)
print(taylor_series)  # 输出: x - x**3/6 + x**5/120 + O(x**6)

8. 数值计算

使用 lambdify 将符号表达式转换为可用于数值计算的函数:

# 使用 lambdify 将符号表达式转换为数值计算函数
numerical_func = sp.lambdify(x, expr, 'numpy')# 在数值上进行计算
result = numerical_func(2)  # 传入具体数值
print(result)  # 输出: 9

9. 符号矩阵和线性代数

# 矩阵操作
matrix = sp.Matrix([[1, 2], [3, 4]])
determinant = matrix.det()
print(determinant)  # 输出: -2# 逆矩阵
inverse_matrix = matrix.inv()
print(inverse_matrix)
# 输出:
# Matrix([
# [ -2,  1],
# [3/2, -1/2]])

10. Plotting (绘图)

虽然 SymPy 提供了一些简单的绘图功能,但通常会使用 Matplotlib 进行更复杂的绘图:

import sympy.plotting as syp# 绘制函数图像
syp.plot(expr, (x, -5, 5), title="Function Plot", ylabel="f(x)")

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • XSS的DOM破坏
  • Linux·权限与工具-yum与vim
  • 说一下Android中的IdleHandler
  • 每日一问:Kafka消息丢失与堆积问题分析与解决方案
  • MFC在OPENGL循环绘制中添加进度条控件后运行速度变慢
  • 设计模式 - 装饰器模式
  • 在IntelliJ IDEA中使用Git推送项目
  • [手机Linux PostmarketOS]五, docker安装和使用
  • Unity如何使用Spine动画导出的动画
  • webrtc学习笔记3
  • HTTP的认证方式
  • C# 使用泛型协变性
  • c语言----取反用什么符号
  • qt笔记之纯qml项目详解
  • ant design pro 如何去保存颜色
  • [译]前端离线指南(上)
  • CentOS7 安装JDK
  • echarts花样作死的坑
  • eclipse(luna)创建web工程
  • extjs4学习之配置
  • JavaScript 是如何工作的:WebRTC 和对等网络的机制!
  • js对象的深浅拷贝
  • Linux快速配置 VIM 实现语法高亮 补全 缩进等功能
  • magento2项目上线注意事项
  • MySQL用户中的%到底包不包括localhost?
  • open-falcon 开发笔记(一):从零开始搭建虚拟服务器和监测环境
  • STAR法则
  • Webpack 4 学习01(基础配置)
  • yii2中session跨域名的问题
  • 大数据与云计算学习:数据分析(二)
  • 关于extract.autodesk.io的一些说明
  • 回顾2016
  • 基于游标的分页接口实现
  • 看域名解析域名安全对SEO的影响
  • 浅谈JavaScript的面向对象和它的封装、继承、多态
  • 日剧·日综资源集合(建议收藏)
  • 如何在 Tornado 中实现 Middleware
  •  一套莫尔斯电报听写、翻译系统
  • 原生 js 实现移动端 Touch 滑动反弹
  • 【云吞铺子】性能抖动剖析(二)
  • 不要一棍子打翻所有黑盒模型,其实可以让它们发挥作用 ...
  • ​LeetCode解法汇总1276. 不浪费原料的汉堡制作方案
  • ​力扣解法汇总946-验证栈序列
  • # 职场生活之道:善于团结
  • #我与Java虚拟机的故事#连载07:我放弃了对JVM的进一步学习
  • $nextTick的使用场景介绍
  • (C语言)求出1,2,5三个数不同个数组合为100的组合个数
  • (WSI分类)WSI分类文献小综述 2024
  • (板子)A* astar算法,AcWing第k短路+八数码 带注释
  • (动手学习深度学习)第13章 计算机视觉---微调
  • (含react-draggable库以及相关BUG如何解决)固定在左上方某盒子内(如按钮)添加可拖动功能,使用react hook语法实现
  • (十七)devops持续集成开发——使用jenkins流水线pipeline方式发布一个微服务项目
  • *1 计算机基础和操作系统基础及几大协议
  • .MSSQLSERVER 导入导出 命令集--堪称经典,值得借鉴!
  • .Net 6.0--通用帮助类--FileHelper