Bagaimana model regresi logistik sederhana mencapai akurasi klasifikasi 92% pada MNIST?

70

Meskipun semua gambar dalam dataset MNIST dipusatkan, dengan skala yang sama, dan menghadap ke atas tanpa rotasi, mereka memiliki variasi tulisan tangan yang signifikan yang membuat saya bingung bagaimana model linier mencapai akurasi klasifikasi tinggi.

Sejauh yang saya bisa memvisualisasikan, mengingat variasi tulisan tangan yang signifikan, digit harus tidak dapat dipisahkan secara linear dalam ruang dimensi 784, yaitu, harus ada sedikit batas non-linear yang kompleks (meskipun tidak terlalu kompleks) yang memisahkan digit yang berbeda , mirip dengan contoh dikutip dengan baik di mana kelas positif dan negatif tidak dapat dipisahkan oleh classifier linier. Tampaknya membingungkan bagi saya bagaimana regresi logistik multi-kelas menghasilkan akurasi tinggi dengan fitur yang sepenuhnya linier (tidak ada fitur polinomial).XOR

Sebagai contoh, mengingat piksel apa pun dalam gambar, variasi tulisan tangan yang berbeda dari angka dan dapat membuat piksel tersebut menyala atau tidak. Oleh karena itu, dengan serangkaian bobot yang dipelajari, setiap piksel dapat membuat tampilan digit sebagai dan . Hanya dengan kombinasi nilai piksel yang memungkinkan untuk mengatakan apakah digit adalah atau . Ini berlaku untuk sebagian besar pasangan digit. Jadi, bagaimana regresi logistik, yang secara membabi buta mendasarkan keputusannya secara independen pada semua nilai piksel (tanpa mempertimbangkan ketergantungan antar-piksel sama sekali), dapat mencapai akurasi tinggi tersebut.232323

Saya tahu bahwa saya salah di suatu tempat atau hanya terlalu memperkirakan variasi dalam gambar. Namun, alangkah baiknya jika seseorang dapat membantu saya dengan intuisi tentang bagaimana digit 'hampir' terpisah secara linear.

Nitish Agarwal
sumber
Lihatlah buku teks Pembelajaran Statistik dengan Sparsity: Lasso dan Generalisasi 3.3.1 Contoh: Digit Tulisan Tangan web.stanford.edu/ ~ hastie
Adrian
Saya ingin tahu: seberapa baik sesuatu seperti model linear yang dihukum (yaitu, glmnet) lakukan pada masalah? Jika saya ingat, apa yang Anda laporkan adalah akurasi out-of-sample yang belum dilaparkan.
Cliff AB

Jawaban:

88

tl; dr Meskipun ini adalah kumpulan data klasifikasi gambar, ini tetap merupakan tugas yang sangat mudah , yang dengannya seseorang dapat dengan mudah menemukan pemetaan langsung dari input ke prediksi.


Menjawab:

Ini adalah pertanyaan yang sangat menarik dan berkat kesederhanaan regresi logistik Anda benar-benar dapat menemukan jawabannya.

Apa yang dilakukan regresi logistik adalah agar setiap gambar menerima input dan mengalikannya dengan bobot untuk menghasilkan prediksi. Yang menarik adalah karena pemetaan langsung antara input dan output (yaitu tidak ada lapisan tersembunyi), nilai setiap bobot sesuai dengan seberapa banyak masing-masing dari input diperhitungkan saat menghitung probabilitas setiap kelas. Sekarang, dengan mengambil bobot untuk setiap kelas dan membentuknya kembali menjadi (yaitu resolusi gambar), kita dapat mengetahui piksel apa yang paling penting untuk perhitungan setiap kelas .78478428×28

Perhatikan, sekali lagi, bahwa ini adalah bobotnya .

Sekarang lihat gambar di atas dan fokus pada dua digit pertama (yaitu nol dan satu). Bobot biru berarti bahwa intensitas piksel ini banyak berkontribusi untuk kelas itu dan nilai merah berarti memberi kontribusi negatif.

Sekarang bayangkan, bagaimana seseorang menggambar angka ? Dia menggambar bentuk melingkar yang kosong di antaranya. Itulah tepatnya yang diangkat oleh beban. Bahkan jika seseorang menggambar tengah gambar, itu dihitung negatif sebagai nol. Jadi untuk mengenali nol Anda tidak perlu beberapa filter canggih dan fitur tingkat tinggi. Anda bisa melihat lokasi piksel yang diambil dan menilai berdasarkan ini.0

