当前位置: 首页 > news >正文

【2024】Datawhale AI夏令营 Task4笔记——vllm加速方式修改及llm推理参数调整上分

【2024】Datawhale AI夏令营 Task4笔记——vllm加速方式修改及llm推理参数调整上分

本文承接文章【2024】Datawhale AI夏令营 Task3笔记——Baseline2部分代码解读及初步上分思路,对其中vllm加速方式进行修改,推理速度获得了极大提升。另外,在延用多路投票的同时,通过调整大语言模型的参数获得了一些分数的提升。

🔴本文主要的注意点:

1、在使用vllm离线推理时,prompt信息需要装入messages并应用tokenizer的对话模板,否则回答会非常抽象。

2、llm推理参数调整对上分的帮助较小,大概在0.1左右。

一、vLLM加速方式修改

文章【2024】Datawhale AI夏令营 Task3笔记——Baseline2部分代码解读及初步上分思路中使用的vLLM加速方式是类openAI的API服务(vLLM启动的相关参数及解释可参考文章:VLLM参数解释-中文表格形式),本文使用的vLLM加速方式是离线批量推理

vLLM离线批量推理的参考文章:

Qwen-离线推理(仅实现离线推理,未实现批量)

使用vLLM和ChatGLM3-6b批量推理(实现离线批量推理,但不完全适用于本次比赛)

Using VLMs(官方文档,实现与图像相关的离线批量推理,但不完全适用于本次比赛)

本文最终使用的vLLM离线批量推理的代码如下。

1.1 引入相关包,创建LLM模型对象及tokenizer

from vllm import LLM, SamplingParams
from transformers import AutoModelForCausalLM, AutoTokenizer
import torchdevice = "cuda"
model_path = './merged_model_ana_my'
llm = LLM(model_path) # 使用vllm.LLM()创建LLM对象
tokenizer = AutoTokenizer.from_pretrained(model_path) # 使用AutoTokenizer.from_pretrained()创建tokenizer

🔴注意:

1、只需要提供模型路径即可创建LLM对象。不需要另外使用类似model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype=torch.float16).eval()的代码创建模型对象,这样可能会导致加载模型权重时程序被Killed或者推理时内存不足(因为创建的模型对象会占用较大的内存空间)。

2、tokenizer还可以通过如下方式创建:

device = "cuda"
model_path = './merged_model_ana_my'
llm = LLM(model_path, model_path) # 第一个model_path表示使用该路径下的model,第二个model_path表示使用该路径下的tokenizer(不再使用AutoTokenizer.from_pretrained()创建tokenizer)

这种方式似乎更加简洁,但为何最终不使用这种方式?原因在后面会提到。

1.2 修改process_datas()函数,实现(多路)离线批量推理

def process_datas(datas, MODEL_NAME):prompts = []results = []# os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # 设置使用第1块GPU# 获取每个问题的prompt,并将prompt信息装入messages,(关键)再应用tokenizer的对话模板for data in tqdm(datas, desc="Submitting tasks", total=len(datas)):problem = data['problem']for id, question in enumerate(data['questions']):prompt = get_prompt(problem, question['question'], question['options'],)messages = [{"role": "user", "content": prompt}]text = tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True)prompts.append(text) # 将处理完成的prompt添加入prompts列表,准备输入vllm批量推理# 定义推理参数sampling_params = SamplingParams(temperature=0.7, top_p=0.8, repetition_penalty=1.05, max_tokens=512)# 开始推理# 单路投票推理# outputs = llm.generate(prompts, sampling_params)# 多路投票推理(这里通过进行三次推理,模仿多路投票的过程)outputs1 = llm.generate(prompts, sampling_params)outputs2 = llm.generate(prompts, sampling_params)outputs3 = llm.generate(prompts, sampling_params)'''单路投票'''# i = 0# for data in tqdm(datas, desc="Submitting tasks", total=len(datas)):#     for id, question in enumerate(data['questions']):#         generated_text = outputs[i].outputs[0].text#         i = i + 1#         extract_response= extract(generated_text)#         data['questions'][id]['answer'] = extract_response#         results.append(data)'''多路投票'''i = 0 # 由于outputs中存储的回答序号并不是与datas中的序号一一对应(因为一个问题背景下可能有多个问题),因此使用一个计数变量另外遍历outputsfor data in tqdm(datas, desc="Extracting answers", total=len(datas)):for id, question in enumerate(data['questions']):# 获取每一路推理的回答文本generated_text1 = outputs1[i].outputs[0].textgenerated_text2 = outputs2[i].outputs[0].textgenerated_text3 = outputs3[i].outputs[0].texti = i + 1# 从文本中提取答案选项extract_response1, extract_response2, extract_response3 = extract(generated_text1),  extract(generated_text2),  extract(generated_text3)# 投票选择出现次数最多的选项作为答案ans = most_frequent_char(extract_response1, extract_response2, extract_response3)data['questions'][id]['answer'] = ansresults.append(data)return results

这样修改后,在与前一篇文章同样的环境下,模型推理完成全部问题只需使用约3min30s,相较于原先的7h提升很多。造成这种差异的原因可能是原先每推理一个问题就需要启动一次vllm,启动耗时较大,因此整体速度慢。现在能够将所有问题的prompt一次性传入vllm进行离线批量推理,速度更快。

🔴注意:prompt的内容影响模型的性能。在进行推理时,如果传入的prompt没有经过messages包装、没有应用tokenizer的对话模板,推理出来的文本会非常抽象,例如对于如下问题:

