动画展示梯度下降(二维)
动画展示梯度下降(二维)
flyfish
-
类初始化 :
GradientDescentAnimation
类用于初始化梯度下降算法,包括目标函数、学习率、训练轮数等参数。 -
执行梯度下降 :
perform_gradient_descent
方法执行梯度下降算法,并记录每次迭代的局部最小值和梯度。 -
创建动画 :
create_animation
方法负责创建和展示梯度下降的动画
import sympy as sp
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation# 设置中文字体支持和坐标轴负号显示
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = Falseclass GradientDescentAnimation:def __init__(self, expr, x_values, learning_rate, training_epochs):"""初始化梯度下降动画类:param expr: 表达式,使用 sympy 表示的目标函数:param x_values: numpy 数组,x 的取值范围:param learning_rate: float,学习率:param training_epochs: int,训练轮数"""self.expr = expr # 目标函数的符号表达式self.x_values = x_values # x 的取值范围self.learning_rate = learning_rate # 学习率,用于控制梯度下降的步长self.training_epochs = training_epochs # 训练轮数,即迭代次数self.x = sp.symbols("x") # 定义符号变量 x# 将符号表达式转换为可用于数值计算的函数self.func = sp.lambdify(self.x, self.expr, "numpy")# 计算目标函数的导数self.deriv = sp.diff(self.expr)# 将导数表达式转换为可用于数值计算的函数self.deriv_func = sp.lambdify(self.x, self.deriv, "numpy")# 创建一个数组用于存储每个 epoch 的局部最小值和梯度self.model_params = np.zeros((training_epochs, 2)) def perform_gradient_descent(self):"""执行梯度下降算法,记录每次更新的局部最小值和梯度"""# 随机选择一个初始局部最小值self.local_min = np.random.choice(self.x_values, 1)print(f"初始局部最小值: {self.local_min}")for i in range(self.training_epochs):# 计算当前局部最小值处的梯度grad = self.deriv_func(self.local_min)# 根据梯度下降更新局部最小值self.local_min = self.local_min - (grad * self.learning_rate)# 记录当前的局部最小值self.model_params[i, 0] = self.local_min[0]# 记录当前的梯度self.model_params[i, 1] = grad[0] print(f"{self.training_epochs}轮训练后的局部最小值: {self.local_min}")def create_animation(self):"""创建梯度下降过程的动画"""grad_fig, ax = plt.subplots(figsize=(12, 6), dpi=100)# 绘制目标函数和其导数的图像ax.plot(self.x_values, self.func(self.x_values), label=f"${sp.latex(self.expr)}$")ax.plot(self.x_values, self.deriv_func(self.x_values), label=f"dy/dx ${sp.latex(self.deriv)}$")# 设置初始图表标题、网格、坐标轴标签等ax.set_title(f"初始局部最小值: {self.local_min[0]}")plt.axhline(0, color='white', linewidth=0.5)plt.axvline(0, color='white', linewidth=0.5)plt.grid(color="gray", linestyle="--", linewidth=0.5)plt.xlabel("x")plt.ylabel("f(x)")plt.legend()def tangent_line(x, x1, y1):"""计算切线:param x: numpy 数组,x 的取值范围:param x1: float,切点的x坐标:param y1: float,切点的y坐标:return: numpy 数组,切线在每个 x 值的对应 y 值"""return self.deriv_func(x1) * (x - x1) + y1# 初始化图像元素title = ax.text(0.5, 1.05, '', transform=ax.transAxes, ha="center", fontweight="bold")initial_local_min = self.model_params[0, 0]# 用于标记局部最小值的散点local_min_scat = ax.scatter(initial_local_min, self.func(initial_local_min), color="orange")# 计算初始切线的范围initial_tangent_range = np.linspace(initial_local_min - 0.5, initial_local_min + 0.5, 10)# 绘制初始切线tangent_plot = ax.plot(initial_tangent_range, tangent_line(x=initial_tangent_range, x1=initial_local_min, y1=self.func(initial_local_min)), linestyle="--", color="orange", linewidth=2)[0]# 添加注释以显示当前梯度grad_annotation = ax.annotate('Gradient={0:2f}'.format(self.deriv_func(initial_local_min)),xy=(initial_local_min, self.func(initial_local_min)), xytext=(initial_local_min, self.func(initial_local_min) + 1),arrowprops={'arrowstyle': "-", 'facecolor': 'orange'},textcoords='data', color='orange', rotation=20, fontweight="bold")def drawframe(epoch):"""绘制每一帧:param epoch: int,当前的训练轮数"""title.set_text(f'Epoch={epoch:4d}, 当前局部最小值: {self.model_params[epoch, 0]:.2f}')# 更新局部最小值的位置x1 = self.model_params[epoch, 0]y1 = self.func(self.model_params[epoch, 0])local_min_scat.set_offsets((x1, y1))# 更新切线tangent_range = np.linspace(x1 - 0.5, x1 + 0.5, 10)tangent_values = tangent_line(x=tangent_range, x1=x1, y1=y1)tangent_plot.set_xdata(tangent_range)tangent_plot.set_ydata(tangent_values)# 更新注释grad_annotation.set_position((x1, y1 + 1))grad_annotation.xy = (x1, y1)grad_annotation.set_text('Gradient={0:2f}'.format(self.model_params[epoch, 1]))return local_min_scat,# 创建动画anim = animation.FuncAnimation(grad_fig, drawframe, frames=self.training_epochs, repeat=False, interval=500, blit=True)return anim# 使用示例
x = sp.symbols("x") # 在此处定义符号变量 x
expr = 3 * x ** 2 - 3 * x + 4x_values = np.linspace(-2, 2, 20)
grad_anim = GradientDescentAnimation(expr=expr, x_values=x_values, learning_rate=0.001, training_epochs=100)
grad_anim.perform_gradient_descent()
anim = grad_anim.create_animation()
anim.save('gradient.gif', writer=animation.PillowWriter(fps=30)) # 保存动画为 GIF 文件
代码解释
self.x = sp.symbols("x") # 定义符号变量 x
# 将符号表达式转换为可用于数值计算的函数
self.func = sp.lambdify(self.x, self.expr, "numpy")
# 计算目标函数的导数
self.deriv = sp.diff(self.expr)
# 将导数表达式转换为可用于数值计算的函数
self.deriv_func = sp.lambdify(self.x, self.deriv, "numpy")
-
self.x = sp.symbols("x")
:
sp.symbols("x")
:这是在 SymPy 中创建一个符号变量x
。符号变量用于表示数学表达式中的变量,可以被用于符号计算。
self.x
:将符号变量x
存储为类的一个属性,使其可以在其他方法中使用。 -
self.func = sp.lambdify(self.x, self.expr, "numpy")
:
sp.lambdify(self.x, self.expr, "numpy")
:将符号表达式self.expr
转换为一个可以在 Python 中进行数值计算的函数。这是通过 SymPy 的lambdify
函数实现的。
self.expr
是传递给GradientDescentAnimation
类的目标函数符号表达式。
"numpy"
参数指定使用 NumPy 库进行数值计算,这样生成的函数可以对 NumPy 数组进行快速运算。
self.func
:将生成的数值计算函数存储为类的一个属性。这个函数可以接受数值或数组作为输入,并返回计算结果。 -
self.deriv = sp.diff(self.expr)
:
sp.diff(self.expr)
:对符号表达式self.expr
进行求导运算,返回该表达式关于变量x
的导数。
self.deriv
:将求得的导数表达式存储为类的一个属性。 -
self.deriv_func = sp.lambdify(self.x, self.deriv, "numpy")
:
sp.lambdify(self.x, self.deriv, "numpy")
:将导数表达式self.deriv
转换为一个可以进行数值计算的函数。
self.deriv_func
:将生成的导数计算函数存储为类的一个属性。这个函数可以接受数值或数组作为输入,并返回导数的计算结果。
SymPy
是一个用于符号数学计算的 Python 库,支持多种数学操作和功能。以下是一些常用的 SymPy
用法示例:
1. 定义符号变量
import sympy as sp# 创建符号变量
x = sp.symbols('x')
y, z = sp.symbols('y z')# 使用多个变量
a, b, c = sp.symbols('a b c')
2. 表达式的创建和操作
# 定义一个符号表达式
expr = x**2 + 2*x + 1# 扩展表达式
expanded_expr = sp.expand((x + 1)**2)
print(expanded_expr) # 输出: x**2 + 2*x + 1# 简化表达式
simplified_expr = sp.simplify(x**2 + 2*x + 1)
print(simplified_expr) # 输出: x**2 + 2*x + 1
3. 求导
# 对表达式求导
derivative = sp.diff(expr, x)
print(derivative) # 输出: 2*x + 2
4. 求积分
# 不定积分
indefinite_integral = sp.integrate(expr, x)
print(indefinite_integral) # 输出: x**3/3 + x**2 + x# 定积分
definite_integral = sp.integrate(expr, (x, 0, 1))
print(definite_integral) # 输出: 7/3
5. 解方程
# 解一元方程
solution = sp.solve(x**2 + 2*x + 1, x)
print(solution) # 输出: [-1]# 解联立方程
eq1 = sp.Eq(2*x + y, 1)
eq2 = sp.Eq(x - y, 3)
solution_system = sp.solve((eq1, eq2), (x, y))
print(solution_system) # 输出: {x: 2, y: -1}
6. 极限
# 求极限
limit_result = sp.limit(sp.sin(x)/x, x, 0)
print(limit_result) # 输出: 1
7. 级数展开(泰勒展开)
# 泰勒展开
taylor_series = sp.series(sp.sin(x), x, 0, 5)
print(taylor_series) # 输出: x - x**3/6 + x**5/120 + O(x**6)
8. 数值计算
使用 lambdify
将符号表达式转换为可用于数值计算的函数:
# 使用 lambdify 将符号表达式转换为数值计算函数
numerical_func = sp.lambdify(x, expr, 'numpy')# 在数值上进行计算
result = numerical_func(2) # 传入具体数值
print(result) # 输出: 9
9. 符号矩阵和线性代数
# 矩阵操作
matrix = sp.Matrix([[1, 2], [3, 4]])
determinant = matrix.det()
print(determinant) # 输出: -2# 逆矩阵
inverse_matrix = matrix.inv()
print(inverse_matrix)
# 输出:
# Matrix([
# [ -2, 1],
# [3/2, -1/2]])
10. Plotting (绘图)
虽然 SymPy
提供了一些简单的绘图功能,但通常会使用 Matplotlib
进行更复杂的绘图:
import sympy.plotting as syp# 绘制函数图像
syp.plot(expr, (x, -5, 5), title="Function Plot", ylabel="f(x)")