RNNs: Kapan menerapkan BPTT dan / atau memperbarui bobot?

15

Saya mencoba memahami aplikasi tingkat tinggi RNNs untuk pelabelan urutan melalui (antara lain) makalah Graves 2005 tentang klasifikasi fonem.

Untuk meringkas masalah: Kami memiliki satu set pelatihan besar yang terdiri dari (input) file audio dari kalimat tunggal dan (output) waktu mulai berlabel ahli, waktu berhenti dan label untuk fonem individu (termasuk beberapa fonem "istimewa" seperti diam, sedemikian sehingga setiap sampel dalam setiap file audio diberi label dengan beberapa simbol fonem.)

Tujuan makalah ini adalah untuk menerapkan RNN dengan sel memori LSTM di lapisan tersembunyi untuk masalah ini. (Dia menerapkan beberapa varian dan beberapa teknik lainnya sebagai pembanding. Saya untuk saat ini HANYA tertarik pada LSTM searah, untuk menjaga hal-hal sederhana.)

Saya percaya saya memahami arsitektur jaringan: Lapisan input yang sesuai dengan 10 ms windows dari file audio, diproses dengan cara standar untuk pekerjaan audio; lapisan tersembunyi sel LSTM, dan lapisan keluaran dengan pengkodean satu-panas dari semua kemungkinan 61 simbol telepon.

Saya percaya saya memahami persamaan (rumit tapi langsung) dari forward pass dan backward pass melalui unit LSTM. Mereka hanyalah kalkulus dan aturan rantai.

Yang tidak saya mengerti, setelah membaca makalah ini dan beberapa yang serupa beberapa kali, adalah kapan tepatnya menerapkan algoritma backpropagation dan kapan tepatnya memperbarui berbagai bobot dalam neuron.

Ada dua metode yang masuk akal:

1) Backprop frame-bijaksana dan pembaruan

Load a sentence.  
Divide into frames/timesteps.  
For each frame:
- Apply forward step
- Determine error function
- Apply backpropagation to this frame's error
- Update weights accordingly
At end of sentence, reset memory
load another sentence and continue.

atau,

2) Backprop bijaksana dan memperbarui:

Load a sentence.  
Divide into frames/timesteps.  
For each frame:
- Apply forward step
- Determine error function
At end of sentence:
- Apply backprop to average of sentence error function
- Update weights accordingly
- Reset memory
Load another sentence and continue.

Perhatikan bahwa ini adalah pertanyaan umum tentang pelatihan RNN menggunakan kertas Graves sebagai contoh yang runcing (dan relevan secara pribadi): Ketika melatih RNN tentang urutan, apakah backprop diterapkan di setiap langkah waktu? Apakah bobot disesuaikan setiap waktu? Atau, dalam analogi longgar dengan pelatihan batch tentang arsitektur umpan-maju secara ketat, apakah kesalahan diakumulasi dan dirata-ratakan dalam urutan tertentu sebelum backprop dan pembaruan berat diterapkan?

Atau apakah saya lebih bingung daripada yang saya kira?

Novak
sumber

Jawaban:

25

Saya akan berasumsi kita sedang berbicara tentang jaring saraf berulang (RNNs) yang menghasilkan output pada setiap langkah waktu (jika output hanya tersedia di akhir urutan, itu masuk akal untuk menjalankan backprop di akhir). RNNs dalam pengaturan ini sering dilatih menggunakan backpropagation terpotong melalui waktu (BPTT) terpotong, beroperasi secara berurutan pada 'potongan' dari suatu urutan. Prosedurnya terlihat seperti ini:

  1. Maju terus: langkah-langkah waktu berikutnya , menghitung status input, tersembunyi, dan keluaran.k1
  2. Hitung kerugiannya, simpulkan dari langkah waktu sebelumnya (lihat di bawah).
  3. Backward pass: Hitung gradien dari kehilangan wrt semua parameter, terakumulasi selama langkah waktu sebelumnya (ini mengharuskan menyimpan semua aktivasi untuk langkah-langkah waktu ini). Klip gradien untuk menghindari masalah gradien meledak (jarang terjadi).k2
  4. Perbarui parameter (ini terjadi satu kali per potong, tidak secara bertahap pada setiap langkah waktu).
  5. Jika memproses beberapa bongkahan dari urutan yang lebih lama, simpan keadaan tersembunyi pada langkah terakhir kali (akan digunakan untuk menginisialisasi keadaan tersembunyi untuk memulai bongkahan berikutnya). Jika kita telah mencapai akhir urutan, atur ulang memori / keadaan tersembunyi dan pindah ke awal urutan berikutnya (atau awal urutan yang sama, jika hanya ada satu).
  6. Ulangi dari langkah 1.

Bagaimana kerugian dijumlahkan tergantung pada dan . Misalnya, ketika , kerugian dijumlahkan di atas langkah waktu lalu , tetapi prosedurnya berbeda ketika (lihat Williams dan Peng 1990).k1k2k1=k2k1=k2k2>k1

Komputasi dan pembaruan Gradient dilakukan setiap langkah waktu karena lebih murah secara komputasi daripada memperbarui di setiap langkah waktu. Memutakhirkan beberapa kali per urutan (yaitu menetapkan kurang dari panjang urutan) dapat mempercepat pelatihan karena pembaruan berat lebih sering.k1k1

