当前位置: 首页 > news >正文

【MindSpore学习打卡】应用实践-LLM原理和实践-基于MindSpore实现BERT对话情绪识别

在当今的自然语言处理(NLP)领域,情绪识别是一个非常重要的应用场景。无论是在智能客服、社交媒体分析,还是在情感计算领域,准确地识别用户的情绪都能够极大地提升用户体验和系统的智能化水平。BERT(Bidirectional Encoder Representations from Transformers)作为一种强大的预训练语言模型,已经在多个NLP任务中展示了其卓越的性能。在这篇博客中,我们将详细介绍如何基于MindSpore框架,利用BERT模型实现对话情绪识别。通过一步步的代码示例和详细解释,帮助你掌握这一技术。

模型简介

BERT(Bidirectional Encoder Representations from Transformers)是一种基于Transformer的双向编码器表征模型。它主要通过两种预训练任务来捕捉词语和句子级别的表征:Masked Language Model(MLM)和Next Sentence Prediction(NSP)。

  • Masked Language Model:随机将语料库中15%的单词进行掩码操作,模型需要预测这些被掩码的单词。
  • Next Sentence Prediction:模型需要预测两个句子之间是否存在顺序关系。

BERT预训练后,可以用于多种下游任务,如文本分类、相似度判断、阅读理解等。

数据集准备

在数据集准备部分,我们下载并解压了百度飞桨团队提供的机器人聊天数据集。这个数据集已经过预处理,并包含了情绪标签。每一行数据由一个标签和一个经过分词处理的文本组成。标签表示情绪类别(0表示消极,1表示中性,2表示积极),文本则是用户的对话内容。通过使用这种结构化的数据,我们可以更方便地进行情感分类任务。

# 下载数据集
!wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz
!tar xvf emotion_detection.tar.gz

数据集格式如下:

label--text_a
0--谁骂人了?我从来不骂人,我骂的都不是人,你是人吗 ?
1--我有事等会儿就回来和你聊
2--我见到你很高兴谢谢你帮我

数据加载和预处理

数据加载和预处理是机器学习流程中至关重要的一步。我们使用了GeneratorDataset来加载数据,并通过映射操作将文本转换为模型可以接受的格式。具体来说,我们使用了BertTokenizer将文本Tokenize成词汇ID,并进行填充(Pad)操作。这样做的目的是确保所有输入序列的长度一致,从而提高训练效率和模型性能。

import numpy as np
from mindspore.dataset import text, GeneratorDataset, transforms
from mindnlp.transformers import BertTokenizerdef process_dataset(source, tokenizer, max_seq_len=64, batch_size=32, shuffle=True):is_ascend = mindspore.get_context('device_target') == 'Ascend'column_names = ["label", "text_a"]dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)type_cast_op = transforms.TypeCast(mindspore.int32)def tokenize_and_pad(text):if is_ascend:tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)else:tokenized = tokenizer(text)return tokenized['input_ids'], tokenized['attention_mask']dataset = dataset.map(operations=tokenize_and_pad, input_columns="text_a", output_columns=['input_ids', 'attention_mask'])dataset = dataset.map(operations=[type_cast_op], input_columns="label", output_columns='labels')if is_ascend:dataset = dataset.batch(batch_size)else:dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id), 'attention_mask': (None, 0)})return datasettokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
dataset_train = process_dataset(SentimentDataset("data/train.tsv"), tokenizer)
dataset_val = process_dataset(SentimentDataset("data/dev.tsv"), tokenizer)
dataset_test = process_dataset(SentimentDataset("data/test.tsv"), tokenizer, shuffle=False)

在这里插入图片描述

模型构建

在模型构建部分,我们使用了BertForSequenceClassification来进行情感分类任务。这个预训练模型已经在大规模语料上进行了训练,具有强大的语言理解能力。通过加载预训练权重,我们可以显著提升模型在情感分类任务上的表现。同时,我们使用了自动混合精度(auto mixed precision)技术,这不仅可以加速训练过程,还能减少显存使用,从而在有限的硬件资源下实现更高效的训练。

优化器和评价指标是模型训练中的重要组件。我们选择了Adam优化器,因为它在处理大规模数据和复杂模型时表现优异。评价指标方面,我们使用了准确率(Accuracy)来衡量模型的性能。通过这些设置,我们可以确保模型在训练过程中不断优化,并在验证集上取得良好的表现。

回调函数在模型训练过程中发挥着重要作用。我们设置了两个回调函数:CheckpointCallbackBestModelCallback。前者用于定期保存模型的权重,后者则自动加载表现最好的模型权重。通过这些回调函数,我们可以确保在训练过程中不会丢失重要的模型参数,并且始终使用表现最佳的模型进行推理和评估。

from mindnlp.transformers import BertForSequenceClassification
from mindspore import nn
from mindnlp._legacy.amp import auto_mixed_precisionmodel = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)
model = auto_mixed_precision(model, 'O1')optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)
metric = Accuracy()
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='bert_emotect', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='bert_emotect_best', auto_load=True)trainer = Trainer(network=model, train_dataset=dataset_train,eval_dataset=dataset_val, metrics=metric,epochs=5, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb])
trainer.run(tgt_columns="labels")

