当前位置: 首页 > news >正文

如何使用“预训练的词向量”,做文本分类

  之前一直不知道,怎么使用预训练得词向量,现在终于知道了!!!

 

  代码可以直接运行

 

 

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from keras.layers import Embedding, LSTM, GRU, Dropout, Dense, Input
from keras.models import Model, Sequential, load_model
from keras.preprocessing import sequence
from keras.datasets import imdb
import gensim
from gensim.models.word2vec import Word2Vec


'''
以LSTM为例,LSTM的长度为MAX_SEQ_LEN;每个cell输入一个单词,这个单词用one-hot表示
词向量矩阵是embedMatrix,记录词典中每个词的词向量;词的idx,对应embedMatrix的行号
“该词的ont-hot向量”点乘“embedMatrix”,便得到“该词的词向量表示”

比如:词典有5个词,也即:word2idx = {_stopWord:0, love:1, I:2, my:3, you:4, friend:5, my:6};每个词映射到2维;
输入句子:"I love my pen",  #pen是停用词,其idx设为0
                     [0,       0]
                     [0.3,   0.1]
[0, 0, 1, 0, 0, 0]   [-0.4, -0.5]   [-0.4, -0.5]
[0, 1, 0, 0, 0, 0] · [0.5,   0.2] = [0.3,   0.1]
[0, 0, 0, 0, 0, 1]   [-0.7,  0.6]   [-0.3, -0.8]
[1, 0, 0, 0, 0, 0]   [-0.3, -0.8]   [0,       0]
                     [0.5,   0.2]
'''
MAX_SEQ_LEN = 250
inPath = '../data/'


def train_W2V(sentenList, embedSize=300, epoch_num=1):
    w2vModel = Word2Vec(sentences=sentenList, hs=0, negative=5, min_count=5, window=5, iter=epoch_num, size=embedSize)
    w2vModel.save(inPath + 'w2vModel')
    return w2vModel


def build_word2idx_embedMatrix(w2vModel): word2idx = {"_stopWord": 0} # 这里加了一行是用来过滤停用词的。 vocab_list = [(w, w2vModel.wv[w]) for w, v in w2vModel.wv.vocab.items()] embedMatrix = np.zeros((len(w2vModel.wv.vocab.items()) + 1, w2vModel.vector_size)) for i in range(0, len(vocab_list)): word = vocab_list[i][0] word2idx[word] = i + 1 embedMatrix[i + 1] = vocab_list[i][1] return word2idx, embedMatrix def make_deepLearn_data(sentenList, word2idx): X_train_idx = [[word2idx.get(w, 0) for w in sen] for sen in sentenList] X_train_idx = np.array(sequence.pad_sequences(X_train_idx, maxlen=MAX_SEQ_LEN)) # 必须是np.array()类型 return X_train_idx


def Lstm_model(embedMatrix): # 注意命名不能和库函数同名,之前命名为LSTM()就出很大的错误!! input_layer = Input(shape=(MAX_SEQ_LEN,), dtype='int32') embedding_layer = Embedding(input_dim=len(embedMatrix), output_dim=len(embedMatrix[0]), weights=[embedMatrix], # 表示直接使用预训练的词向量 trainable=False)(input_layer) # False表示不对词向量微调 Lstm_layer = LSTM(units=20, return_sequences=False)(embedding_layer) drop_layer = Dropout(0.5)(Lstm_layer) dense_layer = Dense(units=1, activation="sigmoid")(drop_layer) model = Model(inputs=[input_layer], outputs=[dense_layer]) model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) print(model.summary()) return model if __name__ == '__main__': (X_train, y_train), (X_test, y_test) = imdb.load_data() X_all = (list(X_train) + list(X_test))[0: 1000] y_all = (list(y_train) + list(y_test))[0: 1000] print(len(X_all), len(y_all)) imdb_word2idx = imdb.get_word_index() imdb_idx2word = dict((idx, word) for (word, idx) in imdb_word2idx.items()) X_all = [[imdb_idx2word.get(idx - 3, '?') for idx in sen][1:] for sen in X_all] # print(y_all[0: 1], X_all[0: 1]) w2vModel = train_W2V(X_all, embedSize=300, epoch_num=2) word2idx, embedMatrix = build_word2idx_embedMatrix(w2vModel) # 制作word2idx和embedMatrix X_all_idx = make_deepLearn_data(X_all, word2idx) # 制作符合要求的深度学习数据 y_all_idx = np.array(y_all) # 一定要注意,X_all和y_all必须是np.array()类型,否则报错 # print(y_all_idx[0: 1], X_all_idx[0: 1]) X_tra_idx, X_val_idx, y_tra_idx, y_val_idx = train_test_split(X_all_idx, y_all_idx, test_size=0.2, random_state=0, stratify=y_all_idx) print('————————————————模型的训练和预测————————————————') model = Lstm_model(embedMatrix) model.fit(X_tra_idx, y_tra_idx, validation_data=(X_val_idx, y_val_idx), epochs=1, batch_size=100, verbose=1) y_pred = model.predict(X_val_idx) y_pred_idx = [1 if prob[0] > 0.5 else 0 for prob in y_pred] print(f1_score(y_val_idx, y_pred_idx)) print(confusion_matrix(y_val_idx, y_pred_idx))

 

