当前位置: 首页 > news >正文

NLP——文本预处理-新闻主题分类案例

案例说明

1. 关于新闻主题分类任务

以一段新闻报道中的文本描述内容为输入, 使用模型帮助我们判断它最有可能属于哪一种类型的新闻, 这是典型的文本分类问题, 我们这里假定每种类型是互斥的, 即文本描述有且只有一种类型.

1. 数据文件预览
"""
文件说明:
train.csv表示训练数据, 共12万条数据; test.csv表示验证数据, 共7600条数据; 
classes.txt是标签(新闻主题)含义文件, 里面有四个单词'World', 'Sports', 'Business', 'Sci/Tech'代表新闻的四个主题;
readme.txt是该数据集的英文说明.
"""
- data/- ag_news_csv/classes.txtreadme.txttest.csvtrain.csv
# train.csv预览:
# train.csv共由3列组成, 使用','进行分隔, 分别代表: 标签, 新闻标题, 新闻简述; 
# 其中标签用"1", "2", "3", "4"表示, 依次对应classes中的内容.
"3","Wall St. Bears Claw Back Into the Black (Reuters)","Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again."
"3","Carlyle Looks Toward Commercial Aerospace (Reuters)","Reuters - Private investment firm Carlyle Group,\which has a reputation for making well-timed and occasionally\controversial plays in the defense industry, has quietly placed\its bets on another part of the market."
"3","Oil and Economy Cloud Stocks' Outlook (Reuters)","Reuters - Soaring crude prices plus worries\about the economy and the outlook for earnings are expected to\hang over the stock market next week during the depth of the\summer doldrums."
"3","Iraq Halts Oil Exports from Main Southern Pipeline (Reuters)","Reuters - Authorities have halted oil export\flows from the main pipeline in southern Iraq after\intelligence showed a rebel militia could strike\infrastructure, an oil official said on Saturday."
"3","Oil prices soar to all-time record, posing new menace to US economy (AFP)","AFP - Tearaway world oil prices, toppling records and straining wallets, present a new economic menace barely three months before the US presidential elections."
"3","Stocks End Up, But Near Year Lows (Reuters)","Reuters - Stocks ended slightly higher on Friday\but stayed near lows for the year as oil prices surged past  #36;46\a barrel, offsetting a positive outlook from computer maker\Dell Inc. (DELL.O)"
"3","Money Funds Fell in Latest Week (AP)","AP - Assets of the nation's retail money market mutual funds fell by  #36;1.17 billion in the latest week to  #36;849.98 trillion, the Investment Company Institute said Thursday."
"3","Fed minutes show dissent over inflation (USATODAY.com)","USATODAY.com - Retail sales bounced back a bit in July, and new claims for jobless benefits fell last week, the government said Thursday, indicating the economy is improving from a midsummer slump."
"3","Safety Net (Forbes.com)","Forbes.com - After earning a PH.D. in Sociology, Danny Bazil Riley started to work as the general manager at a commercial real estate firm at an annual base salary of  #36;70,000. Soon after, a financial planner stopped by his desk to drop off brochures about insurance benefits available through his employer. But, at 32, ""buying insurance was the furthest thing from my mind,"" says Riley."
"3","Wall St. Bears Claw Back Into the Black"," NEW YORK (Reuters) - Short-sellers, Wall Street's dwindling  band of ultra-cynics, are seeing green again."
2. 从本地进行数据的加载,实现代码如下
from torchtext.legacy.datasets.text_classification import _csv_iterator, _create_data_from_iterator, TextClassificationDataset
from torchtext.utils import extract_archive
from torchtext.vocab import build_vocab_from_iterator, Vocab
# 从本地加载数据的方式,本地数据在虚拟机/root/data/ag_news_csv中
# 定义加载函数
def setup_datasets(ngrams=2, vocab_train=None, vocab_test=None, include_unk=False):train_csv_path = 'data/ag_news_csv/train.csv'test_csv_path = 'data/ag_news_csv/test.csv'if vocab_train is None:vocab_train = build_vocab_from_iterator(_csv_iterator(train_csv_path, ngrams))else:if not isinstance(vocab, Vocab):raise TypeError("Passed vocabulary is not of type Vocab")if vocab_test is None:vocab_test = build_vocab_from_iterator(_csv_iterator(test_csv_path, ngrams))else:if not isinstance(vocab, Vocab):raise TypeError("Passed vocabulary is not of type Vocab")train_data, train_labels = _create_data_from_iterator(vocab_train, _csv_iterator(train_csv_path, ngrams, yield_cls=True), include_unk)test_data, test_labels = _create_data_from_iterator(vocab_test, _csv_iterator(test_csv_path, ngrams, yield_cls=True), include_unk)if len(train_labels ^ test_labels) > 0:raise ValueError("Training and test labels don't match")return (TextClassificationDataset(vocab_train, train_data, train_labels),TextClassificationDataset(vocab_test, test_data, test_labels))# 调用函数, 加载本地数据
train_dataset, test_dataset = setup_datasets()
print("train_dataset", train_dataset)

