MATLAB文本分析:13:使用深度学习对文本数据进行分类 您所在的位置:网站首页 matlab数据类型转换 MATLAB文本分析:13:使用深度学习对文本数据进行分类

MATLAB文本分析:13:使用深度学习对文本数据进行分类

2023-03-24 06:50| 来源: 网络整理| 查看: 265

此示例说明如何使用深度学习长短期记忆 (LSTM) 网络对文本数据进行分类。

文本数据自然是顺序的。一段文本是一系列单词,它们之间可能具有依赖关系。要学习和使用长期依赖关系对序列数据进行分类,请使用 LSTM 神经网络。LSTM 网络是一种循环神经网络 (RNN),可以学习序列数据时间步长之间的长期依赖关系。

要将文本输入到 LSTM 网络,首先要将文本数据转换为数字序列。您可以使用将文档映射到数字索引序列的单词编码来实现此目的。为了获得更好的结果,还在网络中包含一个词嵌入层。词嵌入将词汇表中的词映射到数字向量而不是标量索引。这些嵌入捕获了单词的语义细节,因此具有相似含义的单词具有相似的向量。词嵌入还通过矢量算法对单词之间的关系进行建模。例如,“罗马之于意大利,巴黎之于法国”的关系可以用等式意大利 –罗马 + 巴黎 = 法国来描述。

本例中训练和使用LSTM网络有四个步骤:

导入和预处理数据。使用单词编码将单词转换为数字序列。使用词嵌入层创建和训练 LSTM 网络。使用经过训练的 LSTM 网络对新文本数据进行分类。导入数据

导入工厂报告数据。此数据包含工厂事件的分词文本描述。要将文本数据作为字符串导入,请将文本类型指定为'string'.

filename = "factoryReports.csv"; data = readtable(filename,'TextType','string'); head(data) ans=8×5 table Description Category Urgency Resolution Cost _____________________________________________________________________ ____________________ ________ ____________________ _____ ​ "Items are occasionally getting stuck in the scanner spools." "Mechanical Failure" "Medium" "Readjust Machine" 45 "Loud rattling and banging sounds are coming from assembler pistons." "Mechanical Failure" "Medium" "Readjust Machine" 35 "There are cuts to the power when starting the plant." "Electronic Failure" "High" "Full Replacement" 16200 "Fried capacitors in the assembler." "Electronic Failure" "High" "Replace Components" 352 "Mixer tripped the fuses." "Electronic Failure" "Low" "Add to Watch List" 55 "Burst pipe in the constructing agent is spraying coolant." "Leak" "High" "Replace Components" 371 "A fuse is blown in the mixer." "Electronic Failure" "Low" "Replace Components" 441 "Things continue to tumble off of the belt." "Mechanical Failure" "Low" "Readjust Machine" 38

此示例的目标是按Category列中的标签对事件进行分类。要将数据分类,请将这些标签转换为分类标签。

data.Category = categorical(data.Category);

使用直方图查看数据中类的分布。

figure histogram(data.Category); xlabel("Class") ylabel("Frequency") title("Class Distribution")

下一步是将其划分为用于训练和验证的集合。将数据划分区为训练集和留出集以进行验证和测试。将留出百分比指定为 20%。

cvp = cvpartition(data.Category,'Holdout',0.2); dataTrain = data(training(cvp),:); dataValidation = data(test(cvp),:);

从分区表中提取文本数据和标签。

textDataTrain = dataTrain.Description; textDataValidation = dataValidation.Description; YTrain = dataTrain.Category; YValidation = dataValidation.Category;

要检查您是否已正确导入数据,请使用词云可视化训练文本数据。

figure wordcloud(textDataTrain); title("Training Data")

预处理文本数据

创建一个对文本数据进行分词和预处理的函数。示例末尾列出的函数preprocessText执行以下步骤:

使用tokenizedDocument切分文本。使用lower将文本转换为小写。使用erasePunctuation删除标点符号。

使用函数preprocessText,预处理训练数据和验证数据。

documentsTrain = preprocessText(textDataTrain); documentsValidation = preprocessText(textDataValidation);

查看前几个预处理的训练文档。

documentsTrain(1:5) ans = 5×1 tokenizedDocument: ​ 9 tokens: items are occasionally getting stuck in the scanner spools 10 tokens: loud rattling and banging sounds are coming from assembler pistons 10 tokens: there are cuts to the power when starting the plant 5 tokens: fried capacitors in the assembler 4 tokens: mixer tripped the fuses将文档转换为序列

要将文档输入 LSTM 网络,请使用单词编码将文档转换为数字索引序列。

使用函数wordEncoding创建单词编码。

enc = wordEncoding(documentsTrain);

