当前位置: 首页 > news >正文

基于Python的自然语言处理系列(8):使用TorchText进行新闻分类

        在本篇文章中,我们将探讨如何使用TorchText库来进行新闻分类任务。我们将使用AG_NEWS数据集,它包含四个分类标签:“World”,“Sports”,“Business”,“Sci/Tech”。我们将详细讲解数据的加载与预处理、模型设计、训练与评估,以及如何在PyTorch中结合RNN进行分类。

1. TorchText简介

        TorchText 是 PyTorch 的一个辅助库,专门用于处理文本数据。虽然 PyTorch 本身已经非常强大,能够处理多种类型的数据集,但TorchText 提供了许多专为自然语言处理(NLP)任务优化的功能,例如文本的加载、处理、词汇构建、批处理等。在这篇文章中,我们将展示如何使用 TorchText 处理新闻分类任务。

import torch, torchdata, torchtext
from torch import nn
import time# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)# 保持随机性一致
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

2. 数据加载与预处理

        我们将使用 AG_NEWS 数据集,该数据集已经包含在 TorchText 库中。我们将通过 DataPipe 加载数据,并进行一些简单的探索性数据分析(EDA)。

2.1 数据加载

from torchtext.datasets import AG_NEWS# 加载数据
train, test = AG_NEWS()
train_size = len(list(iter(train)))
print(f"训练集大小: {train_size}")

2.2 数据预处理

        为了使模型高效训练,我们将对数据进行随机拆分,并将数据集缩小到较小的子集,以便快速进行模型训练。

too_much, train, valid = train.random_split(total_length=train_size, weights={"too_much": 0.7, "train": 0.2, "valid": 0.1}, seed=999)train_size = len(list(iter(train)))
val_size = len(list(iter(valid)))
print(f"训练集大小: {train_size}, 验证集大小: {val_size}")

        接下来,我们需要将文本转换为整数表示,即使用 tokenizer 将句子拆分为词,并通过 build_vocab_from_iterator 创建词汇表。

from torchtext.data.utils import get_tokenizer
tokenizer = get_tokenizer('spacy', language='en_core_web_sm')from torchtext.vocab import build_vocab_from_iterator
def yield_tokens(data_iter):for _, text in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train), specials=['<unk>', '<pad>'])
vocab.set_default_index(vocab["<unk>"])

3. 模型设计

        接下来,我们将设计一个简单的基于RNN的模型来进行文本分类任务。模型包括三个主要部分:embedding层RNN层全连接层

class simpleRNN(nn.Module):def __init__(self, input_dim, emb_dim, hid_dim, output_dim):super().__init__()self.embedding = nn.Embedding(input_dim, emb_dim)self.rnn = nn.RNN(emb_dim, hid_dim, batch_first=True)self.fc  = nn.Linear(hid_dim, output_dim)def forward(self, text):embedded = self.embedding(text)output, hidden = self.rnn(embedded)return self.fc(hidden.squeeze(0))

模型初始化与参数设置

        我们为模型的每一层初始化参数,以确保更好的学习效果。

def initialize_weights(m):if isinstance(m, nn.Linear):nn.init.xavier_normal_(m.weight)nn.init.zeros_(m.bias)elif isinstance(m, nn.RNN):for name, param in m.named_parameters():if 'bias' in name:nn.init.zeros_(param)elif 'weight' in name:nn.init.xavier_normal_(param)input_dim  = len(vocab)
hid_dim    = 256
emb_dim    = 200
output_dim = 4model = simpleRNN(input_dim, emb_dim, hid_dim, output_dim).to(device)
model.apply(initialize_weights)

4. 模型训练与评估

        我们定义了训练和评估的函数。每个epoch结束后,我们会计算训练和验证的损失与准确率,并保存最佳的验证集模型。

