LSTM每个变量的shape分析
首先要理解一点,LSTM中迭代不光是要遍历每个批量,在处理每个批量时,还需要以时间步来推进,这样才能体现时间上的关系,才能得到state。
如下是权重的初始化以及RNN的model定义:
def get_lstm_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xi, W_hi, b_i = three() # 输入门参数W_xf, W_hf, b_f = three() # 遗忘门参数W_xo, W_ho, b_o = three() # 输出门参数W_xc, W_hc, b_c = three() # 候选记忆元参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,b_c, W_hq, b_q]for param in params:param.requires_grad_(True)return params
class RNNModel(nn.Module):"""The RNN model."""def __init__(self, rnn_layer, vocab_size, **kwargs):super(RNNModel, self).__init__(**kwargs)self.rnn = rnn_layerself.vocab_size = vocab_sizeself.num_hiddens = self.rnn.hidden_size# If the RNN is bidirectional (to be introduced later),# `num_directions` should be 2, else it should be 1.if not self.rnn.bidirectional:self.num_directions = 1self.linear = nn.Linear(self.num_hiddens, self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)def forward(self, inputs, state):X = F.one_hot(inputs.T.long(), self.vocab_size)X = X.to(torch.float32)Y, state = self.rnn(X, state)# The fully connected layer will first change the shape of `Y` to# (`num_steps` * `batch_size`, `num_hiddens`). Its output shape is# (`num_steps` * `batch_size`, `vocab_size`).output = self.linear(Y.reshape((-1, Y.shape[-1])))return output, statedef begin_state(self, device, batch_size=1):if not isinstance(self.rnn, nn.LSTM):# `nn.GRU` takes a tensor as hidden statereturn torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens), device=device)else:# `nn.LSTM` takes a tuple of hidden statesreturn (torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens),device=device),torch.zeros((self.num_directions * self.rnn.num_layers,batch_size, self.num_hiddens),device=device))
可以看到forward函数第一行代码,会将输入进行转置,然后进行独热编码,使得输入input的shape变化:[batch_size,seq_len,1]->[seq_len,batch_size,vocab_size]
接下来的输入会进入lstm层,定义如下:
def lstm(inputs, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,W_hq, b_q] = params(H, C) = stateoutputs = []for X in inputs:I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)C = F * C + I * C_tildaH = O * torch.tanh(C)#这里可以看出C与H的size是相同的Y = (H @ W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H, C)
这里有个迭代:for X in inputs,就是在迭代时间步,每次迭代就是一个时间步,因此此时的X的shape是[batch_size,vocab_size]
而W_xi的shape是[vocab_size,num_hiddens],故得到的I的shape应该是[batch_sizenum_hiddens]的,其他变量的分析也同理