RNN belajar gelombang sinus dari frekuensi yang berbeda

8

Sebagai pemanasan dengan jaringan saraf berulang, saya mencoba memprediksi gelombang sinus dari gelombang sinus lain dari frekuensi lain.

Model saya adalah RNN sederhana, forward pass-nya dapat dinyatakan sebagai berikut:

rt=σ(Winxt+Wrecrt1))zt=Woutrt
di mana adalah fungsi sigmoïd.σ

Ketika kedua input input dan output yang diharapkan adalah dua gelombang sinus dari frekuensi yang sama tetapi dengan (mungkin) pergeseran fasa, model ini mampu melakukan konvergensi yang tepat dengan pendekatan yang masuk akal.

Namun, dalam kasus berikut, model konvergen ke minimum lokal dan memprediksi nol setiap saat:

  • input:x=sin(t)
  • hasil yang diharapkan:y=sin(t2)

Inilah yang diprediksi jaringan ketika diberi urutan input penuh setelah 10 zaman pelatihan, menggunakan batch mini ukuran 16, tingkat pembelajaran 0,01, panjang urutan 16 dan lapisan tersembunyi ukuran 32:

Prediksi jaringan setelah 10 zaman, menggunakan mini-batch ukuran 16

Yang membuat saya berpikir jaringan tidak dapat belajar melalui waktu dan hanya mengandalkan input saat ini untuk membuat prediksi.

Saya mencoba untuk menyetel tingkat belajar, panjang urutan dan ukuran lapisan tersembunyi tanpa banyak keberhasilan.

Saya mengalami masalah yang sama persis dengan LSTM. Saya tidak ingin percaya bahwa arsitektur ini cacat, ada petunjuk tentang apa yang saya lakukan salah?

Saya menggunakan paket rnn untuk Torch, kodenya ada di Gist .

Simon
sumber

Jawaban:

3

Data Anda pada dasarnya tidak dapat dipelajari dengan RNN yang dilatih dengan cara itu. Masukan Anda adalah adalah periodicsin(t)2πsin(t)=sin(t+2π)

tetapi target Anda adalah periodik dansin(t/2)4πsin(t/2)=sin(t+2π)

Oleh karena itu, dalam dataset Anda, Anda akan memiliki pasangan input yang identik dengan output yang berlawanan. Dalam hal Mean Squared Error, itu berarti bahwa solusi optimal adalah fungsi nol.

Ini adalah dua irisan plot Anda di mana Anda dapat melihat input yang identik tetapi berlawanan sasaran masukkan deskripsi gambar di sini

ChenM
sumber
1
Untuk menguraikan jawaban ini, masalah muncul karena menggunakan inisialisasi umpan balik yang sama untuk input yang berbeda. Saya memecahkan ini dengan melakukan (secara acak) lebih maju daripada mundur untuk mempelajari urutan lengkap.
Simon