当前位置: 首页 > news >正文

昇思25天学习打卡营第12天 |昇思MindSpore 基于 MindSpore 通过 GPT 实现情感分类

一、环境准备

# 安装指定版本的 MindSpore
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14# 安装 MindNLP 和 jieba
!pip install mindnlp
!pip install jieba# 设置 HF_ENDPOINT
%env HF_ENDPOINT=https://hf-mirror.com

二、原理公式
情感分类的基本原理通常基于对文本特征的提取和分析。在使用 GPT 模型时,模型通过学习大量的文本数据,自动捕捉语言的模式和规律。

常见的数学公式可能涉及到损失函数的计算,如交叉熵损失函数:
Gamma公式展示 Γ ( n ) = ( n − 1 ) ! ∀ n ∈ N \Gamma(n) = (n-1)!\quad\forall n\in\mathbb N Γ(n)=(n1)!nN 是通过 Euler integral

L = − 1 N ∑ i y i log ⁡ ( y ^ i ) + ( 1 − y i ) log ⁡ ( 1 − y ^ i ) L = -\frac{1}{N} \sum_{i} y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i) L=N1iyilog(y^i)+(1yi)log(1y^i)
其中, y i y_i yi 是真实标签, y ^ i \hat{y}_i y^i是模型预测的概率。

三、推理流程

  1. 数据加载与预处理:
    • 加载 imdb 数据集,并将其划分为训练集、验证集和测试集。
    • 对文本数据进行分词、添加特殊标记、截断或填充等操作,将其转换为适合模型输入的格式。
  2. 模型构建与训练:
    • 基于预训练的 openai-gpt 模型进行微调。
    • 定义优化器(如 Adam )和损失函数,通过反向传播不断更新模型的参数,以最小化损失函数。
  3. 模型评估与预测:
    • 使用验证集或测试集对训练好的模型进行评估,计算准确率等指标。
    • 对于新的输入文本,通过模型进行预测,得到其情感分类结果。

四、操作流程

  1. 加载所需的库和模块:
import os
import mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn
from mindnlp.dataset import load_dataset
from mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy
import time
  1. 加载 imdb 数据集,并进行划分:
imdb_ds = load_dataset('imdb', split=['train', 'test'])
imdb_train = imdb_ds['train']
imdb_test = imdb_ds['test']
  1. 定义数据处理函数 process_dataset
def process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):# 内部处理逻辑#...return dataset
- 参数:- `dataset` :待处理的数据集。- `tokenizer` :用于分词的工具。- `max_seq_len` :序列的最大长度。- `batch_size` :批次大小。- `shuffle` :是否打乱数据集。
- 功能:对输入的数据集进行分词、类型转换、批次处理等操作。
- 例句:
dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
  1. 进行分词器的设置和特殊标记的添加:
gpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')
special_tokens_dict = {"bos_token": "<bos>","eos_token": "<eos>","pad_token": "<pad>",
}
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)
  1. 划分训练集和验证集:
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])
  1. 处理数据集:
dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
dataset_test = process_dataset(imdb_test, gpt_tokenizer)
  1. 定义模型、优化器和评估指标:
model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)
model.config.pad_token_id = gpt_tokenizer.pad_token_id
model.resize_token_embeddings(model.config.vocab_size + 3)optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)metric = Accuracy()
  1. 设置回调函数并进行训练:
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)trainer = Trainer(network=model, train_dataset=dataset_train,eval_dataset=dataset_train, metrics=metric,epochs=1, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],jit=False)
  1. 进行评估:
evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")

五、情感分类模型的应用领域

  1. 客户反馈分析:帮助企业了解客户对产品或服务的满意度和意见。
  2. 社交媒体监测:洞察公众对特定话题、事件或品牌的情感倾向。
  3. 在线评论分类:对电商平台、旅游网站等的用户评论进行分类。
  4. 舆情分析:了解社会舆论对政府政策、公共事件的态度。

