欺诈文本分类微调(一):基座模型选型
背景
随着网络诈骗越来越多,经常有一些不法分子利用网络会议软件进行诈骗,为此需要训练一个文本分类检测模型,来实时检测会议中的对话内容是否存在诈骗风险,以帮助用户在网络会议中增强警惕,减少受到欺诈的风险。
考虑模型的预训练成本高昂,选择一个通用能力强的模型底座进行微调是一个比较经济的做法,我们选择模型底座主要考虑以下几个因素:
- 预训练主要采用中文语料,具有良好的中文支持能力
- 模型需要具备基本的指令遵循能力。
- 模型要能理解json格式,具备输出json格式的能力。
- 在满足以上几个能力的基础上,模型参数量越小越好。
通义千问Qwen2具有0.5B、1.5B、7B、72B等一系列参数大小不等的模型,我们需要做的是从大到小依次测试每个模型的能力,找到满足自己需要的最小参数模型。
模型下载
依次下载qwen2的不同尺寸的模型:0.5B-Instruct、1.5B、1.5B-Instruct、7B-Instruct,下载完后输出模型在本地磁盘的路径。
#模型下载
from modelscope import snapshot_download
cache_dir = '/data2/anti_fraud/models/modelscope/hub'
model_dir = snapshot_download('Qwen/Qwen2-7B-Instruct', cache_dir=cache_dir, revision='master')
# model_dir = snapshot_download('Qwen/Qwen2-0.5B-Instruct', cache_dir=cache_dir, revision='master')
# model_dir = snapshot_download('Qwen/Qwen2-1.5B-Instruct', cache_dir=cache_dir, revision='master')
# model_dir = snapshot_download('Qwen/Qwen2-1.5B', cache_dir=cache_dir, revision='master')
model_dir
'/data2/anti_fraud/models/modelscope/hub/Qwen/Qwen2-7B-Instruct'
封装工具函数
先封装一个函数load_model,用于从参数model_dir指定的路径中加载模型model和序列化器tokenizer,加载完后将模型手动移动到指定的GPU设备上。
我们这里采用的transformers库来加载模型,如果使用modelscope加载只是将
from tranformers
换成from modelscope
。本文测试的目的是用于模型底座选型,所以选择最原始的加载方式。
import os
from transformers import AutoModelForCausalLM, AutoTokenizerdef load_model(model_dir, device='cuda'):model = AutoModelForCausalLM.from_pretrained(model_dir,torch_dtype="auto",trust_remote_code=True# device_map="auto" )model = model.to(device)tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)return model, tokenizer
注:如果同时使用to(device)和device_map=“auto",在多GPU机器上可能会导致
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:2!
,经测试,已经加载到多GPU上的模型并不支持再通过to(device)移动到单GPU上。
再封装一个predict函数用于文本推理,考虑到我们将要用多个不同参数的模型分别进行测试,这里将model和tokenizer提取到参数中,以便复用这个方法。
# Instead of using model.chat(), we directly use model.generate()
# But you need to use tokenizer.apply_chat_template() to format your inputs as shown below
def predict(model, tokenizer, prompt, device='cuda', debug=True):messages = [{