当前位置: 首页 > news >正文

DataWhale AI夏令营-英特尔-阿里天池LLM Hackathon

英特尔-阿里天池LLM Hackathon

  • 项目思路
    • 项目背景
    • 项目思路
  • Lora微调Qwen模型
  • 使用ipex_llm推理加速
  • Gradio交互

项目名称:医疗问答助手

项目思路

项目背景

在当今医疗领域,智能问答系统正在逐步成为辅助医疗诊断的重要工具。随着自然语言处理技术的发展,基于大模型的问答系统在处理复杂医疗问题时展现出了巨大的潜力。Qwen2-1.5B模型作为一个大型预训练语言模型,拥有强大的语言理解和生成能力,但在特定领域应用时,往往需要进一步的微调和优化。为了提升医疗问答系统的准确性,本项目采用了LoRA(Low-Rank Adaptation)微调方法,并通过ipex_llm框架在指定的CPU平台上进行推理加速。

项目思路

明确了项目需求之后可以将本次项目分为三个部分:Lora微调Qwen模型、使用ipex_llm在CPU上进行推理加速、使用Gradio交互。

Lora微调Qwen模型

我们本次项目的目的是完成一个医疗问答机器人,训练的首先需要收集数据,我们使用github上开源的医疗问答数据集,数据集包含了2.7w条真实的问答数据(github链接有点久远了时间我忘记了,如果有需要可以私信我我发给您)。
在这里插入图片描述
Qwen的Lora我们在之前的博客中有提到过,在这里就不细说了,详见Qwen2-1.5B微调+推理

import torch
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer
from peft import LoraConfig, TaskType, get_peft_model, PeftModeldataset = load_dataset("csv", data_files="./问答.csv", split="train")
dataset = dataset.filter(lambda x: x["answer"] is not None)
datasets = dataset.train_test_split(test_size=0.1)tokenizer = AutoTokenizer.from_pretrained("./Qwen2-1.5B-Instruct", trust_remote_code=True)def process_func(example):MAX_LENGTH = 768input_ids, attention_mask, labels = [], [], []instruction = example["question"].strip()     # queryinstruction = tokenizer(f"<|im_start|>system\n你是医学领域的人工助手章鱼哥<|im_end|>\n<|im_start|>user\n{example['question']}<|im_end|>\n<|im_start|>assistant\n",add_special_tokens=False,)response = tokenizer(f"{example['answer']}", add_special_tokens=False)        # \n response, 缺少eos tokeninput_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]attention_mask = (instruction["attention_mask"] + response["attention_mask"] + [1])labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]if len(input_ids) > MAX_LENGTH:input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]return {"input_ids": input_ids,"attention_mask": attention_mask,"labels": labels}tokenized_ds = datasets['train'].map(process_func, remove_columns=['id', 'question', 'answer'])
tokenized_ts = datasets['test'].map(process_func, remove_columns=['id', 'question', 'answer'])model = AutoModelForCausalLM.from_pretrained("./Qwen2-1.5B-Instruct", trust_remote_code=True)config = LoraConfig(target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"], modules_to_save=["post_attention_layernorm"])model = get_peft_model(model, config)args = TrainingArguments(output_dir="./law",per_device_train_batch_size=4,gradient_accumulation_steps=16,gradient_checkpointing=True,logging_steps=6,num_train_epochs=10,learning_rate=1e-4,remove_unused_columns=False,save_strategy="epoch"
)
model.enable_input_require_grads()trainer = Trainer(model=model,args=args,train_dataset=tokenized_ds.select(range(400)),data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)
trainer.train()

训练结束得到微调后的权重,打包下载即可。
在这里插入图片描述

使用ipex_llm推理加速

导入需要的包
ipex是Intel公司研发优化大语言模型 (LLM) 在其硬件(Intel CPU)上运行而开发的一组扩展库和工具。

import os
import torch
import time
from transformers import AutoTokenizer
from ipex_llm.transformers import AutoModelForCausalLM
from peft import PeftModel

由于实在Cpu推理可以根据核心数设置线程

# 设置OpenMP线程数为8, 优化CPU并行计算性能
os.environ["OMP_NUM_THREADS"] = "8"# base_model_name = "qwen2chat_int4"
# model = AutoModelForCausalLM.load_low_bit(base_model_name, trust_remote_code=True)# 加载基础模型和分词器
base_model_name = "Qwen2-1-5B-Instruct"  # 替换为你的基础模型名称
model = AutoModelForCausalLM.from_pretrained(base_model_name,torch_dtype="auto",device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)

合并Lora

# 加载LoRA微调后的权重
lora_checkpoint = "./checkpoint-781"
lora_model = PeftModel.from_pretrained(model, lora_checkpoint)

输入Prompt测试

# 定义输入prompt
prompt = "头疼怎么治疗呢"# 构建符合模型输入格式的消息列表
messages = [{"role": "user", "content": prompt}]

