当前位置: 首页 > news >正文

GLM3源码学习

原文链接:chatglm源码学习

GLM3源码:https://github.com/THUDM/ChatGLM3

我们直接从openai_api_demo入手,因为api_demo一般是nlp模型后端核心功能实现的部分

openai_api_demo源码

api_server.py

api_server.py是提供web api接口的入口文件,是使用flask框架提供的一个异步接口支持

app = FastAPI(lifespan=lifespan)
class ModelCard(BaseModel):
...
class ChatCompletionResponse(BaseModel):

上面这一堆class是实现chat这个api功能的主要对象,如模型卡、请求体和响应体

@app.get("/health")
async def health() -> Response:"""Health check."""return Response(status_code=200)

这个是测试api状态函数,可以看到这个测试功能还是很直接的,没有考虑部署应用下的问题,如负载情况和安全状况,这个demo也就是一个学习的小demo项目。

@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def get_embeddings(request: EmbeddingRequest):if isinstance(request.input, str):# 判断输入是否是字符串,字符串直接编码,否则对字符串列表编码embeddings = [embedding_model.encode(request.input)]else:embeddings = [embedding_model.encode(text) for text in request.input]embeddings = [embedding.tolist() for embedding in embeddings]def num_tokens_from_string(string: str) -> int:"""Returns the number of tokens in a text string.use cl100k_base tokenizer"""encoding = tiktoken.get_encoding('cl100k_base')num_tokens = len(encoding.encode(string))return num_tokensresponse = {"data": [{"object": "embedding","embedding": embedding,"index": index}for index, embedding in enumerate(embeddings)],"model": request.model,"object": "list","usage": CompletionUsage(prompt_tokens=sum(len(text.split()) for text in request.input), completion_tokens=0,total_tokens=sum(num_tokens_from_string(text) for text in request.input),)}return response

这个函数是获取文本向量编码的,sentences_to_embeddings功能。
这里面有个函数num_tokens_from_string是统计文本的tokens数量,使用的tiktoken模块是openai开源的一个快速分词统计库,cl100k_base是和gpt4同款编码器,也就是说glm3的tokenizer实际上是使用的gpt4的tokenizer,在论文里面glm的baseline是最开始的gpt-1模型,那从理论上,glm3的性能提升肯定会受到分词的影响的(清华博士教大家的水论文小技巧hhh)。

class ModelCard(BaseModel):id: strobject: str = "model"created: int = Field(default_factory=lambda: int(time.time()))owned_by: str = "owner"root: Optional[str] = Noneparent: Optional[str] = Nonepermission: Optional[list] = None@app.get("/v1/models", response_model=ModelList)
async def list_models():model_card = ModelCard(id="chatglm3-6b")return ModelList(data=[model_card])

这个list_models直接限定了就是chatglm3-6b模型了,里面没有包括实际的模型

@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):global model, tokenizerif len(request.messages) < 1 or request.messages[-1].role == "assistant":raise HTTPException(status_code=400, detail="Invalid request")#截这个的原因是gpt模型是允许任意角色的消息序列的包括assistant多次生成的功能。glm3则不允许if request.stream:#SSE流式响应response = generate_chatglm3(model, tokenizer, gen_params) #直接响应message = ChatMessage(role="assistant",content=response["text"],...)#创建消息体#计算使用量然后返回响应体,choice_data里面只放了一个数据return ChatCompletionResponse(model=request.model,id="",  # for open_source model, id is emptychoices=[choice_data],object="chat.completion",usage=usage)

chat最核心的响应函数了,由于函数较长就不全截了。

首先我们看到的是一个消息验证不允许assistant多次生成,原因主要是这个功能本身对助手是没有什么意义的,而且多次生成的训练效果比较差,之前我测试过gpt api的多次生成。因为他们用的训练数据基本上都是一个消息内全部回复了,上下文数据本身不存在多次生成的场景,因此这些模型多次生成并不是把问题分多次回复(和人类不同,一句话可以多方面讲,分段讲),只是把答案回答多次。

如果要实现更真实的问答AI,拥有更真实的对话体验,那对数据的要求是很高的,最好的数据集应该是QQ微信这种聊天软件的数据,但是企业是不可能拿这些隐私数据训练的。不过也有平替,如贴吧微博这些开放平台的数据也是很好的,但是这些数据看过后,上下文的逻辑性还是有问题的,并且多轮对话的人物被屏蔽了,也就是说明明是多个人的对话被训练成了二人的对话,这些模型后面肯定被高质量多轮对话微调过,不然单纯这些语料不会达到gpt的这种效果。

响应类型分为直接响应和SSE响应,其中直接响应简单,就是拿model直接推理得到message。
这里有个问题是这个chat函数是asyn异步的,但是model资源是global的单个模型,如果同时多个请求可能会报错。可以对模型封装个请求拥塞队列,比如大于3个请求就返回繁忙。

