gpt-2语言模型训练
一、通过下载对应的语言模型数据集
1.1 根据你想让回答的内容,针对性下载对应的数据集,我下载的是个医疗问答数据集
1.2 针对你要用到的字段信息进行处理,然后把需要处理的数据丢给模型去训练,这个模型我是直接从GPT2的网站下载下来的依赖的必要文件截图如下:
二、具体代码样例实现:
import os
import pandas as pd
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments, TextDataset, \DataCollatorForLanguageModeling
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import AutoTokenizer, AutoModelForCausalLM# 读取CSV文件
data_path = '内科500.csv' # 替换为你的CSV文件路径
df = pd.read_csv(data_path, encoding='ISO-8859-1')# 将数据集转换为适合训练的格式
def preprocess_dialogues(df):conversations = []for index, row in df.iterrows():department = row['department']title = row['title']ask = row['ask']answer = row['answer']# 将每条问答对转换为连续的对话context = f"科室: {department}\n问题: {title}\n提问: {ask}\n回答: {answer}\n"conversations.append(context)return conversationsconversations = preprocess_dialogues(df)# 保存对话数据到文本文件
train_file_path = 'train_data.txt'
with open(train_file_path, 'w', encoding='utf-8') as file:for conversation in conversations:file.write(conversation + '\n')# 加载预训练模型和tokenizer
# tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('./gpt2-model')
model = GPT2LMHeadModel.from_pretrained('./gpt2-model')# 准备数据集
def load_dataset(file_path, tokenizer, block_size=128):return TextDataset(tokenizer=tokenizer,file_path=file_path,block_size=block_size)train_dataset = load_dataset(train_file_path, tokenizer)# 数据整理器
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer,mlm=False
)# 训练参数
training_args = TrainingArguments(output_dir='./results',overwrite_output_dir=True,num_train_epochs=3,per_device_train_batch_size=4,save_steps=10_000,save_total_limit=2,resume_from_checkpoint=True # 从检查点恢复训练
)# 创建Trainer
trainer = Trainer(model=model,args=training_args,data_collator=data_collator,train_dataset=train_dataset
)last_checkpoint = None
if os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir):last_checkpoint = training_args.output_dir
# 开始训练
trainer.train(resume_from_checkpoint=last_checkpoint)