二、案例实现

整个案例的实现可分为以下五个步骤

  • 第一步: 构建带有Embedding层的文本分类模型.
  • 第二步: 对数据进行batch处理.
  • 第三步: 构建训练与验证函数.
  • 第四步: 进行模型训练和验证.
  • 第五步: 查看embedding层嵌入的词向量.
1. 构建带有Embedding层的文本分类模型
import torch.nn as nn
import torch.nn.functional as F# 指定BATCH_SIZE的大小
BATCH_SIZE = 16# 进行可用设备检测, 有GPU的话将优先使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")class TextSentiment(nn.Module):"""文本分类模型"""def __init__(self, vocab_size, embed_dim, num_class):"""description: 类的初始化函数:param vocab_size: 整个语料包含的不同词汇总数:param embed_dim: 指定词嵌入的维度:param num_class: 文本分类的类别总数""" super().__init__()# 实例化embedding层, sparse=True代表每次对该层求解梯度时, 只更新部分权重.self.embedding = nn.Embedding(vocab_size, embed_dim, sparse=True)# 实例化线性层, 参数分别是embed_dim和num_class.self.fc = nn.Linear(embed_dim, num_class)# 为各层初始化权重self.init_weights()def init_weights(self):"""初始化权重函数"""# 指定初始权重的取值范围数initrange = 0.5# 各层的权重参数都是初始化为均匀分布self.embedding.weight.data.uniform_(-initrange, initrange)self.fc.weight.data.uniform_(-initrange, initrange)# 偏置初始化为0self.fc.bias.data.zero_()def forward(self, text):""":param text: 文本数值映射后的结果:return: 与类别数尺寸相同的张量, 用以判断文本类别"""# 获得embedding的结果embedded# >>> embedded.shape# (m, 32) 其中m是BATCH_SIZE大小的数据中词汇总数embedded = self.embedding(text)# 接下来我们需要将(m, 32)转化成(BATCH_SIZE, 32)# 以便通过fc层后能计算相应的损失# 首先, 我们已知m的值远大于BATCH_SIZE=16,# 用m整除BATCH_SIZE, 获得m中共包含c个BATCH_SIZEc = embedded.size(0) // BATCH_SIZE# 之后再从embedded中取c*BATCH_SIZE个向量得到新的embedded# 这个新的embedded中的向量个数可以整除BATCH_SIZEembedded = embedded[:BATCH_SIZE*c]# 因为我们想利用平均池化的方法求embedded中指定行数的列的平均数,# 但平均池化方法是作用在行上的, 并且需要3维输入# 因此我们对新的embedded进行转置并拓展维度embedded = embedded.transpose(1, 0).unsqueeze(0)# 然后就是调用平均池化的方法, 并且核的大小为c# 即取每c的元素计算一次均值作为结果embedded = F.avg_pool1d(embedded, kernel_size=c)# 最后,还需要减去新增的维度, 然后转置回去输送给fc层return self.fc(embedded[0].transpose(1, 0))

