当前位置: 首页 > news >正文

Llama-Factory的baichuan2微调

Llama-Factory:https://github.com/hiyouga/LLaMA-Factory/tree/main

请使用 

--quantization_bit 4/8

 来启用 QLoRA 训练。

默认模块应作为 --lora_target 参数的默认值,可使用 --lora_target all 参数指定全部模块。对于所有“基座”(Base)模型,--template 参数可以是 default, alpaca, vicuna 等任意值。
但“对话”(Chat)模型请务必使用对应的模板。

一、单GPU训练

1.预训练

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \--stage pt \                               # Pre——Training预训练模式--model_name_or_path path_to_llama_model \ # 模型地址--do_train \                               # 表示进行训练--dataset wiki_demo \                      # 使用的数据集--finetuning_type lora \                   # 微调的方法--lora_target W_pack \                     # LoRA作用模块:Baichuan为W_pack--output_dir path_to_pt_checkpoint \       # 断点保存:保存模型断点的位置--overwrite_cache \                        # 表示是否覆盖缓存文件--per_device_train_batch_size 4 \          # 批处理大小:每块 GPU 上处理的样本数量--gradient_accumulation_steps 4 \          # 梯度累积:梯度累积的步数(节省显存的方法)--lr_scheduler_type cosine \               # 学习率调节器:采用的学习率调节器名称--logging_steps 10 \                       # 日志间隔:每两次日志输出间的更新步数--save_steps 1000 \                        # 保存间隔:每两次断点保存间的更新步数--learning_rate 5e-5 \                     # 学习率:AdamW优化器的初始学习率--num_train_epochs 3.0 \                   # 训练轮数:需要执行的训练总轮数--plot_loss \                              # 绘制损失函数图--fp16                                     # 计算类型:是否启用fp16或bf16混合精度训练。

2.指令监督微调

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \--stage sft \--model_name_or_path path_to_llama_model \--do_train \--dataset alpaca_gpt4_zh \                  # 提示模板:构建提示词时使用的模板 --template default \              --finetuning_type lora \--lora_target W_pack \--output_dir path_to_sft_checkpoint \--overwrite_cache \--per_device_train_batch_size 4 \--gradient_accumulation_steps 4 \--lr_scheduler_type cosine \--logging_steps 10 \--save_steps 1000 \--learning_rate 5e-5 \--num_train_epochs 3.0 \--plot_loss \--fp16

3.

(1)奖励模型训练

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \--stage rm \--model_name_or_path path_to_llama_model \--do_train \--dataset comparison_gpt4_zh \              # 奖励模型训练数据集--template default \--finetuning_type lora \--lora_target W_pack \--resume_lora_training False \              # 接着上次的LoRA权重训练或创建一个新的LoRA权重--checkpoint_dir path_to_sft_checkpoint \   # 指令微调模型的断点--output_dir path_to_rm_checkpoint \        # 奖励模型的输出位置--per_device_train_batch_size 2 \--gradient_accumulation_steps 4 \--lr_scheduler_type cosine \--logging_steps 10 \--save_steps 1000 \--learning_rate 1e-6 \--num_train_epochs 1.0 \--plot_loss \--fp16

(2)PPO训练(PPO训练需要先进行上一步RM的训练,然后导入微调后模型和RM进行训练输出)

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \--stage ppo \--model_name_or_path path_to_llama_model \--do_train \--dataset alpaca_gpt4_zh \--template default \--finetuning_type lora \--lora_target W_pack \--resume_lora_training False \--checkpoint_dir path_to_sft_checkpoint \   # 加载指令微调的断点模型--reward_model path_to_rm_checkpoint \      # 奖励模型的断点路径--output_dir path_to_ppo_checkpoint \       # ppo训练的断点输出位置--per_device_train_batch_size 2 \--gradient_accumulation_steps 4 \--lr_scheduler_type cosine \--logging_steps 10 \--save_steps 1000 \--learning_rate 1e-5 \--num_train_epochs 1.0 \--plot_loss

4.DPO训练(不需要先训练RM,直接导入微调模型进行DPO训练)

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \--stage dpo \--model_name_or_path path_to_llama_model \--do_train \--dataset comparison_gpt4_zh \--template default \--finetuning_type lora \--lora_target W_pack \--resume_lora_training False \--checkpoint_dir path_to_sft_checkpoint \--output_dir path_to_dpo_checkpoint \--per_device_train_batch_size 2 \--gradient_accumulation_steps 4 \--lr_scheduler_type cosine \--logging_steps 10 \--save_steps 1000 \--learning_rate 1e-5 \--num_train_epochs 1.0 \--plot_loss \--fp16

 注:

       大规模无监督语言模型(LMs)虽然可以学习广泛的世界知识和一些推理技能,但由于其训练的完全无监督性,因此实现对其行为的精确控制是困难的。现有的获得这种可控性的方法是收集人工对模型生成相对质量的标签,并且通过人类反馈强化学习(RLHF)对无监督的 LM 进行微调,以使其与人类偏好相一致。然而,RLHF 是一个复杂且经常不太稳定的过程,它首先拟合一个反应人类偏好的奖励模型,然后通过强化学习对大型无监督 LM 进行微调以最大化评估奖励,并避免与原始模型相差太远。

        在本文中,我们使用奖励函数和最优策略间的映射,展示了约束奖励最大化问题完全可以通过单阶段策略训练进行优化 ,从本质上解决了人类偏好数据上的分类问题。我们提出的这个算法称为直接偏好优化(Direct Preference Optimization,DPO)。它具有稳定性、高性能和计算轻量级的特点,不需要拟合奖励模型,不需要在微调时从 LM 中采样,也不需要大量的超参调节。我们的实验表明了 DPO 可以微调 LMs 以对齐人类偏好,甚至比现有方法更好。值得注意的是,用 DPO 进行微调在控制生成结果的情感以及改善摘要和单轮对话的响应质量方面表现出更好的能力,同时在实现和训练时的难度大大降低。

