当前位置: 首页 > news >正文

Paddlets时间序列集成模型回测实战:MLPRegressor、NHiTSModel与RNNBlockRegressor

好的,我们继续深入理解代码的每个部分。以下是每个主要模块的详细解释:

1. 导入模块和库

import json
import os
import glob
import pandas as pd
from tqdm import tqdm
from paddlets.datasets import TSDataset
from paddlets.transform import StandardScaler
from paddlets.models.forecasting import MLPRegressor, NHiTSModel, RNNBlockRegressor
from paddlets.ensemble import WeightingEnsembleForecaster
import ray
  • json: 用于处理JSON文件。
  • os: 用于处理文件和目录操作。
  • glob: 用于获取匹配特定模式的文件路径。
  • pandas: 用于数据处理和分析,尤其是表格数据。
  • tqdm: 用于显示进度条,帮助跟踪循环的进度。
  • paddlets: 时间序列预测相关的库。
  • ray: 用于并行计算的库。

2. 定义和创建目录

dirs = ["forecasting_all_result_center","pic_forecasting_center","model_forecasting_center_2048_a_b_5_100","best_forecasting_param_center"
]for dir_name in dirs:os.makedirs(dir_name, exist_ok=True)
  • dirs: 定义了多个用于存储不同类型结果的目录。
  • os.makedirs: 创建目录,如果目录已存在,则不报错。

3. 加载股票映射

with open("./stock_mapping.json", "r") as f:stock_mapping = json.load(f)
  • stock_mapping.json文件中加载股票的映射关系,以便后续使用。

4. 加载CSV数据

csv_paths = glob.glob(os.path.join("./tu_share_data_day", "*.csv"))
sum_dam_data = []for csv_path in tqdm(csv_paths):new_data = pd.read_csv(csv_path)if len(new_data) < 2048 or new_data.iloc[0, 2] < 5 or new_data.iloc[0, 2] > 100:continuenew_data = new_data[::-1].iloc[:2048]new_data['index_new'] = range(1, len(new_data) + 1)sum_dam_data.append(new_data)
  • 使用glob获取所有CSV文件路径,并遍历每个文件。
  • 读取数据并进行过滤,确保符合条件(如数据长度、价格区间)。
  • 将数据反转并取最后2048条,添加索引列。

5. 构建时间序列数据集

dam_data = pd.concat(sum_dam_data)dataset = TSDataset.load_from_dataframe(dam_data,group_id='ts_code',time_col="index_new",target_cols=['high', 'low']
)
  • 将所有符合条件的数据合并成一个DataFrame。
  • 使用TSDataset将数据转换为时间序列格式,指定分组、时间列和目标列。

6. 初始化标准化器

scaler = StandardScaler().fit(dataset)
dataset = scaler.transform(dataset)
  • 使用StandardScaler对数据进行标准化处理,使模型训练更加稳定。

7. 初始化Ray进行并行计算

ray.init()
  • 初始化Ray,使得后续的计算能够并行执行。

8. 定义并行处理函数

@ray.remote
def process_csv_file(csv_path, scaler):...
  • 使用@ray.remote装饰器定义一个可以被Ray并行化的函数,处理每个CSV文件的逻辑。

9. 设置模型参数和加载模型

nhits_params = {'sampling_stride': 24, 'eval_metrics': ["mse", "mae"], 'batch_size': 32, 'max_epochs': 100, 'patience': 10}
rnn_params = nhits_params.copy()
mlp_params = nhits_params.copy()
mlp_params['use_bn'] =

相关文章:

  • 15 Midjourney从零到商用·实战篇:建筑设计与室内设计
  • 8.使用 VSCode 过程中的英语积累 - Help 菜单(每一次重点积累 5 个单词)
  • (28)oracle数据迁移(容器)-部署包资源
  • OpenCV视频I/O(7)视频采集类VideoCapture之初始化视频捕获设备或打开一个视频文件函数open()的使用
  • 【HTML|第1期】HTML5视频(Video)元素详解:从起源到应用
  • 智影S100户外直接采集输出的是绝对坐标吗?内业是否需要控制点进行配准?
  • access mysql
  • 星辰计划04-深入理解kafka的消息存储和索引设计
  • SpringBoot的概述与搭建
  • SIMETRIX 探头和测量
  • [java][gps]GPS坐标系转换
  • JVM总结
  • Python in Excel作图分析实战!
  • JAVA入门1——理论+helloworld
  • Word导出样式模板,应用到其他所有word
  • [Vue CLI 3] 配置解析之 css.extract
  • 【笔记】你不知道的JS读书笔记——Promise
  • Apache Zeppelin在Apache Trafodion上的可视化
  • download使用浅析
  • ES10 特性的完整指南
  • GitUp, 你不可错过的秀外慧中的git工具
  • HashMap ConcurrentHashMap
  • HTTP中GET与POST的区别 99%的错误认识
  • IndexedDB
  • Java知识点总结(JavaIO-打印流)
  • laravel with 查询列表限制条数
  • Linux链接文件
  • MySQL用户中的%到底包不包括localhost?
  • NLPIR语义挖掘平台推动行业大数据应用服务
  • Python十分钟制作属于你自己的个性logo
  • 创建一个Struts2项目maven 方式
  • 从0实现一个tiny react(三)生命周期
  • 多线程 start 和 run 方法到底有什么区别?
  • 七牛云 DV OV EV SSL 证书上线,限时折扣低至 6.75 折!
  • 前端技术周刊 2019-01-14:客户端存储
  • 如何使用 JavaScript 解析 URL
  • 如何使用Mybatis第三方插件--PageHelper实现分页操作
  • 微信开源mars源码分析1—上层samples分析
  • 小程序button引导用户授权
  • 正则学习笔记
  • 阿里云ACE认证之理解CDN技术
  • ​Z时代时尚SUV新宠:起亚赛图斯值不值得年轻人买?
  • # 再次尝试 连接失败_无线WiFi无法连接到网络怎么办【解决方法】
  • #控制台大学课堂点名问题_课堂随机点名
  • #职场发展#其他
  • (003)SlickEdit Unity的补全
  • (4)logging(日志模块)
  • (delphi11最新学习资料) Object Pascal 学习笔记---第5章第5节(delphi中的指针)
  • (LeetCode 49)Anagrams
  • (LeetCode C++)盛最多水的容器
  • (pytorch进阶之路)CLIP模型 实现图像多模态检索任务
  • (Redis使用系列) Springboot 使用redis实现接口幂等性拦截 十一
  • (安全基本功)磁盘MBR,分区表,活动分区,引导扇区。。。详解与区别
  • (二)延时任务篇——通过redis的key监听,实现延迟任务实战
  • (附源码)springboot 智能停车场系统 毕业设计065415