当前位置: 首页 > news >正文

firefly推理和微调qwen

1.conda环境准备


git clone https://github.com/yangjianxin1/Firefly.gitconda create -n firefly  python=3.10cd ./Fireflypip install -r requirements.txt  -i https://pypi.tuna.tsinghua.edu.cn/simple 然后还有依赖要安装pip install git+https://github.com/huggingface/peft.git    -i https://pypi.tuna.tsinghua.edu.cn/simple 
pip install git+https://github.com/huggingface/accelerate.git    -i https://pypi.tuna.tsinghua.edu.cn/simple 
pip install git+https://github.com/huggingface/transformers.git    -i https://pypi.tuna.tsinghua.edu.cn/simple 
pip install git+https://github.com/TimDettmers/bitsandbytes.git     -i https://pypi.tuna.tsinghua.edu.cn/simple 
pip install einops transformers_stream_generator
pip install tiktoken

2.推理

这里我是将chat.py代码放到component文件夹下了,所以untils,而不是component.utils

from transformers import AutoTokenizer
import torchimport sys
from utils import ModelUtils
"""
单轮对话,不具有对话历史的记忆功能
"""def main():# 使用合并后的模型进行推理model_name_or_path = '/home/cxh/Qwen-7B'adapter_name_or_path = None# 使用base model和adapter进行推理,无需手动合并权重# model_name_or_path = 'baichuan-inc/Baichuan-7B'# adapter_name_or_path = 'YeungNLP/firefly-baichuan-7b-qlora-sft'# 是否使用4bit进行推理,能够节省很多显存,但效果可能会有一定的下降load_in_4bit = False# 生成超参配置max_new_tokens = 500top_p = 0.9temperature = 0.35repetition_penalty = 1.0device = 'cuda'# 加载模型model = ModelUtils.load_model(model_name_or_path,load_in_4bit=load_in_4bit,adapter_name_or_path=adapter_name_or_path).eval()tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,trust_remote_code=True,# llama不支持fastuse_fast=False if model.config.model_type == 'llama' else True,padding_side='left'  # 设置左侧填充)print("tokenizer.__class__.__name__",tokenizer.__class__.__name__)# QWenTokenizer比较特殊,pad_token_id、bos_token_id、eos_token_id均为None。eod_id对应的token为<|endoftext|>if tokenizer.__class__.__name__ == 'QWenTokenizer':tokenizer.pad_token_id = tokenizer.eod_idtokenizer.bos_token_id = tokenizer.eod_idtokenizer.eos_token_id = tokenizer.eod_idtext = input('User:')while True:text = text.strip()# chatglm使用官方的数据组织格式print(123)if model.config.model_type == 'chatglm':print("chatglm")text = '[Round 1]\n\n问:{}\n\n答:'.format(text)input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids.to(device)# 为了兼容qwen-7b,因为其对eos_token进行tokenize,无法得到对应的eos_token_idelse:print("QWEN")input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids.to(device)bos_token_id = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to(device)eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long).to(device)input_ids = torch.concat([bos_token_id, input_ids, eos_token_id], dim=1)with torch.no_grad():outputs = model.generate(input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True,top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty,eos_token_id=tokenizer.eos_token_id)outputs = outputs.tolist()[0][len(input_ids[0]):]response = tokenizer.decode(outputs)response = response.strip().replace(tokenizer.eos_token, "").strip()print("123321")print("Firefly:{}".format(response))text = input('User:')if __name__ == '__main__':main()

3.微调

pip install loguru  astunparse

torchrun --nproc_per_node=1 train_qlora.py --train_args_file /home/cxh/Firefly/train_args/qlora/qwen-7b-sft-qlora.json

train_qlora.py内容如下
 

