Apakah praktik umum untuk meminimalkan kerugian rata-rata dari batch bukan jumlah?

15

Tensorflow memiliki contoh tutorial tentang mengklasifikasikan CIFAR-10 . Pada tutorial, rata-rata kehilangan entropi silang di seluruh batch diminimalkan.

def loss(logits, labels):
  """Add L2Loss to all the trainable variables.
  Add summary for for "Loss" and "Loss/avg".
  Args:
    logits: Logits from inference().
    labels: Labels from distorted_inputs or inputs(). 1-D tensor
            of shape [batch_size]
  Returns:
    Loss tensor of type float.
  """
  # Calculate the average cross entropy loss across the batch.
  labels = tf.cast(labels, tf.int64)
  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
      logits, labels, name='cross_entropy_per_example')
  cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
  tf.add_to_collection('losses', cross_entropy_mean)

  # The total loss is defined as the cross entropy loss plus all of the weight
  # decay terms (L2 loss).
  return tf.add_n(tf.get_collection('losses'), name='total_loss')

Lihat cifar10.py , baris 267.

Mengapa itu tidak meminimalkan jumlah lintas batch? Apakah itu membuat perbedaan? Saya tidak mengerti bagaimana ini akan mempengaruhi perhitungan backprop.

Bentrokan
sumber
Tidak persis jumlah / avg terkait, tetapi pilihan kerugian adalah pilihan desain aplikasi. Misalnya, jika Anda mahir dalam hal rata-rata, optimalkan rata-rata. Jika aplikasi Anda peka terhadap skenario terburuk (mis., Kecelakaan otomotif), Anda harus mengoptimalkan nilai maks.
Alex Kreimer
Lihat juga: stats.stackexchange.com/questions/358786/…
Sycorax berkata Reinstate Monica

Jawaban:

15

Seperti disebutkan oleh pkubik, biasanya ada istilah regularisasi untuk parameter yang tidak tergantung pada input, misalnya dalam tensorflow seperti

# Loss function using L2 Regularization
regularizer = tf.nn.l2_loss(weights)
loss = tf.reduce_mean(loss + beta * regularizer)

Dalam hal ini, rata-rata selama batch mini membantu menjaga rasio tetap antara cross_entropykerugian dan regularizerkerugian saat ukuran batch diubah.

Selain itu tingkat pembelajaran juga peka terhadap besarnya kerugian (gradien), sehingga untuk menormalkan hasil ukuran batch yang berbeda, mengambil rata-rata tampaknya merupakan pilihan yang lebih baik.


Memperbarui

Makalah ini oleh Facebook (Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour) menunjukkan bahwa, sebenarnya penskalaan tingkat pembelajaran menurut ukuran bets bekerja dengan sangat baik:

Linear Scaling Rule: Ketika ukuran minibatch dikalikan dengan k, kalikan laju pembelajaran dengan k.

yang pada dasarnya sama dengan mengalikan gradien dengan k dan mempertahankan tingkat pembelajaran tidak berubah, jadi saya kira mengambil rata-rata tidak perlu.

dontloo
sumber
8

Saya akan fokus pada bagian:

Saya tidak mengerti bagaimana ini akan mempengaruhi perhitungan backprop.

1BLSUM=BLAVGBdLSUMdx=BdLAVGdx

dL.dx=limΔ0L.(x+Δ)-L.(x)Δ
d(cL.)dx=limΔ0cL.(x+Δ)-cL.(x)Δ
d(cL.)dx=climΔ0L.(x+Δ)-L.(x)Δ=cdL.dx

Dalam SGD kita akan memperbarui bobot menggunakan gradiennya dikalikan dengan laju pembelajaran dan kita dapat dengan jelas melihat bahwa kita dapat memilih parameter ini sedemikian rupa sehingga pembaruan bobot akhir sama dengan. Aturan pembaruan pertama: dan aturan pembaruan kedua (bayangkan bahwa ): λ

W: =W+λ1dL.SUM.dW
λ1=λ2B
W: =W+λ1dL.SEBUAHVGdW=W+λ2BdL.SUM.dW


Temuan yang sangat baik dari dontloo mungkin menyarankan bahwa menggunakan jumlah mungkin pendekatan yang lebih tepat. Untuk membenarkan rata-rata yang tampaknya lebih populer saya akan menambahkan bahwa menggunakan penjumlahan mungkin dapat menyebabkan beberapa masalah dengan regularisasi berat badan. Menyetel faktor penskalaan untuk regulator untuk ukuran batch yang berbeda mungkin sama menyebalkannya seperti menyetel tingkat pembelajaran.

pkubik
sumber