实例化模型:

# 获得整个语料包含的不同词汇总数
VOCAB_SIZE = len(train_dataset.get_vocab())
# 指定词嵌入维度
EMBED_DIM = 32
# 获得类别总数
NUN_CLASS = len(train_dataset.get_labels())
# 实例化模型
model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)
2. 对数据进行batch处理
def generate_batch(batch):"""description: 生成batch数据函数:param batch: 由样本张量和对应标签的元组组成的batch_size大小的列表形如:[(label1, sample1), (lable2, sample2), ..., (labelN, sampleN)]return: 样本张量和标签各自的列表形式(张量)形如:text = tensor([sample1, sample2, ..., sampleN])label = tensor([label1, label2, ..., labelN])"""# 从batch中获得标签张量label = torch.tensor([entry[0] for entry in batch])# 从batch中获得样本张量text = [entry[1] for entry in batch]text = torch.cat(text)# 返回结果return text, label

调用:

# 假设一个输入:
batch = [(1, torch.tensor([3, 23, 2, 8])), (0, torch.tensor([3, 45, 21, 6]))]
res = generate_batch(batch)
print(res)# 对应输入的两条数据进行了相应的拼接
# (tensor([ 3, 23,  2,  8,  3, 45, 21,  6]), tensor([1, 0]))
3. 构建训练与验证函数
# 导入torch中的数据加载器方法
from torch.utils.data import DataLoaderdef train(train_data):"""模型训练函数"""# 初始化训练损失和准确率为0train_loss = 0train_acc = 0# 使用数据加载器生成BATCH_SIZE大小的数据进行批次训练# data就是N多个generate_batch函数处理后的BATCH_SIZE大小的数据生成器data = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True,collate_fn=generate_batch)# 对data进行循环遍历, 使用每个batch的数据进行参数更新for i, (text, cls) in enumerate(data):# 设置优化器初始梯度为0optimizer.zero_grad()# 模型输入一个批次数据, 获得输出output = model(text)# 根据真实标签与模型输出计算损失loss = criterion(output, cls)# 将该批次的损失加到总损失中train_loss += loss.item()# 误差反向传播loss.backward()# 参数进行更新optimizer.step()# 将该批次的准确率加到总准确率中train_acc += (output.argmax(1) == cls).sum().item()# 调整优化器学习率  scheduler.step()# 返回本轮训练的平均损失和平均准确率return train_loss / len(train_data), train_acc / len(train_data)def valid(valid_data):"""模型验证函数"""# 初始化验证损失和准确率为0loss = 0acc = 0# 和训练相同, 使用DataLoader获得训练数据生成器data = DataLoader(valid_data, batch_size=BATCH_SIZE, collate_fn=generate_batch)# 按批次取出数据验证for text, cls in data:# 验证阶段, 不再求解梯度with torch.no_grad():# 使用模型获得输出output = model(text)# 计算损失loss = criterion(output, cls)# 将损失和准确率加到总损失和准确率中loss += loss.item()acc += (output.argmax(1) == cls).sum().item()# 返回本轮验证的平均损失和平均准确率return loss / len(valid_data), acc / len(valid_data)
4. 进行模型训练和验证
# 导入时间工具包
import time# 导入数据随机划分方法工具
from torch.utils.data.dataset import random_split# 指定训练轮数
N_EPOCHS = 10# 定义初始的验证损失
min_valid_loss = float('inf')# 选择损失函数, 这里选择预定义的交叉熵损失函数
criterion = torch.nn.CrossEntropyLoss().to(device)
# 选择随机梯度下降优化器
optimizer = torch.optim.SGD(model.parameters(), lr=4.0)
# 选择优化器步长调节方法StepLR, 用来衰减学习率
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)# 从train_dataset取出0.95作为训练集, 先取其长度
train_len = int(len(train_dataset) * 0.95)# 然后使用random_split进行乱序划分, 得到对应的训练集和验证集
sub_train_, sub_valid_ = \random_split(train_dataset, [train_len, len(train_dataset) - train_len])# 开始每一轮训练
for epoch in range(N_EPOCHS):# 记录概论训练的开始时间start_time = time.time()# 调用train和valid函数得到训练和验证的平均损失, 平均准确率train_loss, train_acc = train(sub_train_)valid_loss, valid_acc = valid(sub_valid_)# 计算训练和验证的总耗时(秒)secs = int(time.time() - start_time)# 用分钟和秒表示mins = secs / 60secs = secs % 60# 打印训练和验证耗时,平均损失,平均准确率print('Epoch: %d' %(epoch + 1), " | time in %d minutes, %d seconds" %(mins, secs))print(f'\tLoss: {train_loss:.4f}(train)\t|\tAcc: {train_acc * 100:.1f}%(train)')print(f'\tLoss: {valid_loss:.4f}(valid)\t|\tAcc: {valid_acc * 100:.1f}%(valid)')