from transformers import AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (set_seed,HfArgumentParser,TrainingArguments,AutoModelForCausalLM
)
import argparse
from loguru import logger
import os
from os.path import join
import torch
import bitsandbytes as bnb
from collections import defaultdictfrom collator import SFTDataCollator
from dataset import (SFTDataset,ChatGLM2SFTDataset,ChatGLM3SFTDataset,MistralSFTDataset,ZephyrSFTDataset,QwenSFTDataset
)
from argument import QLoRAArguments
from trainer import LoRATrainer
# from component.loss import TargetLMLosstorch.cuda.empty_cache()
def verify_model_dtype(model):"""查看模型种各种类型的参数的情况"""dtype2param_num = defaultdict(int)  # 每种数据类型的参数量dtype2param_name = defaultdict(list)  # 每种数据类型的参数名称dtype2trainable_param_num = defaultdict(int)  # 每种数据类型参与训练的参数量dtype2trainable_param_name = defaultdict(list)  # 每种数据类型参与训练的参数名称for name, p in model.named_parameters():dtype = p.dtypedtype2param_num[dtype] += p.numel()dtype2param_name[dtype].append(name)if p.requires_grad:dtype2trainable_param_num[dtype] += p.numel()dtype2trainable_param_name[dtype].append(name)# 统计全部参数中,各种类型参数分布total = 0print('verify all params of the model')for k, v in dtype2param_num.items():total += vfor k, v in dtype2param_num.items():print(k, v, v / total)for k, v in dtype2trainable_param_name.items():print(k, v)print()# 统计可训练参数中,各种类型参数分布print('verify trainable params the model')total_trainable = 0for k, v in dtype2trainable_param_num.items():total_trainable += vfor k, v in dtype2trainable_param_num.items():print(k, v, v / total_trainable)for k, v in dtype2trainable_param_num.items():print(k, v)def find_all_linear_names(model):"""找出所有全连接层,为所有全连接添加adapter"""cls = bnb.nn.Linear4bitlora_module_names = set()for name, module in model.named_modules():if isinstance(module, cls):names = name.split('.')lora_module_names.add(names[0] if len(names) == 1 else names[-1])if 'lm_head' in lora_module_names:  # needed for 16-bitlora_module_names.remove('lm_head')return list(lora_module_names)def setup_everything():parser = argparse.ArgumentParser()parser.add_argument("--train_args_file", type=str, default='train_args/baichuan-sft-qlora.json', help="")args = parser.parse_args()train_args_file = args.train_args_file# 读取训练的参数配置parser = HfArgumentParser((QLoRAArguments, TrainingArguments))# 解析得到自定义参数,以及自带参数args, training_args = parser.parse_json_file(json_file=train_args_file)# 创建输出目录if not os.path.exists(training_args.output_dir):os.makedirs(training_args.output_dir)# logger.add(join(training_args.output_dir, 'train.log'))# logger.info("train_args:{}".format(training_args))# 设置随机种子set_seed(training_args.seed)return args, training_args
def init_components(args, training_args):logger.info('Initializing components...')training_args.ddp_find_unused_parameters = Falselocal_rank = int(os.environ.get('LOCAL_RANK', '0'))device_map = {'': local_rank}# 加载模型并进行4bit量化配置quantization_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_compute_dtype=torch.float16,bnb_4bit_use_double_quant=True,bnb_4bit_quant_type="nf4",llm_int8_threshold=6.0,llm_int8_has_fp16_weight=False,)# 确保仅使用 `quantization_config` 参数,移除 `load_in_4bit`model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,device_map=device_map,torch_dtype=torch.float16,trust_remote_code=True,quantization_config=quantization_config  # 使用4bit量化配置)model.config.use_cache = Falseif 'output_router_logits' in model.config.to_dict():logger.info('set output_router_logits as True')model.config.output_router_logits = Truetokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path,trust_remote_code=True,use_fast=False if model.config.model_type == 'llama' else True)if tokenizer.__class__.__name__ == 'QWenTokenizer':tokenizer.pad_token_id = tokenizer.eod_idtokenizer.bos_token_id = tokenizer.eod_idtokenizer.eos_token_id = tokenizer.eod_idelif tokenizer.__class__.__name__ != 'ChatGLMTokenizer':assert tokenizer.eos_token_id is not Noneassert tokenizer.bos_token_id is not Nonetokenizer.pad_token_id = tokenizer.eos_token_id if tokenizer.pad_token_id is None else tokenizer.pad_token_idmodel = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)print(f'memory footprint of model: {model.get_memory_footprint()/(1024*1024*1024)} GB')target_modules = find_all_linear_names(model)config = LoraConfig(r=args.lora_rank,lora_alpha=args.lora_alpha,target_modules=target_modules,lora_dropout=args.lora_dropout,bias="none",task_type="CAUSAL_LM",)model = get_peft_model(model, config)model.print_trainable_parameters()model.config.torch_dtype = torch.float32verify_model_dtype(model)if 'chatglm2' in args.model_name_or_path.lower():train_dataset = ChatGLM2SFTDataset(args.train_file, tokenizer, args.max_seq_length)elif 'chatglm3' in args.model_name_or_path.lower():train_dataset = ChatGLM3SFTDataset(args.train_file, tokenizer, args.max_seq_length)elif 'mistral' in args.model_name_or_path.lower() or 'mixtral' in args.model_name_or_path.lower():train_dataset = MistralSFTDataset(args.train_file, tokenizer, args.max_seq_length)elif 'zephyr' in args.model_name_or_path.lower():train_dataset = ZephyrSFTDataset(args.train_file, tokenizer, args.max_seq_length)elif 'qwen' in args.model_name_or_path.lower():train_dataset = QwenSFTDataset(args.train_file, tokenizer, args.max_seq_length)else:train_dataset = SFTDataset(args.train_file, tokenizer, args.max_seq_length)data_collator = SFTDataCollator(tokenizer, args.max_seq_length)trainer = LoRATrainer(model=model,args=training_args,train_dataset=train_dataset,data_collator=data_collator,)return trainerdef main():# 进行一些配置和检查args, training_args = setup_everything()# 加载各种组件trainer = init_components(args, training_args)# 开始训练logger.info("*** starting training ***")# todo resume from checkpoint# https://github.com/huggingface/transformers/issues/24252train_result = trainer.train()# 保存最后的checkpointtrainer.save_model(training_args.output_dir)  # Saves the tokenizer too# 保存训练指标metrics = train_result.metricstrainer.log_metrics("train", metrics)trainer.save_metrics("train", metrics)trainer.save_state()if __name__ == "__main__":main()

 qwen-7b-sft-qlora.json内如如下