二、多卡训练

三、模型评估

CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \--model_name_or_path path_to_llama_model \--finetuning_type lora \--checkpoint_dir path_to_checkpoint \--template vanilla \--task ceval \--split validation \--lang zh \--n_shot 5 \--batch_size 4

四、模型预测

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \--stage sft \--model_name_or_path path_to_llama_model \--do_predict \--dataset alpaca_gpt4_zh \--template default \--finetuning_type lora \--checkpoint_dir path_to_checkpoint \--output_dir path_to_predict_result \--per_device_eval_batch_size 8 \--max_samples 100 \                     # 最大样本数:每个数据集最多使用的样本数--predict_with_generate

我们建议在量化模型的预测中使用 

--per_device_eval_batch_size=1 
--max_target_length 128

微调训练后生成的文件夹path_to_sft_checkpoint中包括:

1.checkpoint-xxx 间隔固定step生成的模型断点

2.runs 文件夹用于tensorboard可视化训练过程

3.lora adapter模型、配置

4.分词器脚本、配置、模型

5.训练日志

6.loss曲线

 

tensorboard: 

 

相关文章:

  • 2000-2021年县域指标统计数据库
  • Python编程-二万字浅谈装饰器原理与装饰器设计模式和函数式编程案例讲解
  • linux的make和makefile
  • 第三节 zookeeper基础应用与实战2
  • 单片机的省电模式及策略
  • FileZilla Server 1.8.1内网搭建
  • MySQL 基础知识(一)之数据库和 SQL 概述
  • BUUCTF misc 专题(47)[SWPU2019]神奇的二维码
  • 【初始C++】引用的概念及使用场景、引用与指针的区别、内联函数、类型推导关键字auto、范围for循环、指针空值nullptr
  • Excel+VBA处理高斯光束
  • 毕业设计vue+php幼儿园网站系统yl567
  • 【Java EE初阶十二】网络编程TCP/IP协议(二)
  • Duilib List 控件学习
  • 第三百一十回
  • ELAdmin 部署
  • [ 一起学React系列 -- 8 ] React中的文件上传
  • 《用数据讲故事》作者Cole N. Knaflic:消除一切无效的图表
  • 【跃迁之路】【463天】刻意练习系列222(2018.05.14)
  • Fabric架构演变之路
  • jQuery(一)
  • JS题目及答案整理
  • MySQL主从复制读写分离及奇怪的问题
  • niucms就是以城市为分割单位,在上面 小区/乡村/同城论坛+58+团购
  • Python socket服务器端、客户端传送信息
  • tab.js分享及浏览器兼容性问题汇总
  • vue-router的history模式发布配置
  • webgl (原生)基础入门指南【一】
  • -- 查询加强-- 使用如何where子句进行筛选,% _ like的使用
  • 基于OpenResty的Lua Web框架lor0.0.2预览版发布
  • 快速构建spring-cloud+sleuth+rabbit+ zipkin+es+kibana+grafana日志跟踪平台
  • 因为阿里,他们成了“杭漂”
  • 源码安装memcached和php memcache扩展
  • “十年磨一剑”--有赞的HBase平台实践和应用之路 ...
  • 7行Python代码的人脸识别
  • 如何用纯 CSS 创作一个菱形 loader 动画
  • ​​​​​​​GitLab 之 GitLab-Runner 安装,配置与问题汇总
  • # include “ “ 和 # include < >两者的区别
  • #1015 : KMP算法
  • #stm32整理(一)flash读写
  • (翻译)Entity Framework技巧系列之七 - Tip 26 – 28
  • (分享)一个图片添加水印的小demo的页面,可自定义样式
  • (四)七种元启发算法(DBO、LO、SWO、COA、LSO、KOA、GRO)求解无人机路径规划MATLAB
  • (转)Oracle 9i 数据库设计指引全集(1)
  • (转)创业的注意事项
  • ***linux下安装xampp,XAMPP目录结构(阿里云安装xampp)
  • ..thread“main“ com.fasterxml.jackson.databind.JsonMappingException: Jackson version is too old 2.3.1
  • .[backups@airmail.cc].faust勒索病毒的最新威胁:如何恢复您的数据?
  • .NET 8 中引入新的 IHostedLifecycleService 接口 实现定时任务
  • .NET Windows:删除文件夹后立即判断,有可能依然存在
  • .net实现客户区延伸至至非客户区
  • .Net中的设计模式——Factory Method模式
  • [1127]图形打印 sdutOJ
  • [20180224]expdp query 写法问题.txt
  • [BeginCTF]真龙之力
  • [BZOJ 3531][Sdoi2014]旅行(树链剖分+线段树)