SSE响应部分和直接响应不同,SSE没有提供使用量这些信息,仅返回了响应文本,SSE还对前端的响应方法有要求,因此如果是仅学习开发和小规模应用没有必要追求SSE

    predict_stream_generator = predict_stream(request.model, gen_params)output = next(predict_stream_generator)if not contains_custom_function(output):return EventSourceResponse(predict_stream_generator, media_type="text/event-stream")

通过predict_stream创建一个生成器,next生成下一个字符然后返回。

utils.py

utils.py提供了响应的实现函数generate_stream_chatglm3和generate_chatglm3

def generate_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict):for response in generate_stream_chatglm3(model, tokenizer, params):passreturn response

循环调用generate_stream_chatglm3后返回响应

generate_stream_chatglm3函数

def generate_stream_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict):messages = params["messages"] #消息tools = params["tools"] #工具temperature = float(params.get("temperature", 1.0)) #温度参数repetition_penalty = float(params.get("repetition_penalty", 1.0))#惩罚参数,transformer有个问题就是高概率文本会重复生成,在有的论文中提出了惩罚参数,即对已经生成的token的概率乘上惩罚参数让这个token的概率变小,减小重复概率。top_p = float(params.get("top_p", 1.0)) #top_p top_k是采样的一个过滤方法,p是按概率阈值过滤,k是按排序过滤max_new_tokens = int(params.get("max_tokens", 256)) #最大允许新生成的tokensecho = params.get("echo", True)messages = process_chatglm_messages(messages, tools=tools)#消息处理query, role = messages[-1]["content"], messages[-1]["role"]#最后一个消息内容inputs = tokenizer.build_chat_input(query, history=messages[:-1], role=role) #把历史和问题构建输入inputs = inputs.to(model.device)input_echo_len = len(inputs["input_ids"][0])#输入编码序列长度if input_echo_len >= model.config.seq_length: #输入序列长度限制print(f"Input length larger than {model.config.seq_length}")eos_token_id = [ #结束tokentokenizer.eos_token_id,tokenizer.get_command("<|user|>"),tokenizer.get_command("<|observation|>")]gen_kwargs = { #控制参数"max_new_tokens": max_new_tokens,"do_sample": True if temperature > 1e-5 else False,"top_p": top_p,"repetition_penalty": repetition_penalty,"logits_processor": [InvalidScoreLogitsProcessor()],}if temperature > 1e-5:gen_kwargs["temperature"] = temperaturetotal_len = 0for total_ids in model.stream_generate(**inputs, eos_token_id=eos_token_id, **gen_kwargs):total_ids = total_ids.tolist()[0]total_len = len(total_ids)if echo: #没看懂echo什么意思 input_echo_len应该是生成的total_ids中echo控制是否对问题重复一遍,重复了就减掉output_ids = total_ids[:-1]  else:output_ids = total_ids[input_echo_len:-1]#反正output_ids是stream_generate的idsresponse = tokenizer.decode(output_ids)if response and response[-1] != "�": #乱码了就跳出response, stop_found = apply_stopping_strings(response, ["<|observation|>"]) #判断是否结束yield { #yield作为一个生成器每次调用生成output_ids然后返回"text": response,"usage": {"prompt_tokens": input_echo_len,#输入tokens"completion_tokens": total_len - input_echo_len,#总tokens-重复的输入tokens"total_tokens": total_len,#总tokens},"finish_reason": "function_call" if stop_found else None,}if stop_found:break#最后一个字符跳出返回结束# Only last stream result contains finish_reason, we set finish_reason as stopret = {"text": response,"usage": {"prompt_tokens": input_echo_len,"completion_tokens": total_len - input_echo_len,"total_tokens": total_len,},"finish_reason": "stop",}yield ret#内存显存收下垃圾gc.collect()torch.cuda.empty_cache() 

其中里面有个函数很关键process_chatglm_messages:消息处理函数

def process_chatglm_messages(messages, tools=None):_messages = messagesmessages = []msg_has_sys = Falseif tools:messages.append({"role": "system","content": "Answer the following questions as best as you can. You have access to the following tools:","tools": tools})msg_has_sys = Truefor m in _messages:role, content, func_call = m.role, m.content, m.function_callif role == "function":messages.append({"role": "observation","content": content})elif role == "assistant" and func_call is not None:for response in content.split("<|assistant|>"):metadata, sub_content = response.split("\n", maxsplit=1)messages.append({"role": role,"metadata": metadata,"content": sub_content.strip()})else:if role == "system" and msg_has_sys:msg_has_sys = Falsecontinuemessages.append({"role": role, "content": content})return messages

这个函数就是把message对象转化为dict对象,我们可以看到这里面有system、observation、assistant、user。

在generate_stream_chatglm3:inputs = tokenizer.build_chat_input(query, history=messages[:-1], role=role)
这个函数中把dict对象对应的history转换成了文本格式,例如:

<|system|>
You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.
<|user|>
Hello
<|assistant|>
Hello, I'm ChatGLM3. What can I assist you today?

