当前位置: 首页 > news >正文

【基础模型】开始构建我们自己的大语言模型2:模型构建部分

在构建大语言模型的旅途中,模型构建是至关重要的一步。大语言模型通常基于深度学习技术,特别是循环神经网络(RNN)及其变体,如长短期记忆网络(LSTM)或门控循环单元(GRU),这些网络擅长处理序列数据。(我们在上一章中提到过,见上一章【拓展】部分)

1. 定义模型构建函数

首先,我们需要一个函数来定义并构建我们的模型。这个函数将接收多个参数,如词汇表大小(vocab_size)、嵌入维度(embedding_dim)、RNN单元数(rnn_units)、批量大小(batch_size)以及模型版本(mt)和窗口大小(window,目前他没用但以后我们优化模型的时候会用到)。

这是GRU的简单流程图:

输入层: 接收当前时间步的输入数据
重置门: 即sigmoid函数 决定新输入与旧记忆的结合
更新门: sigmoid函数并控制旧状态信息的保留程度
候选隐藏状态: 基于重置门输出和当前输入计算
最终隐藏状态: 结合更新门输出候选隐藏状态和旧隐藏状态

def build_model(vocab_size, embedding_dim, rnn_units, batch_size, mt=2.2, window=128):if mt == 1:# mt:模型版本,目前我们只有1版本model = tf.keras.Sequential([tf.keras.layers.Embedding(vocab_size, embedding_dim,batch_input_shape=[batch_size, None]),tf.keras.layers.GRU(rnn_units,return_sequences=True,stateful=True,recurrent_initializer='glorot_uniform'),#GRU层,整个模型的核心部分。你也可以换成LSTM等tf.keras.layers.Dropout(0.2),  # 添加Dropout防止过拟合tf.keras.layers.Dense(vocab_size)  # 输出层,对应词汇表大小])return model
2. 损失函数

损失函数,简单来说,就是用来衡量我们模型预测结果和实际结果之间差异的一个“尺子”。在机器学习和深度学习中,我们的目标是让模型学会从输入数据中预测出正确的输出。但是,模型一开始并不知道怎么预测,它会通过不断地学习和调整自己的参数来尽可能地减少预测结果和实际结果之间的差异。
这个“差异”的大小,就是由损失函数来计算的。如果预测结果和实际结果完全一样,那么损失函数的值就是0(或者一个非常接近0的数),表示没有差异;如果预测结果和实际结果有差异,那么损失函数的值就会大于0,差异越大,损失函数的值就越大。


损失函数就像是一个“教练”,它告诉模型:“你看,你这次预测错了这么多,你得继续努力,调整你的参数,减少这个差异。”模型就会根据损失函数的反馈,不断地调整自己的参数,直到损失函数的值降到很低,也就是预测结果和实际结果之间的差异变得很小。


不同的任务会使用不同的损失函数。比如,在分类任务中,我们常用的损失函数有交叉熵损失函数;在回归任务中,我们常用的损失函数有均方误差损失函数。这些损失函数都是根据任务的特点来设计的,能够很好地衡量模型预测结果和实际结果之间的差异。

对于语言模型,我们通常使用交叉熵损失函数来衡量预测结果与实际标签之间的差异。在TensorFlow中,sparse_categorical_crossentropy是处理此类问题的理想选择,特别是当标签是整数时。

def loss(labels, logits):return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)
3. 加载模型权重

在训练过程中,我们可能希望从之前的检查点加载模型权重,以继续训练或进行预测。这可以通过model.load_weights()方法实现。

# fwidx与loadmodel在上一章的函数定义传参部分已经定义了
if fwidx > 0 or loadmodel:  model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=BATCH_SIZE, mt=mt, window=window)model.load_weights(checkpoint_path)
else:model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=BATCH_SIZE, mt=mt, window=window)
4. 编译模型

在模型训练之前,我们需要使用优化器和损失函数来编译模型。这里,我们使用Adam优化器,它自动调整学习率,非常适合大多数深度学习场景。

model.compile(optimizer=tf.keras.optimizers.Adam(lr=LR), loss=loss)
5. 模型概览

使用model.summary()方法可以打印出模型的概览,包括每层的名称、输出形状和参数数量。

try:model.summary()
except ValueError:print('Model Can not summary, We will summary it after a batch of train.')

如果因为某些原因(如模型未接收到输入数据的维度信息)而无法打印概览,可以在训练开始后再尝试。

总结

下一章我们就要开始正式的训练模型了!