输出效果:

120000lines [00:06, 17834.17lines/s]
120000lines [00:11, 10071.77lines/s]
7600lines [00:00, 10432.95lines/s]Epoch: 1  | time in 0 minutes, 36 secondsLoss: 0.0592(train) |   Acc: 63.9%(train)Loss: 0.0005(valid) |   Acc: 69.2%(valid)
Epoch: 2  | time in 0 minutes, 37 secondsLoss: 0.0507(train) |   Acc: 71.3%(train)Loss: 0.0005(valid) |   Acc: 70.7%(valid)
Epoch: 3  | time in 0 minutes, 36 secondsLoss: 0.0484(train) |   Acc: 72.8%(train)Loss: 0.0005(valid) |   Acc: 71.4%(valid)
Epoch: 4  | time in 0 minutes, 36 secondsLoss: 0.0474(train) |   Acc: 73.4%(train)Loss: 0.0004(valid) |   Acc: 72.0%(valid)
Epoch: 5  | time in 0 minutes, 36 secondsLoss: 0.0455(train) |   Acc: 74.8%(train)Loss: 0.0004(valid) |   Acc: 72.5%(valid)
Epoch: 6  | time in 0 minutes, 36 secondsLoss: 0.0451(train) |   Acc: 74.9%(train)Loss: 0.0004(valid) |   Acc: 72.3%(valid)
Epoch: 7  | time in 0 minutes, 36 secondsLoss: 0.0446(train) |   Acc: 75.3%(train)Loss: 0.0004(valid) |   Acc: 72.0%(valid)
Epoch: 8  | time in 0 minutes, 36 secondsLoss: 0.0437(train) |   Acc: 75.9%(train)Loss: 0.0004(valid) |   Acc: 71.4%(valid)
Epoch: 9  | time in 0 minutes, 36 secondsLoss: 0.0431(train) |   Acc: 76.2%(train)Loss: 0.0004(valid) |   Acc: 72.7%(valid)
Epoch: 10  | time in 0 minutes, 36 secondsLoss: 0.0426(train) |   Acc: 76.6%(train)Loss: 0.0004(valid) |   Acc: 72.6%(valid)
5. 查看embedding层嵌入的词向量
# 打印从模型的状态字典中获得的Embedding矩阵
print(model.state_dict()['embedding.weight'])
tensor([[ 0.4401, -0.4177, -0.4161,  ...,  0.2497, -0.4657, -0.1861],[-0.2574, -0.1952,  0.1443,  ..., -0.4687, -0.0742,  0.2606],[-0.1926, -0.1153, -0.0167,  ..., -0.0954,  0.0134, -0.0632],...,[-0.0780, -0.2331, -0.3656,  ..., -0.1899,  0.4083,  0.3002],[-0.0696,  0.4396, -0.1350,  ...,  0.1019,  0.2792, -0.4749],[-0.2978,  0.1872, -0.1994,  ...,  0.3435,  0.4729, -0.2608]])

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • MySQL——数据库的设计、事务、视图
  • GraalVM全面介绍:革新Java应用开发的利器
  • 【循环神经网络】案例:周杰伦歌词文本预测【训练+python代码】
  • 你真正了解低代码么?(国内低代码平台状况分析)
  • 华为od(D卷)最大N个数和最小N个数的和
  • 怎么用云手机进行TikTok矩阵运营
  • OpenTiny HUICharts 正式开源发布,一个简单、易上手的图表组件库
  • 【JAVA】获取object中 key对应的value值
  • 数据结构的基本概念
  • python Django中使用ORM进行分组统计并降序排列
  • 《计算机组成原理》(第3版)第3章 系统总线 复习笔记
  • 1009 Product of Polynomials(Java)
  • Spring Boot 3.x Rest API统一异常处理最佳实践
  • 解决多个Jenkins Master实例共享Jenkins_home目录的问题(加锁解锁机制)
  • 基于Hadoop的海量电商用户行为分析及机器学习购买预测研究【购物行为分析、100万条数据案例项目】
  • [rust! #004] [译] Rust 的内置 Traits, 使用场景, 方式, 和原因
  • 【vuex入门系列02】mutation接收单个参数和多个参数
  • C++回声服务器_9-epoll边缘触发模式版本服务器
  • centos安装java运行环境jdk+tomcat
  • extract-text-webpack-plugin用法
  • HTTP中GET与POST的区别 99%的错误认识
  • JAVA_NIO系列——Channel和Buffer详解
  • JavaScript实现分页效果
  • JavaWeb(学习笔记二)
  • JS创建对象模式及其对象原型链探究(一):Object模式
  • LeetCode541. Reverse String II -- 按步长反转字符串
  • Netty源码解析1-Buffer
  • PHP的Ev教程三(Periodic watcher)
  • tab.js分享及浏览器兼容性问题汇总
  • zookeeper系列(七)实战分布式命名服务
  • 纯 javascript 半自动式下滑一定高度,导航栏固定
  • 浮现式设计
  • - 概述 - 《设计模式(极简c++版)》
  • 开源中国专访:Chameleon原理首发,其它跨多端统一框架都是假的?
  • 三分钟教你同步 Visual Studio Code 设置
  • 深度解析利用ES6进行Promise封装总结
  • 什么软件可以剪辑音乐?
  • 听说你叫Java(二)–Servlet请求
  • 我从编程教室毕业
  • Java总结 - String - 这篇请使劲喷我
  • ​ssh-keyscan命令--Linux命令应用大词典729个命令解读
  • #NOIP 2014#day.2 T1 无限网络发射器选址
  • #pragma 指令
  • #Ubuntu(修改root信息)
  • $redis-setphp_redis Set命令,php操作Redis Set函数介绍
  • (Mirage系列之二)VMware Horizon Mirage的经典用户用例及真实案例分析
  • (阿里巴巴 dubbo,有数据库,可执行 )dubbo zookeeper spring demo
  • (阿里云在线播放)基于SpringBoot+Vue前后端分离的在线教育平台项目
  • (大众金融)SQL server面试题(1)-总销售量最少的3个型号的车及其总销售量
  • (二)斐波那契Fabonacci函数
  • (十六)一篇文章学会Java的常用API
  • (幽默漫画)有个程序员老公,是怎样的体验?
  • (转) SpringBoot:使用spring-boot-devtools进行热部署以及不生效的问题解决
  • (转)EXC_BREAKPOINT僵尸错误
  • (转)visual stdio 书签功能介绍