也就是说我们以为的多轮对话实际上就是把历史记录拼起来的。

这部分想到了个idea,这种拼起来的实际上有历史限制,如果让模型生成每个对话的重要性,然后按照重要性+时间权重排序选择性记忆能不能增强长期记忆能力?感觉这部分应该有人在做或者做出来了。

main

最后回来看下api_server.py的main函数

if __name__ == "__main__":# Load LLMtokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True)model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()# load Embeddingembedding_model = SentenceTransformer(EMBEDDING_PATH, device="cuda")uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)

main函数从transformers加载模型然后作为global对象推理。
transformers模型就和传统的bert、t5类似了

chatglm的改进主要包括:

  • 二维位置编码+GELU+残差、层归一化重排序
  • 文档级+句子 NLG预训练
  • NLG+NLU两种任务都进行训练,同时微调的时候还使用了slot填空的NLU方法

总结

之前没怎么看过这种有上下文模型响应的完整流程,这趟下来解决了我之前好几个疑惑:

  1. transformer的重复问题我遇到了好几次,可以通过惩罚参数控制
  2. 上下文实现方法-实际上还是把历史对话融在一起
  3. 模型推理资源占用问题,请求队列感觉是一定要有的,web框架本身是异步请求响应的,不对临界资源管理感觉没啥可靠性

加上这个,目前已经把带上下文的文本生成+知识库扩展永久记忆解决了,后面再对模型结构魔改下,然后集成一些动作指令,就可以实现本地部署家用AI了hhh。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 《斯科特·凯尔比的风光摄影手册》读书笔记
  • 刷题之单词规律同构字符串(leetcode)
  • 2022-10-26 Qt6.5版本后视频渲染
  • Go 初始化一个字典value是列表
  • 前端/python脚本/转换-使用天地图下载的geojson(echarts4+如果直接使用会导致坐标和其他信息不全)
  • MongoDB - 查询操作符:比较查询、逻辑查询、元素查询、数组查询
  • 安全防御----防火墙综合实验2
  • 图论---匈牙利算法求二分图最大匹配的实现
  • pdf只要前几页,pdf中只要前几页怎么处理
  • pytorch-pytorch之LSTM
  • DockerCompose介绍,安装,使用
  • 【错题集-编程题】四个选项(DFS + 剪枝 + 哈希表)
  • 利用 AI 解放双手:把“贾维斯”带进现实 | 开源专题 No.64
  • 拥抱UniHttp,规范Http接口对接之旅
  • 基于JavaSpringBoot+Vue+uniapp微信小程序校园宿舍管理系统设计与实现
  • 78. Subsets
  • eclipse的离线汉化
  • fetch 从初识到应用
  • Javascripit类型转换比较那点事儿,双等号(==)
  • JS变量作用域
  • Material Design
  • mysql中InnoDB引擎中页的概念
  • PAT A1017 优先队列
  • React Transition Group -- Transition 组件
  • Redis在Web项目中的应用与实践
  • Sass 快速入门教程
  • 初识MongoDB分片
  • 从地狱到天堂,Node 回调向 async/await 转变
  • 动手做个聊天室,前端工程师百无聊赖的人生
  • 机器学习中为什么要做归一化normalization
  • 我的zsh配置, 2019最新方案
  • 以太坊客户端Geth命令参数详解
  • Nginx惊现漏洞 百万网站面临“拖库”风险
  • ​ ​Redis(五)主从复制:主从模式介绍、配置、拓扑(一主一从结构、一主多从结构、树形主从结构)、原理(复制过程、​​​​​​​数据同步psync)、总结
  • ​【原创】基于SSM的酒店预约管理系统(酒店管理系统毕业设计)
  • ​LeetCode解法汇总1410. HTML 实体解析器
  • ​RecSys 2022 | 面向人岗匹配的双向选择偏好建模
  • # wps必须要登录激活才能使用吗?
  • #Linux(帮助手册)
  • #微信小程序(布局、渲染层基础知识)
  • (5)STL算法之复制
  • (CVPRW,2024)可学习的提示:遥感领域小样本语义分割
  • (c语言版)滑动窗口 给定一个字符串,只包含字母和数字,按要求找出字符串中的最长(连续)子串的长度
  • (附源码)c#+winform实现远程开机(广域网可用)
  • (附源码)php投票系统 毕业设计 121500
  • (附源码)计算机毕业设计SSM疫情社区管理系统
  • (附源码)小程序 交通违法举报系统 毕业设计 242045
  • (九)One-Wire总线-DS18B20
  • (六)Hibernate的二级缓存
  • (免费领源码)Python#MySQL图书馆管理系统071718-计算机毕业设计项目选题推荐
  • (十八)三元表达式和列表解析
  • (转)视频码率,帧率和分辨率的联系与区别
  • .gitignore文件忽略的内容不生效问题解决
  • .gitignore文件设置了忽略但不生效
  • .net framwork4.6操作MySQL报错Character set ‘utf8mb3‘ is not supported 解决方法