当前位置: 首页 > news >正文

0基础学会在亚马逊云科技AWS上利用SageMaker、PEFT和LoRA高效微调AI大语言模型(含具体教程和代码)

项目简介:

小李哥今天将继续介绍亚马逊云科技AWS云计算平台上的前沿前沿AI技术解决方案,帮助大家快速了解国际上最热门的云计算平台亚马逊云科技AWS上的AI软甲开发最佳实践,并应用到自己的日常工作里。本次介绍的是如何在Amazon SageMaker上微调(Fine-tune)大语言模型dolly-v2-3b,满足日常生活中不同的场景需求,并将介分享如何在SageMaker上优化模型性能并节省计算资源实现成本控制,最后将部署后的大语言模型URL集成到自己云上的软件应用中。

本方案包括通过Amazon Cloudfront和S3托管前端页面,并通过Amazon API Gateway和AWS Lambda将应用程序与AI模型集成,调用大模型实现推理。本方案的解决方案架构图如下:

利用微调模型创建的对话机器人前端UI

利用本方案小李哥用微调后的模型搭建了一个Q&A对话机器人助手,可以生成代码、文字总结、回答问题。

在开始分享案例之前,我们来了解一下本方案的技术背景,帮助大家更好的理解方案架构。

什么是Amazon SageMaker?

Amazon SageMaker 是一个完全托管的机器学习服务(大家可以理解为Serverless的Jupyter Notebook),专为应用开发和数据科学家设计,帮助他们快速构建、训练和部署机器学习模型。使用 SageMaker,您无需担心底层基础设施的管理,可以专注于模型的开发和优化。它提供了一整套工具和功能,包括数据准备、模型训练、超参数调优、模型部署和监控,简化了整个机器学习工作流程。

本方案将介绍以下内容:

1. 使用 SageMaker Jupyter Notebook进行dolly-v2-3b模型开发和微调

2. 在SageMaker部署微调后的大语言模型LLM并基于数据进行推理

3. 使用多场景的测试案例验证推理结果表现,并将部署的模型API节点集成进云端应用

项目搭建具体步骤:

下面跟着小李哥手把手微调一个亚马逊云科技AWS上的生成式AI模型(dolly-v2-3b)的软件应用,并将AI大模型部署与应用集成。

1. 在控制台进入Amazon SageMaker, 点击Notebook

2. 打开Jupyter Notebook

3. 创建一个新的Notebook:“lab-notebook.ipynb”并打开

4. 接下来我们在单元格内一步一步运行代码,检查CUDA的内存状态

!nvidia-smi

5.接下来,我们安装必要依赖并导入

%%capture!pip3 install -r requirements.txt --quiet
!pip install sagemaker --quiet --upgrade --force-reinstall
%%captureimport os
import numpy as np
import pandas as pd
from typing import Any, Dict, List, Tuple, Union
from datasets import Dataset, load_dataset, disable_caching
disable_caching() ## disable huggingface cachefrom transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import TextDatasetimport torch
from torch.utils.data import Dataset, random_split
from transformers import TrainingArguments, Trainer
import accelerate
import bitsandbytesfrom IPython.display import Markdown

6. 导入提前准备好的FAQs数据集

sagemaker_faqs_dataset = load_dataset("csv", data_files='data/amazon_sagemaker_faqs.csv')['train']
sagemaker_faqs_dataset
sagemaker_faqs_dataset[0]

7. 我们定义用于模型推理的提示词格式

from utils.helpers import INTRO_BLURB, INSTRUCTION_KEY, RESPONSE_KEY, END_KEY, RESPONSE_KEY_NL, DEFAULT_SEED, PROMPT
'''
PROMPT = """{intro}{instruction_key}{instruction}{response_key}{response}{end_key}"""
'''
Markdown(PROMPT)

8. 下面我们进入重头戏,导入一个提前预训练好的LLM大语言模型“databricks/dolly-v2-3b”。