{"output_dir": "output/firefly-qwen-7b","model_name_or_path": "/home/cxh/Qwen-7B","train_file": "/home/cxh/Firefly/data/dummy_data.jsonl","num_train_epochs": 20,"per_device_train_batch_size": 4,"gradient_accumulation_steps": 2,"learning_rate": 2e-4,"max_seq_length": 1024,"logging_steps": 300,"save_steps": 500,"save_total_limit": 1,"lr_scheduler_type": "constant_with_warmup","warmup_steps": 3000,"lora_rank": 64,"lora_alpha": 16,"lora_dropout": 0.05,"gradient_checkpointing": true,"disable_tqdm": false,"optim": "paged_adamw_32bit","seed": 42,"fp16": true,"report_to": "tensorboard","dataloader_num_workers": 0,"save_strategy": "steps","weight_decay": 0,"max_grad_norm": 0.3,"remove_unused_columns": false
}

5.权重合并

在训练中,我们只保存adapter的权重,不保存合并后的模型权重。

adapter与base model进行权重合并

from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch
"""
使用该脚本,将lora的权重合并大base model中
"""def merge_lora_to_base_model():model_name_or_path = '/home/cxh/Qwen-7B'adapter_name_or_path = '/home/cxh/Firefly/component/output/firefly-qwen-7b'save_path = '/home/cxh/Firefly/script/firefly-Qwen-7B-qlora-sft-merge'config = AutoConfig.from_pretrained(model_name_or_path)tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,trust_remote_code=True,# llama不支持fastuse_fast=False if config.model_type == 'llama' else True)model = AutoModelForCausalLM.from_pretrained(model_name_or_path,trust_remote_code=True,low_cpu_mem_usage=True,torch_dtype=torch.float16,# device_map='auto',device_map={'': 'cpu'})model = PeftModel.from_pretrained(model, adapter_name_or_path, device_map={'': 'cpu'})model = model.merge_and_unload()tokenizer.save_pretrained(save_path)model.save_pretrained(save_path)if __name__ == '__main__':merge_lora_to_base_model()

6.合并后的模型推理


