Mengapa akurasi validasi berfluktuasi?

31

Saya memiliki CNN empat lapis untuk memprediksi respons terhadap kanker menggunakan data MRI. Saya menggunakan aktivasi ReLU untuk memperkenalkan nonlinier. Akurasi dan kehilangan kereta secara monoton meningkat dan menurun. Tapi, akurasi pengujian saya mulai berfluktuasi liar. Saya sudah mencoba mengubah tingkat belajar, mengurangi jumlah lapisan. Tapi, itu tidak menghentikan fluktuasi. Saya bahkan membaca jawaban ini dan mencoba mengikuti arahan dalam jawaban itu, tetapi tidak berhasil lagi. Adakah yang bisa membantu saya mencari tahu di mana saya salah?

Tangkapan layar

Raghuram
sumber
Ya, saya membaca jawaban itu. Mengocok data validasi tidak membantu
Raghuram
4
Karena Anda belum membagikan cuplikan kode Anda, maka saya tidak bisa mengatakan banyak apa yang salah dalam arsitektur Anda. Tetapi dalam tangkapan layar Anda, melihat pelatihan dan akurasi validasi Anda, sangat jelas bahwa jaringan Anda terlalu cocok. Akan lebih baik jika Anda membagikan potongan kode Anda di sini.
Nain
berapa banyak sampel yang anda miliki? mungkin fluktuasi tidak terlalu signifikan. Juga, akurasinya adalah ukuran yang mengerikan
rep_ho
Dapatkah seseorang membantu saya memverifikasi jika menggunakan pendekatan ensemble bagus ketika akurasi validasinya berfluktuasi? karena saya dapat mengelola validation_accuracy berfluktuasi saya dengan ensemble ke nilai yang baik.
Sri2110

Jawaban:

27

Jika saya memahami definisi akurasi dengan benar, akurasi (% dari titik data diklasifikasikan dengan benar) kurang kumulatif daripada katakanlah MSE (mean squared error). Itu sebabnya Anda melihat bahwa Anda lossmeningkat dengan cepat, sementara akurasi berfluktuasi.

Secara intuitif, ini pada dasarnya berarti, bahwa beberapa bagian dari contoh diklasifikasikan secara acak , yang menghasilkan fluktuasi, karena jumlah tebakan acak yang benar selalu berfluktuasi (bayangkan akurasi ketika koin harus selalu mengembalikan "kepala"). Pada dasarnya sensitivitas terhadap noise (ketika klasifikasi menghasilkan hasil acak) adalah definisi umum overfitting (lihat wikipedia):

Dalam statistik dan pembelajaran mesin, salah satu tugas yang paling umum adalah mencocokkan "model" dengan satu set data pelatihan, sehingga dapat membuat prediksi yang andal pada data umum yang tidak terlatih. Dalam overfitting, model statistik menggambarkan kesalahan atau kebisingan acak alih-alih hubungan yang mendasarinya

Bukti lain dari overfitting adalah bahwa kerugian Anda meningkat, Kerugian diukur lebih tepat, lebih sensitif terhadap prediksi bising jika itu tidak tergencet oleh sigmoids / ambang batas (yang tampaknya menjadi kasus Anda untuk Kerugian itu sendiri). Secara intuitif, Anda bisa membayangkan situasi ketika jaringan terlalu yakin tentang output (ketika itu salah), sehingga memberikan nilai yang jauh dari ambang batas dalam kasus kesalahan klasifikasi acak.

Mengenai kasus Anda, model Anda tidak diatur dengan benar, kemungkinan alasan:

  • tidak cukup titik data, kapasitas terlalu banyak
  • pemesanan
  • tidak ada / salah penskalaan / normalisasi fitur
  • tingkat pembelajaran: terlalu besar, sehingga SGD melompat terlalu jauh dan merindukan area dekat minimum lokal. Ini akan menjadi kasus ekstrem "under-fitting" (ketidakpekaan terhadap data itu sendiri), tetapi mungkin menghasilkan (jenis) suara "frekuensi rendah" pada output dengan mengacak data dari input - berlawanan dengan intuisi yang overfitting, itu akan menjadi seperti selalu menebak kepala ketika memprediksi koin. Seperti yang ditunjukkan oleh @JanKukacka, tiba di area "terlalu dekat dengan" minima mungkin menyebabkan overfitting, jadi jika terlalu kecil itu akan menjadi sensitif terhadap kebisingan "frekuensi tinggi" dalam data Anda. harus berada di antara keduanya.αα ααα

