Mengapa Python scikit-learn LDA tidak bekerja dengan benar dan bagaimana cara menghitung LDA melalui SVD?

26

Saya menggunakan Linear Discriminant Analysis (LDA) dari scikit-learnperpustakaan pembelajaran mesin (Python) untuk pengurangan dimensi dan sedikit ingin tahu tentang hasilnya. Sekarang saya bertanya-tanya apa yang dilakukan LDA scikit-learnsehingga hasilnya terlihat berbeda dari, misalnya, pendekatan manual atau LDA yang dilakukan di R. Akan lebih bagus jika seseorang bisa memberi saya wawasan di sini.

Apa yang pada dasarnya paling memprihatinkan adalah bahwa scikit-plotmenunjukkan korelasi antara dua variabel di mana harus ada korelasi 0.

Untuk tes, saya menggunakan dataset Iris dan 2 diskriminan linier pertama terlihat seperti ini:

IMG-1. LDA melalui scikit-learn

masukkan deskripsi gambar di sini

Ini pada dasarnya konsisten dengan hasil yang saya temukan di dokumentasi scikit-learning di sini.

Sekarang, saya melalui LDA langkah demi langkah dan mendapat proyeksi yang berbeda. Saya mencoba berbagai pendekatan untuk mencari tahu apa yang sedang terjadi:

IMG-2. LDA pada data mentah (tidak ada pemusatan, tidak ada standardisasi)

masukkan deskripsi gambar di sini

Dan di sini akan menjadi pendekatan langkah-demi-langkah jika saya menstandarkan (z-skor normalisasi; varian unit) data terlebih dahulu. Saya melakukan hal yang sama hanya dengan pemusatan-kejam saja, yang seharusnya mengarah pada gambar proyeksi relatif yang sama (dan memang memang demikian).

IMG-3. Step-by-step LDA setelah mean-centering, atau standardisasi

masukkan deskripsi gambar di sini

IMG-4. LDA dalam R (pengaturan default)

LDA di IMG-3 di mana saya memusatkan data (yang akan menjadi pendekatan yang disukai) terlihat juga persis sama dengan yang saya temukan di Post oleh seseorang yang melakukan LDA di R masukkan deskripsi gambar di sini


Kode untuk referensi

Saya tidak ingin menempelkan semua kode di sini, tetapi saya telah mengunggahnya sebagai notebook IPython di sini dipecah menjadi beberapa langkah yang saya gunakan (lihat di bawah) untuk proyeksi LDA.

  1. Langkah 1: Menghitung vektor rata-rata d-dimensi
    mi=1nixDinxk
  2. Langkah 2: Menghitung Matriks Sebar

    2.1 dalam kelas scatter matrix dihitung dengan persamaan berikut: S W = c Σ i = 1 S i = c Σ i = 1 n Σ xD i ( x - m i )SW

    SW=i=1cSi=i=1cxDin(xmi)(xmi)T

    2.2 antara kelas scatter matrix dihitung dengan persamaan berikut: S B = c Σ i = 1 n i ( m i - m ) ( m i - m ) T di mana m adalah mean keseluruhan.SB

    SB=i=1cni(mim)(mim)T
    m
  3. Langkah 3. Memecahkan masalah nilai eigen umum untuk matriks SW-1SB

    3.1. Menyortir vektor eigen dengan menurunkan nilai eigen

    3.2. Memilih k vektor eigen dengan nilai eigen terbesar. Menggabungkan dua vektor eigen dengan nilai eigen tertinggi untuk membangun kami berdimensi vektor eigen matriks Wd×kW

  4. Langkah 5: Mengubah sampel ke subruang baru

    y=WT×x.