Hal yang sama untuk . Itu selalu memiliki garis vertikal lurus di tengah gambar. Semua yang lain terhitung negatif.1

Sisa digitnya sedikit lebih rumit, tetapi dengan sedikit imajinasi Anda dapat melihat , , dan . Angka-angka lainnya sedikit lebih sulit, yang sebenarnya membatasi regresi logistik untuk mencapai tahun 90-an.2378

Melalui ini Anda dapat melihat bahwa regresi logistik memiliki peluang yang sangat baik untuk mendapatkan banyak gambar dengan benar dan itulah mengapa nilainya sangat tinggi.


Kode untuk mereproduksi gambar di atas sedikit bertanggal, tetapi di sini Anda mulai:

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data

# Load MNIST:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# Create model
x = tf.placeholder(tf.float32, shape=(None, 784))
y = tf.placeholder(tf.float32, shape=(None, 10))

W = tf.Variable(tf.zeros((784,10)))
b = tf.Variable(tf.zeros((10)))
z = tf.matmul(x, W) + b

y_hat = tf.nn.softmax(z)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(y_hat), reduction_indices=[1]))
optimizer = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) # 

correct_pred = tf.equal(tf.argmax(y_hat, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Train model
batch_size = 64
with tf.Session() as sess:

    loss_tr, acc_tr, loss_ts, acc_ts = [], [], [], []

    sess.run(tf.global_variables_initializer()) 

    for step in range(1, 1001):

        x_batch, y_batch = mnist.train.next_batch(batch_size) 
        sess.run(optimizer, feed_dict={x: x_batch, y: y_batch})

        l_tr, a_tr = sess.run([cross_entropy, accuracy], feed_dict={x: x_batch, y: y_batch})
        l_ts, a_ts = sess.run([cross_entropy, accuracy], feed_dict={x: mnist.test.images, y: mnist.test.labels})
        loss_tr.append(l_tr)
        acc_tr.append(a_tr)
        loss_ts.append(l_ts)
        acc_ts.append(a_ts)

    weights = sess.run(W)      
    print('Test Accuracy =', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})) 

# Plotting:
for i in range(10):
    plt.subplot(2, 5, i+1)
    weight = weights[:,i].reshape([28,28])
    plt.title(i)
    plt.imshow(weight, cmap='RdBu')  # as noted by @Eric Duminil, cmap='gray' makes the numbers stand out more
    frame1 = plt.gca()
    frame1.axes.get_xaxis().set_visible(False)
    frame1.axes.get_yaxis().set_visible(False)
Djib2011
sumber
12
Terima kasih untuk ilustrasinya. Gambar-gambar berat ini membuatnya lebih jelas seperti bagaimana akurasi sangat tinggi. Penggandaan dot dari gambar digit tulisan tangan dengan gambar berat yang sesuai dengan label sebenarnya dari gambar 'tampaknya' menjadi yang tertinggi dibandingkan dengan produk titik dengan label berat lainnya untuk sebagian besar (masih 92% terlihat seperti banyak bagi saya) dari gambar di MNIST. Namun, sedikit mengejutkan bahwa dan atau dan jarang salah diklasifikasi satu sama lain setelah memeriksa matriks kebingungan. Bagaimanapun, ini adalah apa adanya. Data tidak pernah bohong. :)2378
Nitish Agarwal
13
Tentu saja ini membantu bahwa sampel MNIST dipusatkan, diskalakan, dan dinormalisasi kontras sebelum pengklasifikasi melihatnya. Anda tidak perlu menjawab pertanyaan seperti "bagaimana jika ujung nol benar-benar melewati bagian tengah kotak?" karena pra-prosesor sudah berjalan jauh untuk membuat semua nol terlihat sama.
hobbs
1
@EricDuminil Saya menambahkan pujian pada skrip dengan saran Anda. Terima kasih banyak untuk masukannya! : D
Djib2011
1
@NitishAgarwal, Jika Anda berpikir bahwa jawaban ini adalah Jawaban untuk Pertanyaan Anda, pertimbangkan untuk menandainya.
sintax
13
Untuk seseorang yang tertarik tetapi tidak terlalu akrab dengan pemrosesan semacam ini, jawaban ini memberikan contoh intuitif yang fantastis tentang mekanika.
chrylis -on strike-