Solusi yang memungkinkan:

  • mendapatkan lebih banyak data-poin (atau secara artifisial memperluas set yang sudah ada)
  • bermain dengan hiper-parameter (peningkatan / penurunan kapasitas atau istilah regularisasi misalnya)
  • regularisasi : coba putus, berhenti lebih awal, dan sebagainya
dk14
sumber
Mengenai: "Kerugian diukur lebih tepat, lebih sensitif terhadap prediksi berisik karena tidak tergencet oleh sigmoids / ambang batas", saya setuju dengan tidak ada ambang, tetapi jika Anda menggunakan mis. Entropi silang biner sebagai fungsi kerugian Anda, sigmoid masih berperan sebuah peran.
Zhubarb
1
Mengenai tingkat pembelajaran dan kehilangan bagian minimum: mencapai minimum kemungkinan besar berarti kelebihan (karena itu adalah minimum pada set pelatihan)
Jan Kukacka
@ Bosmeister benar, saya telah sedikit diucapkan ulang (lihat edit). Pemikiranku ada peningkatan Loss adalah pertanda bahwa fungsi non-squashing sedang digunakan.
dk14
@JanKukacka maksudmu minimum global? Saya menyiratkan minimum lokal (sebenarnya dekat minimum lokal) - dalam arti bahwa jika terlalu jauh dari minimum apa pun, itu akan kurang pas. Mungkin, saya harus menggambarkannya lebih hati-hati (lihat edit), terima kasih.
dk14
@ dk14 Saya menganggap minimum global tidak dapat dicapai dalam praktik, jadi maksud saya minimum lokal. Jika Anda terlalu jauh, Anda mungkin kurang pas, tetapi jika Anda terlalu dekat, kemungkinan besar Anda overfitting. Ada karya menarik oleh Moritz Hardt "Melatih lebih cepat, menyamaratakan dengan lebih baik: Stabilitas keturunan gradien stokastik" ( arxiv.org/abs/1509.01240 ) menempatkan batasan pada hubungan antara pelatihan dan kesalahan pengujian saat berlatih dengan SGD.
Jan Kukacka
6

Pertanyaan ini sudah lama tetapi memposting ini karena belum ditunjukkan:

Kemungkinan 1 : Anda menerapkan semacam preprocessing (makna nol, normalisasi, dll.) Baik untuk set pelatihan Anda atau set validasi, tetapi tidak yang lain .

Kemungkinan 2 : Jika Anda membuat beberapa layer yang berkinerja berbeda selama pelatihan dan inferensi dari awal, model Anda mungkin diimplementasikan secara tidak benar (misalnya, memindahkan rata-rata dan memindahkan standar deviasi untuk normalisasi batch yang diperbarui selama pelatihan? Jika menggunakan dropout, apakah bobot diskalakan dengan benar selama kesimpulan?). Ini mungkin terjadi jika kode Anda mengimplementasikan hal-hal ini dari awal dan tidak menggunakan fungsi built-in Tensorflow / Pytorch.

Kemungkinan 3: Overfitting, seperti yang ditunjukkan semua orang. Saya menemukan dua opsi lainnya lebih mungkin dalam situasi spesifik Anda karena akurasi validasi Anda macet di 50% dari zaman 3. Secara umum, saya akan lebih khawatir tentang overfitting jika ini terjadi pada tahap selanjutnya (kecuali Anda memiliki masalah yang sangat spesifik di tangan).

Soroush
sumber
Saya mengalami masalah yang agak mirip tetapi tidak sepenuhnya, lebih detail di sini: stackoverflow.com/questions/55348052/... Dalam kasus saya, saya benar-benar memiliki akurasi tinggi yang konsisten dengan data uji dan selama pelatihan, validasi "akurasi "(bukan kerugian) lebih tinggi dari akurasi pelatihan. Tetapi fakta bahwa itu tidak pernah konvergen dan terombang-ambing membuat saya berpikir overfitting, sementara beberapa menyarankan itu tidak terjadi, jadi saya bertanya-tanya apakah itu dan apa pembenarannya jika tidak.
dusa
1
Sejauh ini, inilah penjelasan yang paling masuk akal dari jawaban yang diberikan. Perhatikan bahwa momentum normalisasi bets tinggi (mis. 0,999, atau bahkan standar Keras 0,99) dalam kombinasi dengan tingkat pembelajaran yang tinggi juga dapat menghasilkan perilaku yang sangat berbeda dalam pelatihan dan evaluasi karena statistik lapisan tertinggal sangat jauh di belakang. Dalam hal mengurangi momentum ke sesuatu seperti 0,9 harus melakukan trik. Saya memiliki masalah yang sama dengan OP dan ini berhasil.
kristjan
5