{"problem": "有一群人和一些食物类型。下列是关于这些个体和食物的已知信息:\n\n1. 鸡肉是一种食物。\n2. 苹果是一种食物。\n3. 如果X吃了Y,且X活着,则Y是一种食物。\n4. Bill存活。\n5. Bill吃了花生。\n6. John吃所有食物。\n7. Sue吃所有Bill吃的食物。\n8. John喜欢所有食物。\n\n根据以上信息,回答以下选择题:", "questions": [{"question": "选择题 1:\n谁喜欢吃花生?", "options": ["Bill", "Sue", "John", "None of the above"]}], "id": "round1_test_data_000"}

它的回答是这样的(无中生有了更多选择题):
在这里插入图片描述

对其他问题,回答甚至可能是这样的:

在这里插入图片描述

可以说是非常抽象、已读乱回。

prompt经过messages包装、应用tokenizer的对话模板后就正常多了(但是这一步为什么这么关键,我也还不是很懂):
在这里插入图片描述

这也是为什么在前面要单独创建tokenizer,就是为了在后面能够对prompt应用tokenizer的对话模板。

二、llm推理参数调整上分

其实这只是一个比较低级的trick,还不涉及微调、数据集等技术(时间较短,还未来得及学习应用其他技术)。主要调整llm参数的地方就在process_datas函数中sampling_params定义的位置。

sampling_params = SamplingParams(temperature=0.7, top_p=0.8, repetition_penalty=1.05, max_tokens=512)

关于SamplingParams参数的解释可以查看文档Sampling Parameters,这里这设置了一部分推理参数:temperaturetop_prepetition_penalty

SamplingParams参数的解释可以查看文档Sampling Parameters,这里这设置了一部分推理参数:temperaturetop_prepetition_penalty

这部分是否真的能够提分还没有做对比实验(毕竟验证会消耗提交次数),但是与前一篇文章中的最高分相比,使用此篇文章的代码再次推理出答案后,得到的分数提升了0.1。而本文代码与前一篇文章的代码相比,与推理准确度有关的部分只做了这一方面的改动,vllm加速方式的改动应该不影响推理准确度,所以暂且认为这部分参数的调整有助于微小提分。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 神经网络基础--激活函数
  • 深拷贝——JSON.stringify()序列化和JSON.prase()反序列化
  • 【C语言】Top K问题【建小堆】
  • 浙大版《C语言程序设计(第3版)》题目集
  • JavaScript 继承百花齐放:从原型链到 ES6 类
  • 软设之TCP/IP协议
  • 软科中国大学排名爬虫+数据可视化
  • 图片管理组建
  • Flink 实时数仓(三)【DWD 层搭建(一)】
  • 《人性的枷锁:菲利普的人生探索能解开枷锁吗?》
  • 树套树模板
  • PYTHON专题-(5)类的专有方法
  • 每日学术速递8.3
  • Xilinx管脚验证流程及常见问题
  • conda环境pip 安装Tensorflow-gpu 2.10.2提示nbconvert 的包依赖冲突
  • express.js的介绍及使用
  • github指令
  • Iterator 和 for...of 循环
  • JAVA 学习IO流
  • Java到底能干嘛?
  • leetcode98. Validate Binary Search Tree
  • MySQL用户中的%到底包不包括localhost?
  • React系列之 Redux 架构模式
  • Spring Cloud Feign的两种使用姿势
  • tweak 支持第三方库
  • 阿里云前端周刊 - 第 26 期
  • 汉诺塔算法
  • 前端学习笔记之观察者模式
  • 算法之不定期更新(一)(2018-04-12)
  • 腾讯优测优分享 | Android碎片化问题小结——关于闪光灯的那些事儿
  • 栈实现走出迷宫(C++)
  • 大数据全解:定义、价值及挑战
  • ​LeetCode解法汇总2808. 使循环数组所有元素相等的最少秒数
  • ​创新驱动,边缘计算领袖:亚马逊云科技海外服务器服务再进化
  • ​学习笔记——动态路由——IS-IS中间系统到中间系统(报文/TLV)​
  • #[Composer学习笔记]Part1:安装composer并通过composer创建一个项目
  • (0)Nginx 功能特性
  • (1综述)从零开始的嵌入式图像图像处理(PI+QT+OpenCV)实战演练
  • (C++)栈的链式存储结构(出栈、入栈、判空、遍历、销毁)(数据结构与算法)
  • (Matalb时序预测)WOA-BP鲸鱼算法优化BP神经网络的多维时序回归预测
  • (二)Eureka服务搭建,服务注册,服务发现
  • (二)JAVA使用POI操作excel
  • (附源码)springboot美食分享系统 毕业设计 612231
  • (附源码)计算机毕业设计SSM疫情下的学生出入管理系统
  • (转)母版页和相对路径
  • (转载)微软数据挖掘算法:Microsoft 时序算法(5)
  • (转载)虚幻引擎3--【UnrealScript教程】章节一:20.location和rotation
  • *p++,*(p++),*++p,(*p)++区别?
  • .gitignore文件忽略的内容不生效问题解决
  • .mat 文件的加载与创建 矩阵变图像? ∈ Matlab 使用笔记
  • .NET “底层”异步编程模式——异步编程模型(Asynchronous Programming Model,APM)...
  • .NET Core/Framework 创建委托以大幅度提高反射调用的性能
  • .NET Project Open Day(2011.11.13)
  • .NET 跨平台图形库 SkiaSharp 基础应用
  • .Net 路由处理厉害了