Apa output dari tf.nn.dynamic_rnn ()?

8

Saya tidak yakin tentang apa yang saya mengerti dari dokumentasi resmi, yang mengatakan:

Pengembalian: Sepasang (keluaran, status) tempat:

outputs: Tensor keluaran RNN.

Jika time_major == False(default), ini akan menjadi Tensor berbentuk: [batch_size, max_time, cell.output_size].

Jika time_major == True, ini akan menjadi Tensor berbentuk: [max_time, batch_size, cell.output_size].

Catatan, jika cell.output_sizetupel integer atau objek TensorShape (mungkin bersarang), maka output akan menjadi tupel yang memiliki struktur yang sama dengan cell.output_size, berisi Tensor yang memiliki bentuk yang sesuai dengan data bentuk di cell.output_size.

state: Keadaan akhir. Jika cell.state_size adalah int, ini akan dibentuk [batch_size, cell.state_size]. Jika itu adalah TensorShape, ini akan berbentuk [batch_size] + cell.state_size. Jika itu adalah tuple int atau TensorShape (mungkin bersarang), ini akan menjadi tuple dengan bentuk yang sesuai. Jika sel adalah status LSTMCells akan menjadi tuple yang berisi LSTMStateTuple untuk setiap sel.

Apakah output[-1] selalu (dalam ketiga jenis sel yaitu RNN, GRU, LSTM) sama dengan keadaan (elemen kedua dari tuple pengembalian)? Saya kira literatur di mana-mana terlalu liberal dalam penggunaan istilah negara tersembunyi. Apakah keadaan tersembunyi di ketiga sel skor keluar (mengapa disebut tersembunyi di luar saya, akan muncul keadaan sel di LSTM harus disebut keadaan tersembunyi karena tidak terpapar)?

MiloMinderbinder
sumber

Jawaban:

10

Ya, output sel sama dengan status tersembunyi. Dalam kasus LSTM, ini adalah bagian jangka pendek dari tuple (elemen kedua LSTMStateTuple), seperti yang dapat dilihat pada gambar ini:

LSTM

Tetapi untuk tf.nn.dynamic_rnn, keadaan yang dikembalikan mungkin berbeda ketika urutan lebih pendek ( sequence_lengthargumen). Lihatlah contoh ini:

n_steps = 2
n_inputs = 3
n_neurons = 5

X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
seq_length = tf.placeholder(tf.int32, [None])

basic_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_neurons)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, sequence_length=seq_length, dtype=tf.float32)

X_batch = np.array([
  # t = 0      t = 1
  [[0, 1, 2], [9, 8, 7]], # instance 0
  [[3, 4, 5], [0, 0, 0]], # instance 1
  [[6, 7, 8], [6, 5, 4]], # instance 2
  [[9, 0, 1], [3, 2, 1]], # instance 3
])
seq_length_batch = np.array([2, 1, 2, 2])

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  outputs_val, states_val = sess.run([outputs, states], 
                                     feed_dict={X: X_batch, seq_length: seq_length_batch})

  print(outputs_val)
  print()
  print(states_val)

Di sini kumpulan input berisi 4 urutan dan salah satunya pendek dan diisi dengan nol. Saat berlari Anda akan menemukan sesuatu seperti ini:

[[[ 0.2315362  -0.37939444 -0.625332   -0.80235624  0.2288385 ]
  [ 0.9999524   0.99987394  0.33580178 -0.9981791   0.99975705]]

 [[ 0.97374666  0.8373545  -0.7455188  -0.98751736  0.9658986 ]
  [ 0.          0.          0.          0.          0.        ]]

 [[ 0.9994331   0.9929737  -0.8311569  -0.99928087  0.9990415 ]
  [ 0.9984355   0.9936006   0.3662448  -0.87244385  0.993848  ]]

 [[ 0.9962312   0.99659646  0.98880637  0.99548346  0.9997809 ]
  [ 0.9915743   0.9936939   0.4348318   0.8798458   0.95265496]]]

[[ 0.9999524   0.99987394  0.33580178 -0.9981791   0.99975705]
 [ 0.97374666  0.8373545  -0.7455188  -0.98751736  0.9658986 ]
 [ 0.9984355   0.9936006   0.3662448  -0.87244385  0.993848  ]
 [ 0.9915743   0.9936939   0.4348318   0.8798458   0.95265496]]

... yang memang menunjukkan itu state == output[1]untuk urutan penuh dan state == output[0]untuk yang pendek. Juga output[1]merupakan vektor nol untuk urutan ini. Hal yang sama berlaku untuk sel LSTM dan GRU.

Jadi stateadalah tensor yang nyaman yang memegang keadaan RNN aktual terakhir , mengabaikan nol. The outputtensor memegang output dari semua sel, sehingga tidak mengabaikan nol. Itulah alasan untuk mengembalikan keduanya.

Pepatah
sumber
2

Kemungkinan salinan /programming/36817596/get-last-output-of-dynamic-rnn-in-tensorflow/49705930#49705930

Pokoknya mari kita lanjutkan dengan jawabannya.

Snip kode ini dapat membantu memahami apa yang sebenarnya dikembalikan oleh dynamic_rnnlayer

=> Tuple of (keluaran, final_output_state) .

Jadi untuk input dengan panjang urutan maksimum dari T waktu langkah output adalah bentuk [Batch_size, T, num_inputs](diberikan time_major= Salah; nilai default) dan berisi status output pada setiap catatan waktu h1, h2.....hT.

Dan final_output_state memiliki bentuk [Batch_size,num_inputs]dan memiliki status sel akhir cTdan status keluaran hTdari setiap urutan bets.

Tapi karena dynamic_rnnsedang digunakan tebakan saya adalah panjang urutan Anda bervariasi untuk setiap batch.

    import tensorflow as tf
    import numpy as np
    from tensorflow.contrib import rnn
    tf.reset_default_graph()

    # Create input data
    X = np.random.randn(2, 10, 8)

    # The second example is of length 6 
    X[1,6:] = 0
    X_lengths = [10, 6]

    cell = tf.nn.rnn_cell.LSTMCell(num_units=64, state_is_tuple=True)

    outputs, states  = tf.nn.dynamic_rnn(cell=cell,
                                         dtype=tf.float64,
                                         sequence_length=X_lengths,
                                         inputs=X)

    result = tf.contrib.learn.run_n({"outputs": outputs, "states":states},
                                    n=1,
                                    feed_dict=None)
    assert result[0]["outputs"].shape == (2, 10, 64)
    print result[0]["outputs"].shape
    print result[0]["states"].h.shape
    # the final outputs state and states returned must be equal for each      
    # sequence
    assert(result[0]["outputs"][0][-1]==result[0]["states"].h[0]).all()
    assert(result[0]["outputs"][-1][5]==result[0]["states"].h[-1]).all()
    assert(result[0]["outputs"][-1][-1]==result[0]["states"].h[-1]).all()

Penegasan akhir akan gagal karena keadaan akhir untuk urutan ke-2 adalah pada langkah ke-6 yaitu. indeks 5 dan sisa output dari [6: 9] semua 0s dalam catatan waktu ke-2

Bhaskar Arun
sumber