from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
"""
单轮对话,不具有对话历史的记忆功能
"""def main():model_name = '/home/cxh/Firefly/script/firefly-Qwen-7B-qlora-sft-merge'max_new_tokens = 500top_p = 0.9temperature = 0.35repetition_penalty = 1.0device = 'cuda'input_pattern = '<s>{}</s>'model = AutoModelForCausalLM.from_pretrained(model_name,trust_remote_code=True,low_cpu_mem_usage=True,torch_dtype=torch.float16,device_map='auto').to(device).eval()tokenizer = AutoTokenizer.from_pretrained(model_name,trust_remote_code=True,# llama不支持fastuse_fast=False if model.config.model_type == 'llama' else True)text = input('User:')while True:text = text.strip()text = input_pattern.format(text)input_ids = tokenizer(text, return_tensors="pt", add_special_tokens=False).input_ids.to(device)with torch.no_grad():outputs = model.generate(input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True,top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty,eos_token_id=tokenizer.eos_token_id)outputs = outputs.tolist()[0][len(input_ids[0]):]response = tokenizer.decode(outputs)response = response.strip().replace(text, "").replace('</s>', "").replace('<s>', "").strip()print("Firefly:{}".format(response))text = input('User:')if __name__ == '__main__':main()

6.docker打包conda环境

 conda env export --name firefly > environment.yml

environment.yml内容如下
 

name: firefly
channels:- defaults
dependencies:- _libgcc_mutex=0.1=main- _openmp_mutex=5.1=1_gnu- bzip2=1.0.8=h5eee18b_6- ca-certificates=2024.7.2=h06a4308_0- ld_impl_linux-64=2.38=h1181459_1- libffi=3.4.4=h6a678d5_1- libgcc-ng=11.2.0=h1234567_1- libgomp=11.2.0=h1234567_1- libstdcxx-ng=11.2.0=h1234567_1- libuuid=1.41.5=h5eee18b_0- ncurses=6.4=h6a678d5_0- openssl=3.0.14=h5eee18b_0- pip=24.0=py310h06a4308_0- python=3.10.14=h955ad1f_1- readline=8.2=h5eee18b_0- setuptools=72.1.0=py310h06a4308_0- sqlite=3.45.3=h5eee18b_0- tk=8.6.14=h39e8969_0- tzdata=2024a=h04d1e81_0- wheel=0.43.0=py310h06a4308_0- xz=5.4.6=h5eee18b_1- zlib=1.2.13=h5eee18b_1- pip:- absl-py==2.1.0- accelerate==0.33.0.dev0- astunparse==1.6.3- bitsandbytes==0.43.1- certifi==2024.7.4- charset-normalizer==3.3.2- einops==0.8.0- filelock==3.15.4- fsspec==2024.6.1- grpcio==1.65.4- huggingface-hub==0.24.5- idna==3.7- jinja2==3.1.4- loguru==0.7.2- markdown==3.6- markupsafe==2.1.5- mpmath==1.3.0- networkx==3.3- numpy==1.26.4- nvidia-cublas-cu12==12.1.3.1- nvidia-cuda-cupti-cu12==12.1.105- nvidia-cuda-nvrtc-cu12==12.1.105- nvidia-cuda-runtime-cu12==12.1.105- nvidia-cudnn-cu12==9.1.0.70- nvidia-cufft-cu12==11.0.2.54- nvidia-curand-cu12==10.3.2.106- nvidia-cusolver-cu12==11.4.5.107- nvidia-cusparse-cu12==12.1.0.106- nvidia-nccl-cu12==2.20.5- nvidia-nvjitlink-cu12==12.6.20- nvidia-nvtx-cu12==12.1.105- packaging==24.1- peft==0.12.1.dev0- protobuf==4.25.4- psutil==6.0.0- pyyaml==6.0.2- regex==2024.7.24- requests==2.32.3- safetensors==0.4.4- six==1.16.0- sympy==1.13.2- tensorboard==2.17.0- tensorboard-data-server==0.7.2- tiktoken==0.7.0- tokenizers==0.19.1- torch==2.4.0- tqdm==4.66.5- transformers==4.45.0.dev0- transformers-stream-generator==0.0.5- triton==3.0.0- typing-extensions==4.12.2- urllib3==2.2.2- werkzeug==3.0.3
prefix: /home/cxh/.conda/envs/firefly

创建dockerfile

# 使用官方的 Anaconda 基础镜像
FROM continuumio/miniconda3# 设置工作目录
WORKDIR /app# 复制环境配置文件到镜像中
COPY environment.yml .# 创建 Conda 环境并安装所有依赖
RUN conda env create -f environment.yml# 激活环境
SHELL ["conda", "run", "-n", "<firefly>", "/bin/bash", "-c"]# 安装你的代码或其他依赖项
COPY . .# 默认的入口点
CMD ["python", "docker-file.py"]

docker build -t firefly_env .

docker run --rm firefly_env

7.将conda环境挪到google云端硬盘配合colab使用

或者直接一开始就在云盘环境里安装conda环境

conda环境安装指定位置的虚拟环境_conda指定路径创建虚拟环境-CSDN博客

python=3.10.14

https://colab.research.google.com/drive/1vx8EqjtxAxGXF-WCse7a115fj-uCiUpH?authuser=1#scrollTo=FyBlnNyDdADz

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • Appium基础
  • 背包九讲(动态规划)
  • IO流(完善)
  • 2.4 playwright 实战-爬取某宝商品信息
  • 四款录屏大师,一键搞定!新手也能快速上手?
  • Python数值计算(24)——PCHIP
  • Chapter 9 Operational Amplifiers
  • 快速上手Spring Boot
  • 6G:融合5G与AI,重塑网络交互与智能决策的未来
  • NB模组AT 命令用法记录
  • 如何使用yolov5-master进行训练
  • 书生.浦江大模型实战训练营——(四)书生·浦语大模型全链路开源开放体系
  • JavaScript高阶笔记总结第三天:(JavaScript高阶完结)
  • JavaScript中的字符串与数字转换
  • 人工智能GPU算力评估分析
  • 《剑指offer》分解让复杂问题更简单
  • 【MySQL经典案例分析】 Waiting for table metadata lock
  • 【个人向】《HTTP图解》阅后小结
  • Android系统模拟器绘制实现概述
  • Angular数据绑定机制
  • export和import的用法总结
  • JAVA并发编程--1.基础概念
  • js ES6 求数组的交集,并集,还有差集
  • js面向对象
  • Map集合、散列表、红黑树介绍
  • MySQL-事务管理(基础)
  • SAP云平台运行环境Cloud Foundry和Neo的区别
  • SpringBoot 实战 (三) | 配置文件详解
  • Webpack4 学习笔记 - 01:webpack的安装和简单配置
  • 从0到1:PostCSS 插件开发最佳实践
  • 高程读书笔记 第六章 面向对象程序设计
  • 湖南卫视:中国白领因网络偷菜成当代最寂寞的人?
  • 基于Android乐音识别(2)
  • 前端性能优化——回流与重绘
  • 腾讯优测优分享 | Android碎片化问题小结——关于闪光灯的那些事儿
  • 走向全栈之MongoDB的使用
  • ​埃文科技受邀出席2024 “数据要素×”生态大会​
  • #多叉树深度遍历_结合深度学习的视频编码方法--帧内预测
  • #我与Java虚拟机的故事#连载14:挑战高薪面试必看
  • $jQuery 重写Alert样式方法
  • (01)ORB-SLAM2源码无死角解析-(66) BA优化(g2o)→闭环线程:Optimizer::GlobalBundleAdjustemnt→全局优化
  • (六)Hibernate的二级缓存
  • (三)模仿学习-Action数据的模仿
  • (十一)JAVA springboot ssm b2b2c多用户商城系统源码:服务网关Zuul高级篇
  • (五)MySQL的备份及恢复
  • (一)VirtualBox安装增强功能
  • (源码版)2024美国大学生数学建模E题财产保险的可持续模型详解思路+具体代码季节性时序预测SARIMA天气预测建模
  • (转)shell调试方法
  • (转)shell中括号的特殊用法 linux if多条件判断
  • (转载)从 Java 代码到 Java 堆
  • .net 4.0 A potentially dangerous Request.Form value was detected from the client 的解决方案
  • .Net Web项目创建比较不错的参考文章
  • .NET 设计模式初探
  • .NET实现之(自动更新)
  • .Net小白的大学四年,内含面经