当前位置: 首页 > news >正文

使用paddle进行酒店评论的情感分类5——batch准备

把原始语料中的每个句子通过截断和填充,转换成一个固定长度的句子,并将所有数据整理成mini-batch,用于训练模型,下面代码参照paddle官方


# 库文件导入
# encoding=utf8
import re
import random
import requests
import numpy as np
import paddle
from paddle.nn import Embedding
import paddle.nn.functional as F
from paddle.nn import LSTM, Embedding, Dropout, Linear
import os
import jieba
import paddle.fluidimport build_dict
import convert_corpus_to_id
import data_preprocess
import load_comment# 编写一个迭代器,每次调用这个迭代器都会返回一个新的batch,用于训练或者预测
def build_batch(word2id_dict, corpus, batch_size, epoch_num, max_seq_len, shuffle = True, drop_last = True):# 模型将会接受的两个输入:# 1. 一个形状为[batch_size, max_seq_len]的张量,sentence_batch,代表了一个mini-batch的句子。# 2. 一个形状为[batch_size, 1]的张量,sentence_label_batch,每个元素都是非0即1,代表了每个句子的情感类别(正向或者负向)sentence_batch = []sentence_label_batch = []for _ in range(epoch_num): #每个epoch前都shuffle一下数据,有助于提高模型训练的效果#但是对于预测任务,不要做数据shuffleif shuffle:random.shuffle(corpus)for sentence, sentence_label in corpus:sentence_sample = sentence[:min(max_seq_len, len(sentence))]if len(sentence_sample) < max_seq_len:for _ in range(max_seq_len - len(sentence_sample)):sentence_sample.append(word2id_dict['[pad]'])sentence_sample = [[word_id] for word_id in sentence_sample]sentence_batch.append(sentence_sample)sentence_label_batch.append([sentence_label])if len(sentence_batch) == batch_size:yield np.array(sentence_batch).astype("int64"), np.array(sentence_label_batch).astype("int64")sentence_batch = []sentence_label_batch = []if not drop_last and len(sentence_batch) > 0: # 控制样本数量不能被批次整除时的行为,若为真则丢弃最后一批样本yield np.array(sentence_batch).astype("int64"), np.array(sentence_label_batch).astype("int64")train_corpus =  load_comment.load_comment(True)
train_corpus = data_preprocess.data_preprocess(train_corpus)
word2id_freq, word2id_dict = build_dict.build_dict(train_corpus)
train_corpus = convert_corpus_to_id.convert_corpus_to_id(train_corpus, word2id_dict)for batch_id, batch in enumerate(build_batch(word2id_dict, train_corpus, batch_size=3, epoch_num=3, max_seq_len=40)): # 此处train_corpus输入的是covert_corpus_to_id之后的内容print(batch)break```
http://www.lryc.cn/news/114451.html

相关文章:

  • 04-1_Qt 5.9 C++开发指南_常用界面设计组件_字符串QString
  • Centos 从0搭建grafana和Prometheus 服务以及问题解决
  • 【代码解读】RRNet: A Hybrid Detector for Object Detection in Drone-captured Images
  • python人工智能可以干什么,python人工智能能干什么
  • K8s工作原理
  • go错误集(持续更新)
  • 【Docker】Docker中network的概要、常用命令、网络模式以及底层ip和容器映射变化的详细讲解
  • arcgis栅格数据之最佳路径分析
  • docker服务器部署Django
  • SpringBoot集成百度人脸识别实现登陆注册功能Demo(二)
  • FPGA纯verilog实现 LZMA 数据压缩,提供工程源码和技术支持
  • C++实现一个链栈
  • Vue电商项目--VUE插件的使用及原理
  • 2.部署kubernetes的组件
  • 后端开发4.Elasticsearch的搭建
  • 嵌入式该往哪个方向发展?
  • 非凸科技受邀参加中科大线上量化分享
  • Linux 命令之 - chown(改变文件拥有者及所属组)
  • 【基于openharmony的多路摄像头功能:USB设备插拔检测】
  • uni-app:实现数字文本框,以及左右加减按钮
  • 跨平台开发框架Qt:面向对象、丰富API
  • An unexpected error has occurred. Conda has prepared the above report
  • 考研C语言进阶题库——更新6-10题
  • 汽车BOOTLOADER开发经历
  • 适配器模式(C++)
  • HTTP连接之出现400 Bad Request分析
  • 后端开发, 接口幂等性是什么意思
  • k8s手动发布镜像的方法
  • 十二、ESP32控制步进电机
  • 利用openTCS实现车辆调度系统(六)openTCS订单的使用