import torch.optim as optimoptimizer = optim.SGD(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()def train(model, loader, optimizer, criterion, loader_length):model.train()epoch_loss, epoch_acc = 0, 0for label, text in loader:label, text = label.to(device), text.to(device)predictions = model(text).squeeze(1)loss = criterion(predictions, label)acc = (predictions.argmax(1) == label).sum() / len(label)optimizer.zero_grad()loss.backward()optimizer.step()epoch_loss += loss.item()epoch_acc += acc.item()return epoch_loss / loader_length, epoch_acc / loader_length

        同样,我们也定义了模型的评估函数:

def evaluate(model, loader, criterion, loader_length):model.eval()epoch_loss, epoch_acc = 0, 0with torch.no_grad():for label, text in loader:label, text = label.to(device), text.to(device)predictions = model(text).squeeze(1)loss = criterion(predictions, label)acc = (predictions.argmax(1) == label).sum() / len(label)epoch_loss += loss.item()epoch_acc += acc.item()return epoch_loss / loader_length, epoch_acc / loader_length

5. 模型测试

在训练完成后,我们可以测试模型在随机新闻样本上的表现。你可以输入任意新闻文本,模型将预测它属于哪个类别。

def predict(text):model.eval()tokens = tokenizer(text)indices = vocab(tokens)text_tensor = torch.tensor(indices).unsqueeze(0).to(device)with torch.no_grad():prediction = model(text_tensor)return prediction.argmax(1).item()test_str = "Google is facing major challenges in its recent business strategies."
pred = predict(test_str)
print(f'预测类别: {pred}')

结语

        在这篇文章中,我们探索了如何使用TorchText库进行新闻分类任务。从数据的加载与预处理,到模型的设计与训练,我们详细讲解了每个步骤,尤其是如何利用RNN来处理序列数据。通过这种方式,读者不仅可以掌握如何构建自然语言处理(NLP)pipeline,还能了解如何在实际项目中应用这些技术。

        虽然我们使用了一个简单的RNN模型,但这仅仅是NLP世界的一小部分。在后续的文章中,我们将继续优化模型,尝试更多的高级架构(如LSTM、GRU和Transformer),并探讨如何提升分类的性能和准确性。希望这篇文章能为你在NLP的学习旅程中提供有益的帮助,敬请期待更多深入的技术内容!

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • 迎接AI时代的机遇与挑战:个人成长指南
  • C语言 | Leetcode C语言题解之第404题左叶子之和
  • c++类和对象(3):默认成员函数(下)
  • 电巢科技携Ecosmos元宇宙产品亮相第25届中国光博会
  • Java | Leetcode Java题解之第404题左叶子之和
  • 光伏选址和设计离不开气象分析!
  • Android 蓝牙三方和动态权限三方
  • 【Android安全】Keystone和Capstone
  • Flink CEP(复杂事件处理)高级进阶
  • 【C++题解】1406. 石头剪刀布?
  • vue国际化vue-i18n搭配i18n-ally实现多语言国际化
  • linux gcc 静态库的简单介绍
  • 438 找到字符串中所有字母异位词
  • 以太网传输出现不分包
  • Facebook主页,广告账户,BM被封分别怎么解决?
  • [LeetCode] Wiggle Sort
  • [Vue CLI 3] 配置解析之 css.extract
  • 《Javascript高级程序设计 (第三版)》第五章 引用类型
  • bootstrap创建登录注册页面
  • C语言笔记(第一章:C语言编程)
  • Django 博客开发教程 16 - 统计文章阅读量
  • hadoop入门学习教程--DKHadoop完整安装步骤
  • JavaWeb(学习笔记二)
  • Js基础——数据类型之Null和Undefined
  • python 装饰器(一)
  • python3 使用 asyncio 代替线程
  • React-flux杂记
  • 成为一名优秀的Developer的书单
  • 多线程 start 和 run 方法到底有什么区别?
  • 分享自己折腾多时的一套 vue 组件 --we-vue
  • 服务器从安装到部署全过程(二)
  • 服务器之间,相同帐号,实现免密钥登录
  • 面试题:给你个id,去拿到name,多叉树遍历
  • 前嗅ForeSpider中数据浏览界面介绍
  • 深入浅出Node.js
  • 使用iElevator.js模拟segmentfault的文章标题导航
  • 进程与线程(三)——进程/线程间通信
  • ​​​【收录 Hello 算法】9.4 小结
  • ​Linux Ubuntu环境下使用docker构建spark运行环境(超级详细)
  • ​必胜客礼品卡回收多少钱,回收平台哪家好
  • ​你们这样子,耽误我的工作进度怎么办?
  • ​软考-高级-信息系统项目管理师教程 第四版【第14章-项目沟通管理-思维导图】​
  • (22)C#传智:复习,多态虚方法抽象类接口,静态类,String与StringBuilder,集合泛型List与Dictionary,文件类,结构与类的区别
  • (3)选择元素——(17)练习(Exercises)
  • (BFS)hdoj2377-Bus Pass
  • (Redis使用系列) SpirngBoot中关于Redis的值的各种方式的存储与取出 三
  • (草履虫都可以看懂的)PyQt子窗口向主窗口传递参数,主窗口接收子窗口信号、参数。
  • (二)springcloud实战之config配置中心
  • (附源码)ssm基于web技术的医务志愿者管理系统 毕业设计 100910
  • (附源码)ssm捐赠救助系统 毕业设计 060945
  • (黑马点评)二、短信登录功能实现
  • (全注解开发)学习Spring-MVC的第三天
  • (实战篇)如何缓存数据
  • (转)http协议
  • .Mobi域名介绍