当前位置: 首页 > news >正文

RNN股票预测(Pytorch版)

任务:基于zgpa_train.csv数据,建立RNN模型,预测股价
1.完成数据预处理,将序列数据转化为可用于RNN输入的数据
2.对新数据zgpa_test.csv进行预测,可视化结果
3.存储预测结果,并观察局部预测结果
备注:模型结构:单层RNN,输出有5个神经元,每次使用前8个数据预测第9个数据
参考视频:吹爆!3小时搞懂!【RNN循环神经网络+时间序列LSTM深度学习模型】学不会UP主下跪!
up主用的Keras,自己用Pytorch尝试了一下,代码如下:

import pandas as pd
import numpy as np
import torch
from torch import nn
from matplotlib import pyplot as plt
data = pd.read_csv('zgpa_train.csv')
# loc 通过行索引 “Index” 中的具体值来取行数据
# 取出开盘价
price = data.loc[:,'close']# 归一化
price_norm = price/max(price)
# 开盘价折线图
# fig1 = plt.figure(figsize=(10, 6))
# plt.plot(price)
# plt.title('close price')
# plt.xlabel('time')
# plt.ylabel('price')
# plt.show()# 提取数据 每次使用前8个数据来预测第九个数据
def extract_data(data, time_step):x = []y = []for i in range(len(data)- time_step):x.append([a for a in data[i:i+time_step]])y.append(data[i + time_step])x = np.array(x)x = x.reshape(x.shape[0], x.shape[1], 1)x = torch.tensor(x, dtype=torch.float32)y = torch.tensor(y, dtype=torch.float32)return x, y
time_step = 8
x, y = extract_data(price_norm,time_step)
# print(x)
# print(y)
class RNN(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers):super(RNN,self).__init__()self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first = True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):out, _ = self.rnn(x)# print(out)out = self.fc(out[:, -1, :])out = out.squeeze(1)return out
# 定义模型参数
input_size = 1 # 输入特征的维度
hidden_size = 64 # 隐藏层的维度
output_size = 1 # 输出特征的维度
num_layers = 1 # RNN的层数# 创建模型
model = RNN(input_size, hidden_size, output_size, num_layers)# 定义损失函数和优化器
criterion = nn.MSELoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练模型
epochs = 200
for epoch in range(epochs):optimizer.zero_grad()# outputs = model(x.unsqueeze(2))outputs = model(x)loss = criterion(outputs, y)loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')
# 进行预测 数据很少这里就不先保存模型再预测了
model.eval()
with torch.no_grad():y_train_predict = model(x) * max(price)
y_train = [i * max(price) for i in y]
# print(y_train_predict)
y_train_predict = y_train_predict.cpu().numpy()
y_train = np.array(y_train)
fig2 = plt.figure(figsize=(10, 6))
plt.plot(y_train_predict, label='Predicted', color='blue')
plt.plot(y_train, label='True', color='red', alpha=0.6)
plt.title('Predicted vs True Values')
plt.xlabel('time')
plt.ylabel('price')
plt.legend()
plt.show()# 测试集
data_test = pd.read_csv('zgpa_test.csv')
price_test = data_test.loc[:,'close']
price_test_norm = price_test/max(price)
x_test,y_test = extract_data(price_test_norm,time_step)
with torch.no_grad():y_test_predict = model(x_test) * max(price)
y_test = [i * max(price) for i in y_test]
# print(y_train_predict)
y_test_predict = y_test_predict.cpu().numpy()
y_test = np.array(y_test)
fig3 = plt.figure(figsize=(10, 6))
plt.plot(y_test_predict, label='Predicted', color='blue')
plt.plot(y_test, label='True', color='red', alpha=0.6)
plt.title('Predicted vs True Values (Test Set)')
plt.xlabel('time')
plt.ylabel('price')
plt.legend()
plt.show()# 存储数据
result_y_test = np.array(y_test).reshape(-1, 1) # 若干行,1列
result_y_test_predict = y_test_predict.reshape(-1, 1)
print(result_y_test.shape, result_y_test_predict.shape)
result = np.concatenate((result_y_test, result_y_test_predict), axis=1)
print(result.shape)
result = pd.DataFrame(result, columns=['real_price_test', 'predict_price_test'])
result.to_csv('zgpa_predict_test.csv')

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 【AI视频】复刻抖音爆款AI数字人作品初体验
  • TS - tsconfig.json 和 tsconfig.node.json 的关系,如何在TS 中使用 JS 不报错
  • 【Petri网导论学习笔记】Petri网导论入门学习(三)
  • spring模块(六)spring event事件(3)广播与异步问题
  • 【时时三省】tessy 单元测试 集成测试 专栏 文章阅读说明
  • 利用AI驱动智能BI数据可视化-深度评测Amazon Quicksight(三)
  • UE5安卓项目打包安装
  • windows安装docker、elasticsearch、kibana、cerebro、logstash
  • QT--connect的使用
  • Java 集合(数据结构)面试题总结
  • 【MySQL】了解并操作MySQL的缓存配置与信息
  • python AssertionError: Torch not compiled with CUDA enabled
  • 浅谈Spring Cloud:认识微服务
  • vue3+ts+vite搭建脚手架(二)配置eslintprettier
  • SpringBoot接口开发总结
  • 网络传输文件的问题
  • Android开发 - 掌握ConstraintLayout(四)创建基本约束
  • angular2 简述
  • gf框架之分页模块(五) - 自定义分页
  • Invalidate和postInvalidate的区别
  • JS题目及答案整理
  • Mysql数据库的条件查询语句
  • php ci框架整合银盛支付
  • PV统计优化设计
  • python3 使用 asyncio 代替线程
  • SegmentFault 技术周刊 Vol.27 - Git 学习宝典:程序员走江湖必备
  • Spring Cloud Alibaba迁移指南(一):一行代码从 Hystrix 迁移到 Sentinel
  • Vue--数据传输
  • 讲清楚之javascript作用域
  • 使用API自动生成工具优化前端工作流
  • 使用common-codec进行md5加密
  • 从如何停掉 Promise 链说起
  • 如何在招聘中考核.NET架构师
  • ​Base64转换成图片,android studio build乱码,找不到okio.ByteString接腾讯人脸识别
  • # Java NIO(一)FileChannel
  • #Linux杂记--将Python3的源码编译为.so文件方法与Linux环境下的交叉编译方法
  • (12)Hive调优——count distinct去重优化
  • (day 2)JavaScript学习笔记(基础之变量、常量和注释)
  • (delphi11最新学习资料) Object Pascal 学习笔记---第5章第5节(delphi中的指针)
  • (pycharm)安装python库函数Matplotlib步骤
  • (补充)IDEA项目结构
  • (二刷)代码随想录第16天|104.二叉树的最大深度 559.n叉树的最大深度● 111.二叉树的最小深度● 222.完全二叉树的节点个数
  • (附源码)ssm考生评分系统 毕业设计 071114
  • (六)DockerCompose安装与配置
  • (一)80c52学习之旅-起始篇
  • (原創) 系統分析和系統設計有什麼差別? (OO)
  • (转)Linux NTP配置详解 (Network Time Protocol)
  • ***详解账号泄露:全球约1亿用户已泄露
  • .bat批处理(三):变量声明、设置、拼接、截取
  • .NET Core 和 .NET Framework 中的 MEF2
  • .NET 除了用 Task 之外,如何自己写一个可以 await 的对象?
  • .net6解除文件上传限制。Multipart body length limit 16384 exceeded
  • .net和jar包windows服务部署
  • .net后端程序发布到nignx上,通过nginx访问
  • .net连接MySQL的方法