Menambahkan ke jawaban oleh @ dk14. Jika Anda masih melihat fluktuasi setelah mengatur model Anda dengan benar , ini bisa menjadi alasan yang memungkinkan:

  • Menggunakan sampel acak dari set validasi Anda: Ini berarti set validasi Anda pada setiap langkah evaluasi berbeda, begitu juga validasi-kerugian Anda.
  • Menggunakan fungsi kerugian tertimbang (yang digunakan jika ada masalah kelas yang sangat tidak seimbang). Pada langkah kereta, Anda menimbang fungsi kerugian Anda berdasarkan bobot kelas, sedangkan pada langkah dev Anda hanya menghitung kerugian yang tidak berbobot. Dalam kasus seperti itu, meskipun jaringan Anda sedang menuju konvergensi, Anda mungkin melihat banyak fluktuasi dalam kehilangan validasi setelah setiap langkah kereta. Tetapi jika Anda menunggu gambar yang lebih besar, Anda dapat melihat bahwa jaringan Anda benar-benar konvergen ke minima dengan fluktuasi yang usang (lihat gambar terlampir untuk salah satu contohnya).masukkan deskripsi gambar di sinimasukkan deskripsi gambar di sini
bitpersecond
sumber
2

Jelas terlalu pas. Kesenjangan antara akurasi pada data pelatihan dan data tes menunjukkan bahwa Anda terlalu cocok untuk pelatihan. Mungkin regularisasi dapat membantu.

keramat
sumber
1

Akurasi validasi Anda pada masalah klasifikasi biner (saya berasumsi) adalah "berfluktuasi" sekitar 50%, itu berarti model Anda memberikan prediksi yang benar-benar acak (kadang-kadang menebak dengan benar beberapa sampel lebih banyak, kadang beberapa sampel lebih sedikit). Secara umum, model Anda tidak lebih baik daripada membalik koin.

{0;1}

Bagaimanapun, seperti yang telah ditunjukkan orang lain, model Anda mengalami overfitting yang parah. Dugaan saya adalah bahwa masalah Anda terlalu rumit , yaitu sangat sulit untuk mengekstrak informasi yang diinginkan dari data Anda, dan 4-layer conv-net terlatih sederhana seperti end2end tidak memiliki kesempatan untuk mempelajarinya .

Jan Kukacka
sumber
0

Ada beberapa cara untuk mencoba dalam situasi Anda. Pertama-tama cobalah untuk meningkatkan ukuran bets, yang membantu SGD mini-batch kurang berkeliaran dengan liar. Kedua, menyesuaikan tingkat belajar, mungkin membuatnya lebih kecil. Ketiga, coba pengoptimal yang berbeda, misalnya Adam atau RMSProp yang dapat menyesuaikan tingkat pembelajaran untuk fitur WRT. Jika mungkin, coba tambah data Anda. Terakhir, coba jaringan saraf Bayesian melalui pendekatan dropout, sebuah karya yang sangat menarik dari Yarin Gal https://arxiv.org/abs/1506.02158

pateheo
sumber
0

Sudahkah Anda mencoba jaringan yang lebih kecil? Mengingat akurasi pelatihan Anda dapat mencapai> .99, jaringan Anda tampaknya memiliki koneksi yang cukup untuk sepenuhnya memodelkan data Anda, tetapi Anda mungkin memiliki koneksi asing yang belajar secara acak (mis. Overfitting).

Dalam pengalaman saya, saya mendapatkan akurasi validasi ketidaksepakatan untuk stabil dengan jaringan yang lebih kecil dengan mencoba berbagai jaringan seperti ResNet, VGG, dan bahkan jaringan yang lebih sederhana.

teter123f
sumber