六、常用的情感分类模型

  1. 朴素贝叶斯分类器
  2. 支持向量机(SVM)
  3. 决策树
  4. 卷积神经网络(CNN)
  5. 循环神经网络(RNN)及其变体,如长短期记忆网络(LSTM)和门控循环单元(GRU)

七、如何评估情感分类模型

  1. 准确率(Accuracy):正确分类的样本数占总样本数的比例。
  2. 召回率(Recall):正确分类的正例样本数占实际正例样本数的比例。
  3. 精确率(Precision):正确分类的正例样本数占预测为正例样本数的比例。
  4. F1 值:综合考虑精确率和召回率的调和平均值。
  5. 混淆矩阵(Confusion Matrix):直观展示不同类别之间的分类情况。

八、调用库的名称、功能

  1. mindspore :用于深度学习模型的构建和训练。
  2. mindnlp :提供自然语言处理相关的数据集和工具。
  3. jieba :中文分词库。

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • CANoe:System Variables模块介绍
  • 只有IP地址没有域名怎么实现HTTPS访问?
  • 自动问答之白嫖文心一言大模型
  • 卡拉OK歌唱比赛活动策划方案
  • 使用flutter做圆形进度条 (桌面端)
  • PicInsight - 制作精美的明信片! | 限时免费
  • 成都云飞浩容文化传媒有限公司领航电商新纪元
  • URL重写
  • 代码随想录算法训练营第四十五天| 115.不同的子序列 、583. 两个字符串的删除操作 、 72. 编辑距离
  • ElasticSearch搜索
  • 【实践出真知】使用Docusaurus将md文档组织起来就是一个网站(写API文档,写教程、写日记、写博客的有福了)
  • python使用selenium切换到了iframe
  • 理解 HTTP 请求中 Query 和 Body 的异同
  • Android经典面试题之Kotlin中 if 和 let的区别
  • C语言100基础拔高题(3)
  • Google 是如何开发 Web 框架的
  • $translatePartialLoader加载失败及解决方式
  • 【编码】-360实习笔试编程题(二)-2016.03.29
  • ➹使用webpack配置多页面应用(MPA)
  • es6
  • HTML5新特性总结
  • Java知识点总结(JDBC-连接步骤及CRUD)
  • MySQL几个简单SQL的优化
  • Redis 中的布隆过滤器
  • vue-cli在webpack的配置文件探究
  • 阿里云购买磁盘后挂载
  • 对JS继承的一点思考
  • 基于Vue2全家桶的移动端AppDEMO实现
  • 基于遗传算法的优化问题求解
  • 蓝海存储开关机注意事项总结
  • 算法---两个栈实现一个队列
  • 想写好前端,先练好内功
  • 由插件封装引出的一丢丢思考
  • 容器镜像
  • # Java NIO(一)FileChannel
  • #etcd#安装时出错
  • #WEB前端(HTML属性)
  • (02)Cartographer源码无死角解析-(03) 新数据运行与地图保存、加载地图启动仅定位模式
  • (C语言)编写程序将一个4×4的数组进行顺时针旋转90度后输出。
  • (附源码)python房屋租赁管理系统 毕业设计 745613
  • (附源码)springboot家庭财务分析系统 毕业设计641323
  • (附源码)ssm旅游企业财务管理系统 毕业设计 102100
  • (每日一问)操作系统:常见的 Linux 指令详解
  • (三)终结任务
  • (一)基于IDEA的JAVA基础12
  • (转)创业的注意事项
  • (转)为C# Windows服务添加安装程序
  • (转载)CentOS查看系统信息|CentOS查看命令
  • *算法训练(leetcode)第三十九天 | 115. 不同的子序列、583. 两个字符串的删除操作、72. 编辑距离
  • .net CHARTING图表控件下载地址
  • .Net 代码性能 - (1)
  • .NET/ASP.NETMVC 深入剖析 Model元数据、HtmlHelper、自定义模板、模板的装饰者模式(二)...
  • .pub是什么文件_Rust 模块和文件 - 「译」
  • /usr/bin/env: node: No such file or directory
  • @autowired注解作用_Spring Boot进阶教程——注解大全(建议收藏!)