转载于:https://www.cnblogs.com/liguangchuang/p/10074075.html

相关文章:

  • 字符串匹配基础上
  • Curator教程(一)快速入门
  • 阿里云搭建hadoop集群服务器,内网、外网访问问题(详解。。。)
  • 枚举与switch组合使用
  • 如何用纯 CSS 创作一个货车 loader
  • 阿里云马劲:保证云产品持续拥有稳定性的实践和思考
  • C# 获取对象 大小 Marshal.SizeOf (sizeof 只能在不安全的上下文中使用)
  • Oracle-SQL*Plus 简单操作
  • thinkphp 使用paginate分页搜索带参数
  • Money去哪了- 每日站立会议
  • ethereumjs/merkle-patricia-tree-2-API
  • 腾讯音乐赴美IPO仅11亿美元,疑受科技股抛售和中美贸易战影响
  • 【quick-cocos2d-lua】 基本类及用法
  • mysql安装时无法启动服务解决方案
  • Linux初级运维(十五)——bash脚本编程之函数
  • css选择器
  • Date型的使用
  • express + mock 让前后台并行开发
  • iOS 系统授权开发
  • JavaScript新鲜事·第5期
  • java中的hashCode
  • Protobuf3语言指南
  • SAP云平台里Global Account和Sub Account的关系
  • Spring-boot 启动时碰到的错误
  • Yeoman_Bower_Grunt
  • 服务器从安装到部署全过程(二)
  • 极限编程 (Extreme Programming) - 发布计划 (Release Planning)
  • 猫头鹰的深夜翻译:Java 2D Graphics, 简单的仿射变换
  • 普通函数和构造函数的区别
  • 前端技术周刊 2019-02-11 Serverless
  • 前端之React实战:创建跨平台的项目架构
  • 小程序01:wepy框架整合iview webapp UI
  • 深度学习之轻量级神经网络在TWS蓝牙音频处理器上的部署
  • gunicorn工作原理
  • Hibernate主键生成策略及选择
  • zabbix3.2监控linux磁盘IO
  • ​secrets --- 生成管理密码的安全随机数​
  • #绘制圆心_R语言——绘制一个诚意满满的圆 祝你2021圆圆满满
  • #使用清华镜像源 安装/更新 指定版本tensorflow
  • (ibm)Java 语言的 XPath API
  • (libusb) usb口自动刷新
  • (二)pulsar安装在独立的docker中,python测试
  • (附源码)spring boot基于Java的电影院售票与管理系统毕业设计 011449
  • (六)激光线扫描-三维重建
  • (篇九)MySQL常用内置函数
  • (七)MySQL是如何将LRU链表的使用性能优化到极致的?
  • (四)鸿鹄云架构一服务注册中心
  • (一)WLAN定义和基本架构转
  • (一)基于IDEA的JAVA基础12
  • (转)memcache、redis缓存
  • *setTimeout实现text输入在用户停顿时才调用事件!*
  • ... fatal error LINK1120:1个无法解析的外部命令 的解决办法
  • .net core 6 集成 elasticsearch 并 使用分词器
  • .NET core 自定义过滤器 Filter 实现webapi RestFul 统一接口数据返回格式
  • .NETCORE 开发登录接口MFA谷歌多因子身份验证