Have been doing some latelyText classificationDuring the problem, Bilstm is frequently used, for .bidirectional_dynamic_rnn()functionMore use, the author has introduced the .dynamic_rnn() function before. Based on this, please refer to/wuzqChom/article/details/75453327and/taolusi/article/details/81232210Two blogs, based on their own understanding, explain the .bidirectional_dynamic_rnn() function in detail.
First, let's understand the parameters of the function
-
bidirectional_dynamic_rnn(
-
cell_fw, # Forward RNN
-
cell_bw, # Backward RNN
-
inputs, # input
-
sequence_length=None,# The actual length of the input sequence (optional, default is the maximum length of the input sequence)
-
initial_state_fw=None, # Forward initialization status (optional)
-
initial_state_bw=None, # Backward initialization status (optional)
-
dtype=None, # Initialization and output data type (optional)
-
parallel_iterations=None,
-
swap_memory=False,
-
time_major=False,
-
scope=None)
It is worth noting that when the tensor shape of inputs is [batch_size,max_len,embeddings_num], time_major = False. When the shape of inputs is [max_len,batch_size,embeddings_num], time_major = True. Generally, we will enter the format [batch_size,max_len,embeddings_num], so the default value of time_major is False.
The input of the function is similar to .dynamic_rnn() and consists of (outputs, outputs_states).
- outputs are (output_fw, output_bw), which is a tuple consisting of forward cell output tensor and backward cell output tensor. When time_major = False, the shape of output_fw and output_bw are [batch_size,max_len,hiddens_num]. In this case, the final outputs can be ([output_fw, output_bw],-1) or ([output_fw, output_bw],2), and the [output_fw, output_bw] in this case can be replaced directly with outputs. For reference/leviopku/article/details/82380118
- output_states is (output_state_fw, output_state_bw), which contains the tuple composed of the last hidden state of the forward and backward directions. The types of output_state_fw and output_state_bw are LSTMStateTuple, which consists of (c,h), representing memory cell and hidden state respectively.
The two projects that the author has done recently are text classification based on Bilstm and Chinese entity extraction. For text classification, the output of the last time_step is needed, while the Chinese entity extraction requires the final outputs, that is, the output of all time_steps.
-
#Text classification can obtain the final input state from the following method
-
-
outputs, outputs_state = .bidirectional_dynamic_rnn(lstm_fw_cell_m, lstm_bw_cell_m, embedding_inputs,time_major = False,dtype = 32)
-
output_fw = outputs[0]
-
output_bw = outputs[1]#Original shape is [batch_size,max_len,hidden_num]
-
output_fw = (output_fw,[1,0,2])#The shape is now [max_len,batch_size,hidden_num]
-
output_bw = (output_bw,[1,0,2])
-
outputs1 = [output_fw,output_bw]
-
lstmoutputs = (outputs1, 2)#The shape after connection is [max_len,batch_size,2*hidden_num]
-
last = lstmoutputs[-1]#Last onetimeThe output of _step is [batch_size,2*hidden_num]
-
-
-
#Chinese Entity Extraction
-
(output_fw_seq, output_bw_seq), _ = .bidirectional_dynamic_rnn(cell_fw=cell_fw,cell_bw=cell_bw,inputs=self.word_embeddings,sequence_length=self.sequence_lengths,dtype=32)
-
output = ([output_fw_seq, output_bw_seq],axis=-1) # time_major = False, so the input is [batch_size,time_step,embedding_dim], so this connection is equivalent to axis= 2