tokenizer = AutoTokenizer.from_pretrained("databricks/dolly-v2-3b", padding_side="left")tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens({"additional_special_tokens": [END_KEY, INSTRUCTION_KEY, RESPONSE_KEY_NL]})model = AutoModelForCausalLM.from_pretrained("databricks/dolly-v2-3b",# use_cache=False,device_map="auto", #"balanced",load_in_8bit=True,
)

9. 对模型训练进行预准备, 处理数据集、优化模型训练(PEFT)效率

model.resize_token_embeddings(len(tokenizer))from functools import partial
from utils.helpers import mlu_preprocess_batchMAX_LENGTH = 256
_preprocessing_function = partial(mlu_preprocess_batch, max_length=MAX_LENGTH, tokenizer=tokenizer)encoded_sagemaker_faqs_dataset = sagemaker_faqs_dataset.map(_preprocessing_function,batched=True,remove_columns=["instruction", "response", "text"],
)processed_dataset = encoded_sagemaker_faqs_dataset.filter(lambda rec: len(rec["input_ids"]) < MAX_LENGTH)split_dataset = processed_dataset.train_test_split(test_size=14, seed=0)
split_dataset

10. 同时我们使用LoRA(Low-Rank Adaptation)模型加速我们的模型微调

from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskTypeMICRO_BATCH_SIZE = 8  
BATCH_SIZE = 64
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
LORA_R = 256 # 512
LORA_ALPHA = 512 # 1024
LORA_DROPOUT = 0.05# Define LoRA Config
lora_config = LoraConfig(r=LORA_R,lora_alpha=LORA_ALPHA,lora_dropout=LORA_DROPOUT,bias="none",task_type="CAUSAL_LM"
)model = get_peft_model(model, lora_config)
model.print_trainable_parameters()from utils.helpers import MLUDataCollatorForCompletionOnlyLMdata_collator = MLUDataCollatorForCompletionOnlyLM(tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8
)

11. 接下来我们定义模型训练参数并开始训练。其中Batch=1,Step=20000,epoch为10.