模型验证

在模型验证部分,我们使用验证数据集来评估模型的性能。通过计算模型在验证集上的准确率,我们可以了解模型的泛化能力和实际效果。这一步骤非常重要,因为它可以帮助我们发现模型在训练过程中可能存在的问题,并进行相应的调整和优化。

evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")

模型推理

模型推理部分展示了如何使用训练好的模型对新数据进行情感分类。我们定义了一个predict函数,通过输入文本进行情感预测,并输出预测结果。这个步骤展示了模型的实际应用能力,并验证了模型的泛化性能。

dataset_infer = SentimentDataset("data/infer.tsv")def predict(text, label=None):label_map = {0: "消极", 1: "中性", 2: "积极"}text_tokenized = Tensor([tokenizer(text).input_ids])logits = model(text_tokenized)predict_label = logits[0].asnumpy().argmax()info = f"inputs: '{text}', predict: '{label_map[predict_label]}'"if label is not None:info += f" , label: '{label_map[label]}'"print(info)for label, text in dataset_infer:predict(text, label)

在这里插入图片描述

自定义推理数据

最后,我们展示了如何使用模型对自定义输入进行情感识别。这一步骤不仅展示了模型的实际应用能力,还验证了模型在不同输入下的表现。通过这种方式,我们可以进一步了解模型的泛化能力和实际效果。

predict("家人们咱就是说一整个无语住了 绝绝子叠buff")

在这里插入图片描述

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 【机器学习】独立成分分析(ICA):解锁信号的隐秘面纱
  • 道路运输企业管理人员安全考核试题(附答案)
  • 如何在工作中开悟?
  • element 如何实现文件上传下载导出
  • 基于QT开发的反射内存小工具
  • OWASP ZAP
  • 低代码研发项目管理流程优化:提效与创新的双重驱动
  • 【Unity2D 2022:UI】制作主菜单
  • 昇思25天学习打卡营第1天|初步了解
  • [Linux][Shell][Shell函数]详细讲解
  • Qt 统计图编程
  • Apache Seata分布式事务启用Nacos做配置中心
  • 禅道二次开发——禅道zentaoPHP框架扩展机制——对视图层(view)扩展
  • Linux账号和权限管理详解
  • 浅谈MMORPG的战斗系统
  • ES6指北【2】—— 箭头函数
  • 【css3】浏览器内核及其兼容性
  • 【翻译】babel对TC39装饰器草案的实现
  • CentOS 7 修改主机名
  • gf框架之分页模块(五) - 自定义分页
  • go语言学习初探(一)
  • If…else
  • java中的hashCode
  • Making An Indicator With Pure CSS
  • NSTimer学习笔记
  • SpriteKit 技巧之添加背景图片
  • UMLCHINA 首席专家潘加宇鼎力推荐
  • 基于 Ueditor 的现代化编辑器 Neditor 1.5.4 发布
  • 前端性能优化——回流与重绘
  • 如何编写一个可升级的智能合约
  • 使用Swoole加速Laravel(正式环境中)
  • 腾讯优测优分享 | 你是否体验过Android手机插入耳机后仍外放的尴尬?
  • 栈实现走出迷宫(C++)
  • 浅谈sql中的in与not in,exists与not exists的区别
  • ​Base64转换成图片,android studio build乱码,找不到okio.ByteString接腾讯人脸识别
  • ​学习一下,什么是预包装食品?​
  • # 深度解析 Socket 与 WebSocket:原理、区别与应用
  • #LLM入门|Prompt#1.8_聊天机器人_Chatbot
  • (16)Reactor的测试——响应式Spring的道法术器
  • (3)选择元素——(17)练习(Exercises)
  • (C语言)strcpy与strcpy详解,与模拟实现
  • (C语言)二分查找 超详细
  • (delphi11最新学习资料) Object Pascal 学习笔记---第5章第5节(delphi中的指针)
  • (HAL库版)freeRTOS移植STMF103
  • (Redis使用系列) SpringBoot中Redis的RedisConfig 二
  • (void) (_x == _y)的作用
  • (八)Flink Join 连接
  • (草履虫都可以看懂的)PyQt子窗口向主窗口传递参数,主窗口接收子窗口信号、参数。
  • (二)什么是Vite——Vite 和 Webpack 区别(冷启动)
  • (附源码)spring boot基于Java的电影院售票与管理系统毕业设计 011449
  • (附源码)ssm捐赠救助系统 毕业设计 060945
  • (论文阅读26/100)Weakly-supervised learning with convolutional neural networks
  • (三)Pytorch快速搭建卷积神经网络模型实现手写数字识别(代码+详细注解)
  • (算法)区间调度问题
  • (一)springboot2.7.6集成activit5.23.0之集成引擎