上一章与本章的所有完整代码:

                                            
def train(mt=3,big_file=False,#是否采用大文件加载策略#微调数据集path_to_file = r'en_novel.txt',ntype_='_en',#保存为微调模型名称#设置vocab版本vtype_='_lx',#type_#fen=50,#数据量分几份fwidx=0,#第几份BATCH_SIZE = 64,loadmodel=False,pass_=-1,ste=0,):'''多出的参数不必理会,后面会用到'''global LR,param_data,p_ntypep_ntype=ntype_if ntype_[0]!='_':ntype_='_'+ntype_type_=ntype_print('path_to_file',path_to_file)print('LR',LR)import os#dataset与vocab是配对的!if not os.path.exists(r'E:\小思框架\论文\ganskchat\vocab'+vtype_+'.txt'):raise Exception("can't reading vocab from "+r'E:\小思框架\论文\ganskchat\vocab'+vtype_+'.txt')else:with open('E:\\小思框架\\论文\\ganskchat\\vocab'+vtype_+'.txt','r',encoding='utf-8') as f:vocab=eval(f.read())UNK=0unkli=[]char2idx = {u:i for i, u in enumerate(vocab)}idx2char = np.array(vocab)print('{')for char,_ in zip(char2idx, range(20)):print('  {:4s}: {:3d},'.format(repr(char), char2idx[char]))print('  ...\n}')# 设定每个输入句子长度的最大值seq_length = dic[mt][2]def split_input_target(chunk):input_text = chunk[:-1]target_text = chunk[1:]return input_text, target_textimport tensorflow as tfimport pickle# 假设 BATCH_SIZE 和 BUFFER_SIZE 已经定义好if 1:# 设定缓冲区大小,以重新排列数据集BUFFER_SIZE = 50000# 词集的长度vocab_size = len(vocab)# 嵌入的维度embedding_dim = dic[mt][0]#int(1024*2*1)# RNN 的单元数量rnn_units = dic[mt][1]#int(1024*4*2)window = dic[mt][2]# 加载保存的数据集def load_dataset(path):dataset = tf.data.experimental.load(path)return datasetif os.path.exists(r'E:\小思框架\论文\ganskchat\dataset'+ntype_+'_'+str(fwidx)):#换了batch后要重新处理数据集print('loading dataset')# 加载已经打乱过的数据集dataset = load_dataset(r'E:\小思框架\论文\ganskchat\dataset'+ntype_+'_'+str(fwidx))else:if big_file:text=[]with open(path_to_file, 'r',encoding='utf-8') as f:idxlen=0print('getting length\n')for _ in tqdm.tqdm(f):idxlen+=1st=idxlen//fen*fwidxed=idxlen//fen*(fwidx+1)with open(path_to_file, 'r',encoding='utf-8') as f:idx=0print('\n\nrunning data\n')for _ in tqdm.tqdm(f):if idx<st:continueif idx>=ed:breaktext.append(_)idx+=1text=''.join(text)else:text = open(path_to_file, 'r',encoding='utf-8').read()idxlen=len(text)st=idxlen//fen*fwidxed=idxlen//fen*(fwidx+1)text=text[st:ed]print('data size:',len(text))#text_as_int = np.array([char2idx[c] for c in text])text_as_int=[]cks=list(char2idx.keys())unk_li=set()len_=0for c in tqdm.tqdm(text):if c in cks:text_as_int.append(char2idx[c])else:c=cc.convert(c)if c in cks:#转为简体再次尝试text_as_int.append(char2idx[c])else:if not is_add:if not c in unk_li:with open('unk'+ntype_+'.txt','w',encoding='utf-8') as f:f.write(str(len_)+'\n'+str(list(unk_li)))#print('unknow:',repr(c))unk_li.add(c)len_+=1text_as_int.append(UNK)text_as_int=np.array(text_as_int)#————————————————————————————# 创建训练样本 / 目标char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)if 0:for i in char_dataset.take(5):print(i.numpy())print(idx2char[i.numpy()])sequences = char_dataset.batch(seq_length+1, drop_remainder=True)for item in sequences.take(5):print(repr(''.join(idx2char[(item.numpy())])))dataset = sequences.map(split_input_target)dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)tf.data.Dataset.save(dataset,'./dataset'+ntype_+'_'+str(fwidx) )# 以上是上一章代码
######################这是分界线#######################
# 以下是本章的代码def build_model(vocab_size, embedding_dim, rnn_units, batch_size,mt=2.2,window=128):#global mtif mt==1:model = tf.keras.Sequential([tf.keras.layers.Embedding(vocab_size, embedding_dim,batch_input_shape=[batch_size, None]),tf.keras.layers.GRU(rnn_units,return_sequences=True,stateful=True,recurrent_initializer='glorot_uniform'),tf.keras.layers.Dropout(0.2),tf.keras.layers.Dense(vocab_size)])return modelelse:raise Exception('MT Error!')def loss(labels, logits):return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.0035), loss=loss)def preprocess_data_for_training(dd):input_sequences, target_sequences = ddreturn input_sequences, target_sequencesEPOCHS = 50000import osimport tensorflow as tfcheckpoint_dir = r'./dataset/ckpt'+type_if not os.path.exists(checkpoint_dir):os.makedirs(checkpoint_dir)steps_per_epoch = len(dataset) // BATCH_SIZEct()checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)ncheckpoint_dir = r'./dataset/ckpt'+ntype_if not os.path.exists(ncheckpoint_dir):os.makedirs(ncheckpoint_dir)ncheckpoint_path = tf.train.latest_checkpoint(ncheckpoint_dir)ct()if fwidx>0 or loadmodel:model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=BATCH_SIZE,mt=mt,window=window)# 直接加载权重到模型中model.load_weights(checkpoint_path)else:#传入新的mtmodel = build_model(vocab_size, embedding_dim, rnn_units, batch_size=BATCH_SIZE,mt=mt,window=window)try:model.summary()ns=0except ValueError:print('Model Can not summary, We will summary it after a batch of train.')ns=1model.compile(optimizer=tf.keras.optimizers.Adam(lr=LR), loss=loss)

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • ElementUI,el-input输入框max、min限制最大最小值失效
  • cdga|数据资产运营:加速企业数据价值释放的新引擎
  • 【MySQL】访问mysqld的方式{命令行客户端/vscode-c-api客户端/图形化界面:mysql/navicat}
  • lucene中nvd和nvm索引文件作用以及规范化值是如何影响文档评分
  • uniapp map组件自定义markers标记点
  • 基于Windows Docker desktop搭建pwn环境
  • ATA-M8功率放大器在变压器老化中的作用是什么
  • python: 打包好的exe程序(冻结程序)中使用多进程,子进程不能正常执行!
  • 八股文-基础知识-int和Integer有什么区别?
  • 图片url处理(带http和不带http)方法
  • “微软蓝屏”事件:网络安全与稳定性的深刻反思
  • 深入学习H264和H265
  • 数据集相关类代码回顾理解 | StratifiedShuffleSplit\transforms.ToTensor\Counter
  • 【组件协作】模板方法
  • 【TS】TypeScript数组类型:掌握数据集合的类型安全
  • [分享]iOS开发 - 实现UITableView Plain SectionView和table不停留一起滑动
  • 《Java编程思想》读书笔记-对象导论
  • canvas 高仿 Apple Watch 表盘
  • css属性的继承、初识值、计算值、当前值、应用值
  • exports和module.exports
  • Git同步原始仓库到Fork仓库中
  • HTML中设置input等文本框为不可操作
  • iOS 系统授权开发
  • Spring-boot 启动时碰到的错误
  • 第三十一到第三十三天:我是精明的小卖家(一)
  • 飞驰在Mesos的涡轮引擎上
  • 互联网大裁员:Java程序员失工作,焉知不能进ali?
  • 简单实现一个textarea自适应高度
  • 理解在java “”i=i++;”所发生的事情
  • 前端面试题总结
  • 前端学习笔记之观察者模式
  • 学习使用ExpressJS 4.0中的新Router
  • 用Canvas画一棵二叉树
  • Nginx惊现漏洞 百万网站面临“拖库”风险
  • 浅谈sql中的in与not in,exists与not exists的区别
  • 小白应该如何快速入门阿里云服务器,新手使用ECS的方法 ...
  • ​软考-高级-信息系统项目管理师教程 第四版【第14章-项目沟通管理-思维导图】​
  • ‌JavaScript 数据类型转换
  • # Python csv、xlsx、json、二进制(MP3) 文件读写基本使用
  • $.extend({},旧的,新的);合并对象,后面的覆盖前面的
  • (2)关于RabbitMq 的 Topic Exchange 主题交换机
  • (k8s中)docker netty OOM问题记录
  • (libusb) usb口自动刷新
  • (Redis使用系列) Springboot 在redis中使用BloomFilter布隆过滤器机制 六
  • (SpringBoot)第七章:SpringBoot日志文件
  • (笔试题)合法字符串
  • (二)hibernate配置管理
  • (二)JAVA使用POI操作excel
  • (经验分享)作为一名普通本科计算机专业学生,我大学四年到底走了多少弯路
  • (算法)Travel Information Center
  • (贪心) LeetCode 45. 跳跃游戏 II
  • (一)插入排序
  • (已解决)什么是vue导航守卫
  • (转)ABI是什么
  • (转)原始图像数据和PDF中的图像数据