当前位置: 首页 > news >正文

gpt-2语言模型训练

一、通过下载对应的语言模型数据集 

1.1 根据你想让回答的内容,针对性下载对应的数据集,我下载的是个医疗问答数据集

1.2 针对你要用到的字段信息进行处理,然后把需要处理的数据丢给模型去训练,这个模型我是直接从GPT2的网站下载下来的依赖的必要文件截图如下:

二、具体代码样例实现:

import os
import pandas as pd
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments, TextDataset, \DataCollatorForLanguageModeling
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import AutoTokenizer, AutoModelForCausalLM# 读取CSV文件
data_path = '内科500.csv'  # 替换为你的CSV文件路径
df = pd.read_csv(data_path, encoding='ISO-8859-1')# 将数据集转换为适合训练的格式
def preprocess_dialogues(df):conversations = []for index, row in df.iterrows():department = row['department']title = row['title']ask = row['ask']answer = row['answer']# 将每条问答对转换为连续的对话context = f"科室: {department}\n问题: {title}\n提问: {ask}\n回答: {answer}\n"conversations.append(context)return conversationsconversations = preprocess_dialogues(df)# 保存对话数据到文本文件
train_file_path = 'train_data.txt'
with open(train_file_path, 'w', encoding='utf-8') as file:for conversation in conversations:file.write(conversation + '\n')# 加载预训练模型和tokenizer
# tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('./gpt2-model')
model = GPT2LMHeadModel.from_pretrained('./gpt2-model')# 准备数据集
def load_dataset(file_path, tokenizer, block_size=128):return TextDataset(tokenizer=tokenizer,file_path=file_path,block_size=block_size)train_dataset = load_dataset(train_file_path, tokenizer)# 数据整理器
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer,mlm=False
)# 训练参数
training_args = TrainingArguments(output_dir='./results',overwrite_output_dir=True,num_train_epochs=3,per_device_train_batch_size=4,save_steps=10_000,save_total_limit=2,resume_from_checkpoint=True  # 从检查点恢复训练
)# 创建Trainer
trainer = Trainer(model=model,args=training_args,data_collator=data_collator,train_dataset=train_dataset
)last_checkpoint = None
if os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir):last_checkpoint = training_args.output_dir
# 开始训练
trainer.train(resume_from_checkpoint=last_checkpoint)

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 揭秘!移动安全管理系统是什么?有什么功能?(从小白到精通一文揭晓!)
  • java在实际开发中反常识bug
  • 阿里一款非常不错的多级缓存框架如何使用?
  • Nmap扫描六种端口状态介绍
  • 【java计算机毕设】足浴城消费系统小程序MySQL ssm vue uniapp maven项目设计源代码 编程语言 小组课后作业 寒暑假作业
  • 使用docker compose一键部署 Portainer
  • XSS-过滤特殊符号的正则绕过
  • 从易车“超级818冠军之夜” 看如何借势体育营销点燃汽车消费热潮
  • 框架——Mybatis(!!!MyBatis 环境搭建步骤)
  • Redis远程字典服务器(7)—— set类型详解
  • VAuditDemo常规漏洞
  • DBAPI如何用SQL将多表关联查询出树状结构数据(嵌套JSON格式)
  • 论文解读:LONGWRITER: UNLEASHING 10,000+ WORD GENERATION FROM LONG CONTEXT LLMS
  • 精准掌控,速看顶级软件资产管理方案,让您企业软件资产一目了然!
  • ArcGIS Pro基础:状态栏显示栏的比例尺设置和经纬度位置
  • [js高手之路]搞清楚面向对象,必须要理解对象在创建过程中的内存表示
  • CAP理论的例子讲解
  • Django 博客开发教程 16 - 统计文章阅读量
  • Java新版本的开发已正式进入轨道,版本号18.3
  • JAVA之继承和多态
  • leetcode-27. Remove Element
  • leetcode讲解--894. All Possible Full Binary Trees
  • nfs客户端进程变D,延伸linux的lock
  • pdf文件如何在线转换为jpg图片
  • Vue UI框架库开发介绍
  • Vue2.0 实现互斥
  • Vue实战(四)登录/注册页的实现
  • 从零搭建Koa2 Server
  • 入职第二天:使用koa搭建node server是种怎样的体验
  • 使用 @font-face
  • 微信小程序:实现悬浮返回和分享按钮
  • 用jquery写贪吃蛇
  • AI又要和人类“对打”,Deepmind宣布《星战Ⅱ》即将开始 ...
  • ​一、什么是射频识别?二、射频识别系统组成及工作原理三、射频识别系统分类四、RFID与物联网​
  • #DBA杂记1
  • #pragma once
  • ( 用例图)定义了系统的功能需求,它是从系统的外部看系统功能,并不描述系统内部对功能的具体实现
  • (4) PIVOT 和 UPIVOT 的使用
  • (Python) SOAP Web Service (HTTP POST)
  • (Redis使用系列) SpringBoot中Redis的RedisConfig 二
  • (附源码)spring boot基于Java的电影院售票与管理系统毕业设计 011449
  • (牛客腾讯思维编程题)编码编码分组打印下标(java 版本+ C版本)
  • (实战)静默dbca安装创建数据库 --参数说明+举例
  • (四)linux文件内容查看
  • (一)Thymeleaf用法——Thymeleaf简介
  • (转)C#开发微信门户及应用(1)--开始使用微信接口
  • (轉貼) 2008 Altera 亞洲創新大賽 台灣學生成果傲視全球 [照片花絮] (SOC) (News)
  • (自用)网络编程
  • .net 7 上传文件踩坑
  • .NET Core 2.1路线图
  • .NET Core跨平台微服务学习资源
  • .NET 使用配置文件
  • .Net8 Blazor 尝鲜
  • .net打印*三角形
  • .NET构架之我见