开启推理模式,在这部分其实有一个缺陷,就是合并Lora后的模型推理速度非常慢,大概是普通模型的五倍,欢迎有大佬能指点。

# 使用推理模式,减少内存使用并提高推理速度
with torch.inference_mode():# 应用聊天模板,将消息转换为模型输入格式的文本text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)# 将文本转换为模型输入张量,并移至CPU (如果使用GPU,这里应改为.to('cuda'))model_inputs = tokenizer([text], return_tensors="pt").to('cpu')st = time.time()# 生成回答, max_new_tokens限制生成的最大token数generated_ids = lora_model.generate(model_inputs.input_ids, max_new_tokens=512)end = time.time()# 初始化一个空列表,用于存储处理后的generated_idsprocessed_generated_ids = []# 使用zip函数同时遍历model_inputs.input_ids和generated_idsfor input_ids, output_ids in zip(model_inputs.input_ids, generated_ids):# 计算输入序列的长度input_length = len(input_ids)# 从output_ids中截取新生成的部分# 这是通过切片操作完成的,只保留input_length之后的部分new_tokens = output_ids[input_length:]# 将新生成的token添加到处理后的列表中processed_generated_ids.append(new_tokens)# 将处理后的列表赋值回generated_idsgenerated_ids = processed_generated_ids# 解码模型输出,转换为可读文本response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

打印推理时间和结果

    # 打印推理时间print(f'Inference time: {end-st:.2f} s')# 打印原始promptprint('-'*20, 'Prompt', '-'*20)print(text)# 打印模型生成的输出print('-'*20, 'Output', '-'*20)print(response)

一站式py脚本

import os
import torch
import time
from transformers import AutoTokenizer
from ipex_llm.transformers import AutoModelForCausalLM
from peft import PeftModel# 设置OpenMP线程数为8, 优化CPU并行计算性能
os.environ["OMP_NUM_THREADS"] = "8"# 加载基础模型和分词器
base_model_name = "Qwen2-1-5B-Instruct"  # 替换为你的基础模型名称
model = AutoModelForCausalLM.from_pretrained(base_model_name,torch_dtype="auto",device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)# 加载LoRA微调后的权重
lora_checkpoint = "./checkpoint-5000"
lora_model = PeftModel.from_pretrained(model, lora_checkpoint)# 定义输入prompt
prompt = "头疼怎么治疗呢"# 构建符合模型输入格式的消息列表
messages = [{"role": "user", "content": prompt}]# 使用推理模式,减少内存使用并提高推理速度
with torch.inference_mode():# 应用聊天模板,将消息转换为模型输入格式的文本text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)# 将文本转换为模型输入张量,并移至CPU (如果使用GPU,这里应改为.to('cuda'))model_inputs = tokenizer([text], return_tensors="pt").to('cpu')st = time.time()# 生成回答, max_new_tokens限制生成的最大token数generated_ids = lora_model.generate(model_inputs.input_ids, max_new_tokens=512)end = time.time()# 初始化一个空列表,用于存储处理后的generated_idsprocessed_generated_ids = []# 使用zip函数同时遍历model_inputs.input_ids和generated_idsfor input_ids, output_ids in zip(model_inputs.input_ids, generated_ids):# 计算输入序列的长度input_length = len(input_ids)# 从output_ids中截取新生成的部分# 这是通过切片操作完成的,只保留input_length之后的部分new_tokens = output_ids[input_length:]# 将新生成的token添加到处理后的列表中processed_generated_ids.append(new_tokens)# 将处理后的列表赋值回generated_idsgenerated_ids = processed_generated_ids# 解码模型输出,转换为可读文本response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]# 打印推理时间print(f'Inference time: {end-st:.2f} s')# 打印原始promptprint('-'*20, 'Prompt', '-'*20)print(text)# 打印模型生成的输出print('-'*20, 'Output', '-'*20)print(response)

Gradio交互

Gradio是一个功能强大的Web交互页面,Gradio的特点是可以非常简单的使用几行代码实现前端的页面,在这里我只是简单的使用了比赛baseline提供的一个简单的Gradio,后续有时间我也会专门补一篇gradio使用教程。