Backpropagation dilakukan hanya untuk langkah waktu karena secara komputasi lebih murah daripada merambat kembali ke awal urutan (yang akan membutuhkan penyimpanan dan berulang kali memproses semua langkah waktu). Gradien yang dihitung dengan cara ini merupakan perkiraan gradien 'benar' yang dihitung pada semua langkah waktu. Tetapi, karena masalah gradien menghilang, gradien akan cenderung mendekati nol setelah beberapa langkah waktu; menyebarkan melampaui batas ini tidak akan memberikan manfaat apa pun. Menyetel terlalu pendek dapat membatasi skala temporal tempat jaringan dapat belajar. Namun, memori jaringan tidak terbatas pada langkah waktu karena unit tersembunyi dapat menyimpan informasi di luar periode ini (misk2k2k2).

Selain pertimbangan komputasi, pengaturan yang tepat untuk dan bergantung pada statistik data (misalnya skala temporal struktur yang relevan untuk menghasilkan keluaran yang baik). Mereka mungkin juga tergantung pada detail jaringan. Misalnya, ada sejumlah arsitektur, trik inisialisasi, dll. Yang dirancang untuk mengurangi masalah gradien yang membusuk.k1k2

Pilihan Anda 1 ('backprop frame-wise') sesuai dengan pengaturan ke dan dengan jumlah langkah waktu dari awal kalimat ke titik saat ini. Opsi 2 ('backprops bijaksana-bijaksana') sesuai dengan pengaturan dan dengan panjang kalimat. Keduanya adalah pendekatan yang valid (dengan pertimbangan komputasi / kinerja seperti di atas; # 1 akan cukup intensif secara komputasi untuk urutan yang lebih lama). Tak satu pun dari pendekatan ini akan disebut 'terpotong' karena backpropagation terjadi di seluruh urutan. Pengaturan lain dari dan dimungkinkan; Saya akan mencantumkan beberapa contoh di bawah ini.k11k2k1k2k1k2

Referensi yang menggambarkan BPTT terpotong (prosedur, motivasi, masalah praktis):

  • Sutskever (2013) . Pelatihan jaringan saraf berulang.
  • Mikolov (2012) . Model Bahasa Statistik Berdasarkan Jaringan Saraf Tiruan.
    • Menggunakan vanilla RNNs untuk memproses data teks sebagai urutan kata, ia merekomendasikan pengaturan hingga 10-20 kata dan hingga 5 katak1k2
    • Melakukan beberapa pembaruan per urutan (mis. kurang dari panjang urutan) berfungsi lebih baik daripada memperbarui di akhir urutank1
    • Melakukan pembaruan sekali per potong lebih baik daripada secara bertahap (yang mungkin tidak stabil)
  • Williams dan Peng (1990) . Algoritme berbasis gradien yang efisien untuk pelatihan on-line lintasan jaringan berulang.
    • Proposal algoritma asli (?)
    • Mereka mendiskusikan pilihan dan (yang mereka sebut dan ). Mereka hanya mempertimbangkan .k1k2hhk2k1
    • Catatan: Mereka menggunakan frasa "BPTT (h; h ')" atau' algoritma yang ditingkatkan 'untuk merujuk pada apa yang disebut referensi lain' BPTT terpotong '. Mereka menggunakan frasa 'BPT terpotong' berarti kasus khusus di mana .k1=1

Contoh lain menggunakan BPTT terpotong:

  • (Karpathy 2015). char-rnn.
    • Deskripsi dan kode
    • Vanilla RNN memproses dokumen teks satu karakter pada satu waktu. Terlatih untuk memprediksi karakter selanjutnya. karakter. Jaringan yang digunakan untuk menghasilkan teks baru dengan gaya dokumen pelatihan, dengan hasil yang lucu.k1=k2=25
  • Graves (2014) . Menghasilkan urutan dengan jaringan saraf berulang.
    • Lihat bagian tentang membuat artikel Wikipedia yang disimulasikan. Jaringan LSTM memproses data teks sebagai urutan byte. Terlatih untuk memprediksi byte berikutnya. byte. Memori LSTM diatur ulang setiap byte.k1=k2=10010,000
  • Sak et al. (2014) . Arsitektur jaringan saraf berulang jangka panjang berbasis memori jangka panjang untuk pengenalan ucapan kosakata yang besar.
    • Jaringan LSTM yang dimodifikasi, memproses urutan fitur akustik. .k1=k2=20
  • Ollivier et al. (2015) . Pelatihan jaringan berulang online tanpa mundur.
    • Poin dari makalah ini adalah untuk mengusulkan algoritma pembelajaran yang berbeda, tetapi mereka membandingkannya dengan BPTT terpotong. Menggunakan vanilla RNNs untuk memprediksi urutan simbol. Hanya menyebutkan di sini untuk mengatakan bahwa mereka menggunakan .k1=k2=15
  • Hochreiter dan Schmidhuber (1997) . Memori jangka pendek yang panjang.
    • Mereka menggambarkan prosedur yang dimodifikasi untuk LSTM
pengguna20160
sumber
Ini adalah jawaban yang luar biasa, dan saya berharap saya memiliki posisi di forum ini untuk memberikan hadiah besar untuk itu. Yang sangat berguna adalah diskusi konkret k1 vs k2 untuk mengontekstualisasikan dua kasus saya terhadap penggunaan yang lebih umum, dan contoh numerik yang sama.
Novak