Apa sebenarnya mekanisme perhatian?

23

Mekanisme perhatian telah digunakan dalam berbagai makalah Deep Learning dalam beberapa tahun terakhir. Ilya Sutskever, kepala penelitian di Open AI, dengan antusias memuji mereka: https://towardsdatascience.com/the-fall-of-rnn-lstm-2d1594c74ce0

Eugenio Culurciello di Purdue University telah mengklaim bahwa RNN dan LSTM harus ditinggalkan demi jaringan saraf murni berdasarkan perhatian:

https://towardsdatascience.com/the-fall-of-rnn-lstm-2d1594c74ce0

Ini tampaknya berlebihan, tetapi tidak dapat dipungkiri bahwa model murni berbasis perhatian telah melakukan cukup baik dalam tugas pemodelan urutan: kita semua tahu tentang kertas yang dinamai tepat dari Google, Perhatian adalah semua yang Anda butuhkan

Namun, apa sebenarnya model berbasis perhatian? Saya belum menemukan penjelasan yang jelas tentang model-model tersebut. Misalkan saya ingin memperkirakan nilai baru dari rangkaian waktu multivarian, mengingat nilai historisnya. Cukup jelas bagaimana cara melakukannya dengan RNN yang memiliki sel LSTM. Bagaimana saya melakukan hal yang sama dengan model berbasis perhatian?

DeltaIV
sumber

Jawaban:

20

Perhatian adalah metode untuk menggabungkan satu set vektor menjadi hanya satu vektor, seringkali melalui vektor lookup . Biasanya, adalah input ke model atau status tersembunyi dari langkah-langkah sebelumnya, atau status tersembunyi satu tingkat ke bawah (dalam kasus tumpukan LSTM).viuvi

Hasilnya sering disebut vektor konteks , karena berisi konteks yang relevan dengan langkah waktu saat ini.c

Vektor konteks tambahan ini c kemudian dimasukkan ke dalam RNN / LSTM juga (dapat disatukan dengan input asli). Oleh karena itu, konteksnya dapat digunakan untuk membantu prediksi.

p=softmax(VTu)c=ipiviVviuht

viTuf(vi,u)f

p=softmax(qTtanh(W1vi+W2ht))vhtqW

Beberapa makalah yang memamerkan berbagai variasi pada gagasan perhatian:

Pointer Networks menggunakan perhatian pada input referensi untuk menyelesaikan masalah optimisasi kombinatorial.

Jaringan Entitas Berulang mempertahankan status memori terpisah untuk entitas yang berbeda (orang / objek) saat membaca teks, dan memperbarui status memori yang benar menggunakan perhatian.

kipvi


Berikut ini adalah implementasi cepat dari satu bentuk perhatian, meskipun saya tidak dapat menjamin kebenaran di luar kenyataan bahwa itu melewati beberapa tes sederhana.

RNN dasar:

def rnn(inputs_split):
    bias = tf.get_variable('bias', shape = [hidden_dim, 1])
    weight_hidden = tf.tile(tf.get_variable('hidden', shape = [1, hidden_dim, hidden_dim]), [batch, 1, 1])
    weight_input = tf.tile(tf.get_variable('input', shape = [1, hidden_dim, in_dim]), [batch, 1, 1])

    hidden_states = [tf.zeros((batch, hidden_dim, 1), tf.float32)]
    for i, input in enumerate(inputs_split):
        input = tf.reshape(input, (batch, in_dim, 1))
        last_state = hidden_states[-1]
        hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )
        hidden_states.append(hidden)
    return hidden_states[-1]

Dengan perhatian, kami menambahkan hanya beberapa baris sebelum keadaan tersembunyi yang baru dihitung:

        if len(hidden_states) > 1:
            logits = tf.transpose(tf.reduce_mean(last_state * hidden_states[:-1], axis = [2, 3]))
            probs = tf.nn.softmax(logits)
            probs = tf.reshape(probs, (batch, -1, 1, 1))
            context = tf.add_n([v * prob for (v, prob) in zip(hidden_states[:-1], tf.unstack(probs, axis = 1))])
        else:
            context = tf.zeros_like(last_state)

        last_state = tf.concat([last_state, context], axis = 1)

        hidden = tf.nn.tanh( tf.matmul(weight_input, input) + tf.matmul(weight_hidden, last_state) + bias )

kode lengkap

shimao
sumber
p=softmax(VTu)ic=ipivipiVTvVTv
1
zi=viTup=softmax(z)pi=eizjejz
ppi
1
ya, itulah yang saya maksud
shimao
@shimao Saya membuat ruang obrolan , beri tahu saya jika Anda tertarik untuk berbicara (bukan tentang pertanyaan ini)
DeltaIV