import os
import torch
import time
from transformers import AutoTokenizer
from ipex_llm.transformers import AutoModelForCausalLM
from peft import PeftModel
import gradio as gr
from threading import Event# 设置OpenMP线程数为8, 优化CPU并行计算性能
os.environ["OMP_NUM_THREADS"] = "8"# 加载基础模型和分词器
base_model_name = "Qwen2-1-5B-Instruct"  # 替换为你的基础模型名称
model = AutoModelForCausalLM.from_pretrained(base_model_name,torch_dtype="auto",device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)# 加载LoRA微调后的权重
lora_checkpoint = "./checkpoint-781"
lora_model = PeftModel.from_pretrained(model, lora_checkpoint)# 创建一个停止事件,用于控制生成过程的中断
stop_event = Event()# 定义用户输入处理函数
def user(user_message, history):return "", history + [[user_message, None]]# 定义机器人回复生成函数
def bot(history):stop_event.clear()prompt = history[-1][0]messages = [{"role": "user", "content": prompt}]text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)model_inputs = tokenizer([text], return_tensors="pt").to('cpu')print(f"\n用户输入: {prompt}")print("模型输出: ", end="", flush=True)start_time = time.time()with torch.inference_mode():generated_ids = lora_model.generate(model_inputs.input_ids, max_new_tokens=512)processed_generated_ids = []for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids):input_length = len(input_ids)new_tokens = output_ids[input_length:]processed_generated_ids.append(new_tokens)generated_ids = processed_generated_idsresponse = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]history[-1][1] = responseend_time = time.time()print(f"\n生成完成,用时: {end_time - start_time:.2f} 秒")return historydef stop_generation():stop_event.set()with gr.Blocks() as demo:gr.Markdown("# Qwen 聊天机器人")chatbot = gr.Chatbot()msg = gr.Textbox()clear = gr.Button("清除")stop = gr.Button("停止生成")msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(bot, chatbot, chatbot)clear.click(lambda: None, None, chatbot, queue=False)stop.click(stop_generation, queue=False)if __name__ == "__main__":print("启动 Gradio 界面...")demo.queue()demo.launch(root_path='/dsw-607012/proxy/7860/')

运行代码即可启动本次项目的界面,测试界面如下:
请添加图片描述
在这里插入图片描述

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • Xlua原理分析 四
  • 虚拟机ubuntu22.04找不到ttyUSB*端口
  • Windows系统之环境变量
  • Lumos学习王佩丰Excel第十讲:Sumif函数
  • .NET未来路在何方?
  • ei会议论文是什么级别
  • 登录相关功能的优化【JWT令牌+拦截器+跨域】
  • 研0 冲刺算法竞赛 day27 P1090 [NOIP2004 提高组] 合并果子 / [USACO06NOV] Fence Repair G
  • linux 进程 inode 信息获取
  • Java 面试常见问题之——final,finalize 和 finally 的不同之处
  • Java IO与NIO的对比与高级用法
  • python-打分(赛氪OJ)
  • 书生大模型实战营第三期——入门岛——Git基础知识
  • 【Android】四大组件(Activity、Service、Broadcast Receiver、Content Provider)、结构目录
  • DataX迁移数据到StarRocks超大表报too many version问题记录
  • [ JavaScript ] 数据结构与算法 —— 链表
  • [数据结构]链表的实现在PHP中
  • 【5+】跨webview多页面 触发事件(二)
  • 07.Android之多媒体问题
  • CSS实用技巧干货
  • el-input获取焦点 input输入框为空时高亮 el-input值非法时
  • golang 发送GET和POST示例
  • Java,console输出实时的转向GUI textbox
  • JAVA并发编程--1.基础概念
  • Java程序员幽默爆笑锦集
  • java正则表式的使用
  • Netty 框架总结「ChannelHandler 及 EventLoop」
  • React组件设计模式(一)
  • REST架构的思考
  • UEditor初始化失败(实例已存在,但视图未渲染出来,单页化)
  • 关于Android中设置闹钟的相对比较完善的解决方案
  • 蓝海存储开关机注意事项总结
  • 原创:新手布局福音!微信小程序使用flex的一些基础样式属性(一)
  • ​​​​​​​​​​​​​​Γ函数
  • # MySQL server 层和存储引擎层是怎么交互数据的?
  • #、%和$符号在OGNL表达式中经常出现
  • #HarmonyOS:基础语法
  • #QT(智能家居界面-界面切换)
  • #QT项目实战(天气预报)
  • #鸿蒙生态创新中心#揭幕仪式在深圳湾科技生态园举行
  • $Django python中使用redis, django中使用(封装了),redis开启事务(管道)
  • (2)关于RabbitMq 的 Topic Exchange 主题交换机
  • (c语言版)滑动窗口 给定一个字符串,只包含字母和数字,按要求找出字符串中的最长(连续)子串的长度
  • (附源码)小程序儿童艺术培训机构教育管理小程序 毕业设计 201740
  • (三)centos7案例实战—vmware虚拟机硬盘挂载与卸载
  • (四) Graphivz 颜色选择
  • (一)pytest自动化测试框架之生成测试报告(mac系统)
  • (原創) 博客園正式支援VHDL語法著色功能 (SOC) (VHDL)
  • (转)scrum常见工具列表
  • (转)如何上传第三方jar包至Maven私服让maven项目可以使用第三方jar包
  • (转载)hibernate缓存
  • .dat文件写入byte类型数组_用Python从Abaqus导出txt、dat数据
  • .h头文件 .lib动态链接库文件 .dll 动态链接库
  • .java 指数平滑_转载:二次指数平滑法求预测值的Java代码
  • .L0CK3D来袭:如何保护您的数据免受致命攻击