amuba kata Reinstate Monica
sumber
Saya belum melalui untuk mencari perbedaan, tetapi Anda dapat melihat apa yang dilakukan scikit-belajar di sumbernya .
Dougal
Sepertinya mereka juga melakukan standarisasi (pemusatan dan kemudian penskalaan melalui pembagian dengan standar deviasi). Ini, saya akan mengharapkan hasil yang mirip dengan yang ada di plot ke-3 saya (dan plot R) ... hmm
Aneh: plot yang Anda dapatkan dengan scikit (dan yang mereka tampilkan di dokumentasi) tidak masuk akal. LDA selalu menghasilkan proyeksi yang memiliki korelasi nol, tetapi jelas ada korelasi yang sangat kuat antara proyeksi scikit pada sumbu 1 dan 2. Ada sesuatu yang salah di sana.
Amoeba berkata Reinstate Monica
@ameoba Ya, saya juga berpikir begitu. Yang juga aneh adalah bahwa plot yang sama dengan yang saya tunjukkan untuk scikit ada dalam dokumentasi contoh: scikit-learn.org/stable/auto_examples/decomposition/... Itu membuat saya berpikir bahwa penggunaan scikit itu benar, tetapi ada sesuatu yang aneh tentang fungsi LDA
@SebastianRaschka: Ya, saya perhatikan. Memang aneh. Namun, perhatikan bahwa plot LDA pertama Anda (non-scikit) juga menunjukkan korelasi non-nol dan karenanya ada sesuatu yang salah dengan plot tersebut. Apakah Anda memusatkan data? Proyeksi pada sumbu kedua tampaknya tidak memiliki rata-rata nol.
Amoeba berkata Reinstate Monica

Jawaban:

20

Pembaruan: Berkat diskusi ini, scikit-learntelah diperbarui dan berfungsi dengan benar sekarang. Kode sumber LDA-nya dapat ditemukan di sini . Masalah aslinya adalah karena bug kecil (lihat diskusi github ini ) dan jawaban saya sebenarnya tidak menunjuk dengan benar (permintaan maaf atas kebingungan yang disebabkan). Karena semua itu tidak masalah lagi (bug diperbaiki), saya mengedit jawaban saya untuk fokus pada bagaimana LDA dapat diselesaikan melalui SVD, yang merupakan algoritma default di scikit-learn.


ΣWΣBΣW1ΣB

  1. ΣW1/2

    Perhatikan bahwa jika Anda memiliki dekomposisi eigen ΣW=USUΣW1/2=US1/2UXW=ULVΣW1/2=UL1U

  2. ΣW1/2ΣBΣW1/2A

    XBΣW1/2

  3. AΣW1/2A

    Memang, jika adalah vektor eigen dari matriks di atas, maka Σ - 1 / 2 W Σ B Σ - 1 / 2 Wa

    ΣW1/2ΣBΣW1/2a=λa,
    ΣW1/2a=ΣW1/2a
    ΣW1ΣBa=λa.

Singkatnya, LDA setara dengan memutihkan matriks sarana kelas sehubungan dengan kovarians dalam kelas, melakukan PCA pada sarana kelas, dan mentransformasikan kembali sumbu utama yang dihasilkan ke ruang asli (tanpa dikhawatirkan).

Ini ditunjukkan misalnya dalam The Elements of Statistics Learning , bagian 4.3.3. Dalam scikit-learnhal ini adalah cara standar untuk menghitung LDA karena SVD dari matriks data secara numerik lebih stabil daripada dekomposisi eigen dari matriks kovariansnya.

ΣW-1/2scikit-learn L.-1UUL.-1U

amuba kata Reinstate Monica
sumber
1
Terima kasih atas jawaban yang bagus ini. Saya menghargai bahwa Anda meluangkan waktu untuk menulisnya dengan baik. Mungkin Anda bisa menyebutkannya dalam diskusi di GitHub; Saya yakin itu akan membantu memperbaiki LDA di versi sci-kit berikutnya
@SebastianRaschka: Saya tidak punya akun di GitHub. Tetapi jika Anda mau, Anda bisa memberikan tautan ke utas ini.
Amoeba berkata Reinstate Monica
ΣW-1ΣBΣW-1
2
ΣBΣW-1ΣBΣW-1(μ1-μ2)μsaya
Amuba mengatakan Reinstate Monica
3

Hanya untuk menutup pertanyaan ini, masalah yang dibahas dengan LDA telah diperbaiki di scikit-learn 0.15.2 .


sumber