使用paddle进行酒店评论的情感分类5——batch准备
把原始语料中的每个句子通过截断和填充,转换成一个固定长度的句子,并将所有数据整理成mini-batch,用于训练模型,下面代码参照paddle官方
# 库文件导入
# encoding=utf8
import re
import random
import requests
import numpy as np
import paddle
from paddle.nn import Embedding
import paddle.nn.functional as F
from paddle.nn import LSTM, Embedding, Dropout, Linear
import os
import jieba
import paddle.fluidimport build_dict
import convert_corpus_to_id
import data_preprocess
import load_comment# 编写一个迭代器,每次调用这个迭代器都会返回一个新的batch,用于训练或者预测
def build_batch(word2id_dict, corpus, batch_size, epoch_num, max_seq_len, shuffle = True, drop_last = True):# 模型将会接受的两个输入:# 1. 一个形状为[batch_size, max_seq_len]的张量,sentence_batch,代表了一个mini-batch的句子。# 2. 一个形状为[batch_size, 1]的张量,sentence_label_batch,每个元素都是非0即1,代表了每个句子的情感类别(正向或者负向)sentence_batch = []sentence_label_batch = []for _ in range(epoch_num): #每个epoch前都shuffle一下数据,有助于提高模型训练的效果#但是对于预测任务,不要做数据shuffleif shuffle:random.shuffle(corpus)for sentence, sentence_label in corpus:sentence_sample = sentence[:min(max_seq_len, len(sentence))]if len(sentence_sample) < max_seq_len:for _ in range(max_seq_len - len(sentence_sample)):sentence_sample.append(word2id_dict['[pad]'])sentence_sample = [[word_id] for word_id in sentence_sample]sentence_batch.append(sentence_sample)sentence_label_batch.append([sentence_label])if len(sentence_batch) == batch_size:yield np.array(sentence_batch).astype("int64"), np.array(sentence_label_batch).astype("int64")sentence_batch = []sentence_label_batch = []if not drop_last and len(sentence_batch) > 0: # 控制样本数量不能被批次整除时的行为,若为真则丢弃最后一批样本yield np.array(sentence_batch).astype("int64"), np.array(sentence_label_batch).astype("int64")train_corpus = load_comment.load_comment(True)
train_corpus = data_preprocess.data_preprocess(train_corpus)
word2id_freq, word2id_dict = build_dict.build_dict(train_corpus)
train_corpus = convert_corpus_to_id.convert_corpus_to_id(train_corpus, word2id_dict)for batch_id, batch in enumerate(build_batch(word2id_dict, train_corpus, batch_size=3, epoch_num=3, max_seq_len=40)): # 此处train_corpus输入的是covert_corpus_to_id之后的内容print(batch)break```