Tensorflow2基础代码实战系列之CNN文本分类实战

news/2024/6/15 18:38:46 标签: tensorflow, cnn, 分类

深度学习框架Tensorflow2系列

注:大家觉得博客好的话,别忘了点赞收藏呀,本人每周都会更新关于人工智能和大数据相关的内容,内容多为原创,Python Java Scala SQL 代码,CV NLP 推荐系统等,Spark Flink Kafka Hbase Hive Flume等等~写的都是纯干货,各种顶会的论文解读,一起进步。
这个系列主要和大家分享深度学习框架Tensorflow2的各种api,从基础开始。
#博学谷IT学习技术支持#


文章目录

  • 深度学习框架Tensorflow2系列
  • 前言
  • 一、文本分类任务实战
  • 二、数据集介绍
  • 三、CNN模型解读
  • 四、实战代码
    • 1.数据预处理
    • 2.定义模型
    • 3.模型训练
  • 总结


前言

通过CNN文本分类实战案例,学习Tensorflow2中一些API


一、文本分类任务实战

任务介绍:
数据集构建:影评数据集进行情感分析(分类任务)
词向量模型:加载训练好的词向量或者自己训练都可以
序列网络模型:训练RNN模型进行识别

二、数据集介绍

训练和测试集都是比较简单的电影评价数据集,标签为0和1的二分类,表示对电影的喜欢和不喜欢
在这里插入图片描述

三、CNN模型解读

在这里插入图片描述
通过不同尺度的卷积核[(2,3,4),word_dim] 来提取单词的特征,再进行max_pooling得到一个特征值,最后把所有尺度得到的特征值拼接在一起后,通过全连接进行分类

四、实战代码

1.数据预处理

这里直接加载默认数据集,通过pad_sequences进行截断和填充操作
得到训练集和测试集大小都为(25000,300)
各25000个样本,每个样本长度为300
(25000, 300)
(25000, 300)

import warnings
warnings.filterwarnings("ignore")
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.sequence import pad_sequences

num_features = 3000
sequence_length = 300
embedding_dimension = 100
(x_train, y_train), (x_test, y_test) = keras.datasets.imdb.load_data(num_words=num_features)
x_train = pad_sequences(x_train, maxlen=sequence_length)
x_test = pad_sequences(x_test, maxlen=sequence_length)
print(x_train.shape)
print(x_test.shape)
print(y_train.shape)
print(y_test.shape)

2.定义模型

# 多种卷积核,相当于单词数
filter_sizes=[3,4,5]
def convolution():
    inn = layers.Input(shape=(sequence_length, embedding_dimension, 1))#3维的
    cnns = []
    for size in filter_sizes:
        conv = layers.Conv2D(filters=64, kernel_size=(size, embedding_dimension),
                            strides=1, padding='valid', activation='relu')(inn)
        #需要将多种卷积后的特征图池化成一个特征
        pool = layers.MaxPool2D(pool_size=(sequence_length-size+1, 1), padding='valid')(conv)
        cnns.append(pool)
    # 将得到的特征拼接在一起
    outt = layers.concatenate(cnns)

    model = keras.Model(inputs=inn, outputs=outt)
    return model

def cnn_mulfilter():
    model = keras.Sequential([
        layers.Embedding(input_dim=num_features, output_dim=embedding_dimension,
                        input_length=sequence_length),
        layers.Reshape((sequence_length, embedding_dimension, 1)),
        convolution(),
        layers.Flatten(),
        layers.Dense(10, activation='relu'),
        layers.Dropout(0.2),
        layers.Dense(1, activation='sigmoid')

    ])
    model.compile(optimizer=keras.optimizers.Adam(),
                 loss=keras.losses.BinaryCrossentropy(),
                 metrics=['accuracy'])
    return model

model = cnn_mulfilter()
model.summary()

在这里插入图片描述
得到的模型结构如上图,其中192来自3种不同卷积核通过max_pooling之后相加得到的结果(64*3=192),每种卷积卷积之后得到一个向量,通过max_pooling之后得到一个特征值,每种卷积核设置filters=64所有最终一个卷积核得到64个值。

3.模型训练

deom级别测试代码

history = model.fit(x_train, y_train, batch_size=64, epochs=5, validation_split=0.1)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.legend(['training', 'valiation'], loc='upper left')
plt.show()

总结

通过CNN文本分类任务代码案例实战,学习Tensorflow2的各种api。


http://www.niftyadmin.cn/n/359382.html

相关文章

STM8、STM8S003F3P6 双机串口通信(片上串口)

背景 这里为什么要写串口通信,因为实际项目上使用了串口,STM8S003F3P6的串口简单啊,不值得一提。本文写的串口确实简单,因为这里我想先从简单的写起来,慢慢的把难的引出来。这里呢,做个提纲说明&#xff0…

ASEMI代理长电MCR100-6可控硅的性能与应用分析

编辑-Z 本文主要介绍了新型MCR100-6晶闸管的性能与应用。首先,从晶闸管的基本原理和结构出发,分析了MCR100-6晶闸管的性能特点;其次,探讨了MCR100-6晶闸管在各种电子电路中的应用;最后,对MCR100-6晶闸管的…

chatgpt赋能python:Python中的Split函数:去空操作详解

Python中的Split函数:去空操作详解 在Python编程中,我们经常需要对字符串进行操作。而字符串的分割操作在其中是非常常见的操作。Python中的split函数便是用来实现字符串分割的函数。不过,在使用split函数时通常还需要经过去除空格等操作。 …

SpringBoot Controller层传入的参数进行解密

一、 应用场景 当和第三方应用对接系统的时候, 可能别人的参数加密方式和我们的不相同,那就需要和对方沟通好他们的接口参数是如何加密的,达成一致后才方便后续的工作开展。 二、示例说明 采用Springboot 项目开发,先在compone…

打家劫舍 III——力扣337

文章目录 题目描述法一:动态规划 题目描述 法一:动态规划 问题简化:一棵二叉树,树上的每个点都有对应的权值,每个点有两种状态(选中和不选中),问在不能同时选中有父子关系的点的情况…

00_JS基础_ES6

js的标准ECMAScript(ES),现在使用的版本为ES6 js编写的位置 1.写在HTML中的scrip标签 <script>//内嵌式console.log("hello world") </script> <!--引入外部的js文件,script不能使用单标签-->2.引用中使用 <script src"../js/01_index…

vs中计算代码行数

在vs中依次点击以下几个菜单按钮&#xff1a;”编辑“&#xff0c;”查找和替换“&#xff0c;”在文件中查找“&#xff0c;然后输入如下表达式&#xff0c; b*[^:b#/].*$并点击”使用正则表达式“复选框后&#xff0c;然后再”查找范围“选项卡中选择解决方案或者工程或者本…

drf-----认证组件

认证组件 认证组件使用步骤&#xff08;固定用法&#xff09; 1 写一个类&#xff0c;继承BaseAuthentication 2 在类中写&#xff1a;authenticate 3 在方法中&#xff0c;完成登录认证&#xff0c;如果 不是登录的&#xff0c;抛异常 4 如果是登录的&#xff0c…