EPOCHS = 10
LEARNING_RATE = 1e-4  
MODEL_SAVE_FOLDER_NAME = "dolly-3b-lora"training_args = TrainingArguments(output_dir=MODEL_SAVE_FOLDER_NAME,fp16=True,per_device_train_batch_size=1,per_device_eval_batch_size=1,learning_rate=LEARNING_RATE,num_train_epochs=EPOCHS,logging_strategy="steps",logging_steps=100,evaluation_strategy="steps",eval_steps=100, save_strategy="steps",save_steps=20000,save_total_limit=10,
)trainer = Trainer(model=model,tokenizer=tokenizer,args=training_args,train_dataset=split_dataset['train'],eval_dataset=split_dataset["test"],data_collator=data_collator,
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()

12. 接下来我们将微调后的模型保存在本地

trainer.model.save_pretrained(MODEL_SAVE_FOLDER_NAME)trainer.save_model()trainer.model.config.save_pretrained(MODEL_SAVE_FOLDER_NAME)tokenizer.save_pretrained(MODEL_SAVE_FOLDER_NAME)

13. 接下来,我们将保存到本地的模型进行部署,生成公开访问的API节点Endpoint

对部署所需要的参数进行定义和初始化

import boto3
import json
import sagemaker.djl_inference
from sagemaker.session import Session
from sagemaker import image_uris
from sagemaker import Modelsagemaker_session = Session()
print("sagemaker_session: ", sagemaker_session)aws_role = sagemaker_session.get_caller_identity_arn()
print("aws_role: ", aws_role)aws_region = boto3.Session().region_name
print("aws_region: ", aws_region)image_uri = image_uris.retrieve(framework="djl-deepspeed",version="0.22.1",region=sagemaker_session._region_name)
print("image_uri: ", image_uri)

进行模型部署

model_data="s3://{}/lora_model.tar.gz".format(mybucket)model = Model(image_uri=image_uri,model_data=model_data,predictor_cls=sagemaker.djl_inference.DJLPredictor,role=aws_role)

14.最后我们写入提示词,对大语言模型进行测试, 得到推理

outputs = predictor.predict({"inputs": "What solutions come pre-built with Amazon SageMaker JumpStart?"})from IPython.display import Markdown
Markdown(outputs)

15. 我们下面进入SageMaker Endpoint页面,得到刚部署的模型API端点的URL,通过这种方式我们就可以在应用中调用我们的微调后的大语言模型了。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 服务客户,保证质量:腾讯云产品的质量实践
  • 医疗健康信息的安全挑战与隐私保护最佳实践
  • 【周末闲谈】Stable Diffusion会魔法的绘画师
  • Facebook软体机器人与机器人框架:创新社交互动的未来
  • HarmonyOS 界面开发基础篇
  • 算法日常练习
  • JAVA简单封装UserUtil
  • LangChain与正则表达式:探索文本匹配的强大工具
  • 【数据结构】--- 堆的应用
  • 什么是SQL锁
  • ES6 Iterator 与 for...of 循环(五)
  • Spring中的适配器模式和策略模式
  • 罗马仕充电宝怎么样?罗马仕、西圣、绿联无线充电宝测评对比!
  • Spring boot 2.0 升级到 3.3.1 的相关问题 (一)
  • 力扣题解(不相交的线)
  • 30秒的PHP代码片段(1)数组 - Array
  • Computed property XXX was assigned to but it has no setter
  • Go 语言编译器的 //go: 详解
  • iOS编译提示和导航提示
  • JS笔记四:作用域、变量(函数)提升
  • oldjun 检测网站的经验
  • Selenium实战教程系列(二)---元素定位
  • -- 查询加强-- 使用如何where子句进行筛选,% _ like的使用
  • 第13期 DApp 榜单 :来,吃我这波安利
  • 仿天猫超市收藏抛物线动画工具库
  • 给第三方使用接口的 URL 签名实现
  • 解决iview多表头动态更改列元素发生的错误
  • 码农张的Bug人生 - 见面之礼
  • 前嗅ForeSpider教程:创建模板
  • 收藏好这篇,别再只说“数据劫持”了
  • 思否第一天
  • 问题之ssh中Host key verification failed的解决
  • 要让cordova项目适配iphoneX + ios11.4,总共要几步?三步
  • Salesforce和SAP Netweaver里数据库表的元数据设计
  • 阿里云服务器购买完整流程
  • 如何用纯 CSS 创作一个货车 loader
  • ​比特币大跌的 2 个原因
  • !! 2.对十份论文和报告中的关于OpenCV和Android NDK开发的总结
  • #数学建模# 线性规划问题的Matlab求解
  • #图像处理
  • ( 用例图)定义了系统的功能需求,它是从系统的外部看系统功能,并不描述系统内部对功能的具体实现
  • (175)FPGA门控时钟技术
  • (pojstep1.1.2)2654(直叙式模拟)
  • (顶刊)一个基于分类代理模型的超多目标优化算法
  • (附源码)php新闻发布平台 毕业设计 141646
  • (附源码)计算机毕业设计SSM智慧停车系统
  • (蓝桥杯每日一题)平方末尾及补充(常用的字符串函数功能)
  • (五)c52学习之旅-静态数码管
  • (原創) 如何刪除Windows Live Writer留在本機的文章? (Web) (Windows Live Writer)
  • (转)chrome浏览器收藏夹(书签)的导出与导入
  • (转)平衡树
  • (转载)PyTorch代码规范最佳实践和样式指南
  • ****** 二十三 ******、软设笔记【数据库】-数据操作-常用关系操作、关系运算
  • .NET中的十进制浮点类型,徐汇区网站设计
  • .NET中统一的存储过程调用方法(收藏)