下一个转换步骤是填充和截断文档,使它们的长度都相同。函数trainingOptions提供了自动填充和截断输入序列的选项。然而,这些选项不太适合词向量序列。相反,需要手动填充和截断序列。如果您左填充并截断词向量序列,那么训练可能会有所改善。

填充和截断文档,首先选择一个目标长度,然后截断比它长的文档和左填充比它短的文档。为获得最佳结果,目标长度应较短且不会丢弃大量数据。要找到合适的目标长度,请查看训练文档长度的直方图。

documentLengths = doclength(documentsTrain); figure histogram(documentLengths) title("Document Lengths") xlabel("Length") ylabel("Number of Documents")

大多数训练文档的分词数少于 10 个,可以把这个数值作为截断和填充的目标长度。

使用doc2sequence将文档转换为数字索引序列。要截断或向左填充序列以使其长度为 10,请将选项'Length'设置为 10。

sequenceLength = 10; XTrain = doc2sequence(enc,documentsTrain,'Length',sequenceLength); XTrain(1:5) ans=5×1 cell array {1×10 double} {1×10 double} {1×10 double} {1×10 double} {1×10 double}

使用相同的选项将验证文档转换为序列。

XValidation = doc2sequence(enc,documentsValidation,'Length',sequenceLength);创建和训练 LSTM 网络

定义 LSTM 网络架构。要将序列数据输入网络,包括一个序列输入层并将输入大小设置为 1。接下来,包含一个维度为 50 的词嵌入层和与单词编码相同数量的词。接下来,包含一个 LSTM 层并将隐藏单元数设置为 80。要将 LSTM 层用于序列-标签分类问题,请将输出模式设置为'last'。最后,添加一个与类数大小相同的全连接层、一个softmax层和一个分类层。

inputSize = 1; embeddingDimension = 50; numHiddenUnits = 80; ​ numWords = enc.NumWords; numClasses = numel(categories(YTrain)); ​ layers = [ ... sequenceInputLayer(inputSize) wordEmbeddingLayer(embeddingDimension,numWords) lstmLayer(numHiddenUnits,'OutputMode','last') fullyConnectedLayer(numClasses) softmaxLayer classificationLayer] layers = 6x1 Layer array with layers: ​ 1 '' Sequence Input Sequence input with 1 dimensions 2 '' Word Embedding Layer Word embedding layer with 50 dimensions and 423 unique words 3 '' LSTM LSTM with 80 hidden units 4 '' Fully Connected 4 fully connected layer 5 '' Softmax softmax 6 '' Classification Output crossentropyex指定训练选项

指定训练选项:

使用 Adam 求解器进行训练。指定小批量mini-batch大小为 16。每个epoch都重新打乱数据。通过将选项'Plots'设置为'training-progress'来监控训练进度。使用选项'ValidationData'指定验证数据。通过将选项'Verbose'设置为false来抑制详细输出。

默认情况下,trainNetwork要使用到 GPU;如果没有GPU,就使用 CPU。要手动指定执行环境,请使用 trainingOptions的'ExecutionEnvironment'名称-值对参数。在 CPU 上训练比在 GPU 上训练花费的时间要长得多。使用 GPU 进行训练需要 Parallel Computing Toolbox™ 和受支持的 GPU 设备。

options = trainingOptions('adam', ... 'MiniBatchSize',16, ... 'GradientThreshold',2, ... 'Shuffle','every-epoch', ... 'ValidationData',{XValidation,YValidation}, ... 'Plots','training-progress', ... 'Verbose',false);

使用trainNetwork函数训练 LSTM 网络。

net = trainNetwork(XTrain,YTrain,layers,options);

使用新数据进行预测

对三个新报告的事件类型进行分类。创建一个包含新报告的字符串数组。

reportsNew = [ ... "Coolant is pooling underneath sorter." "Sorter blows fuses at start up." "There are some very loud rattling sounds coming from the assembler."];

使用预处理步骤,将预处理的文本数据用作训练文档。

documentsNew = preprocessText(reportsNew);

使用doc2sequence,设置与创建训练序列时相同的选项,将文本数据转换为序列。

XNew = doc2sequence(enc,documentsNew,'Length',sequenceLength);

使用经过训练的 LSTM 网络对新序列进行分类。

labelsNew = classify(net,XNew) labelsNew = 3×1 categorical Leak Electronic Failure Mechanical Failure 预处理函数

函数preprocessText执行以下步骤:

使用tokenizedDocument标记文本。使用lower将文本转换为小写。使用erasePunctuation删除标点符号。function documents = preprocessText(textData) ​ % Tokenize the text. documents = tokenizedDocument(textData); ​ % Convert to lowercase. documents = lower(documents); ​ % Erase punctuation. documents = erasePunctuation(documents); ​ end

注:本文根据MATLAB官网内容修改而成。

关注公众号DataXY,获取免费MATLAB视频课程。欢迎您进一步了解以下MATLAB系列文章:



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有