Paket R untuk Hutan Acak Tertimbang? Opsi classwt?

16

Saya mencoba menggunakan Random Forest untuk memprediksi hasil dari set data yang sangat tidak seimbang (tingkat kelas minoritas hanya sekitar 1% atau bahkan kurang). Karena algoritma Acak Hutan tradisional meminimalkan tingkat kesalahan keseluruhan, daripada memberikan perhatian khusus pada kelas-kelas minoritas, itu tidak langsung berlaku pada data yang tidak seimbang. Jadi saya ingin menetapkan biaya tinggi untuk kesalahan klasifikasi kelas minoritas (pembelajaran sensitif biaya).

Saya membaca beberapa sumber yang kita dapat menggunakan opsi classwtdari randomForestdalam R, tapi saya tidak tahu bagaimana menggunakan ini. Dan apakah kita punya alternatif lain selain randomForestfuntion?

Metariat
sumber

Jawaban:

29

Utas ini merujuk pada dua utas lainnya dan artikel bagus tentang masalah ini. Sepertinya classweighting dan downsampling sama-sama bagus. Saya menggunakan downsampling seperti yang dijelaskan di bawah ini.

Ingat set pelatihan harus besar karena hanya 1% akan mencirikan kelas langka. Kurang dari 25 ~ 50 sampel dari kelas ini mungkin akan bermasalah. Beberapa sampel yang mengkarakterisasi kelas akan membuat pola belajar menjadi kasar dan kurang dapat diproduksi kembali.

RF menggunakan voting mayoritas sebagai default. Prevalensi kelas dari set pelatihan akan beroperasi sebagai semacam sebelumnya yang efektif. Jadi kecuali kelas langka dapat dipisahkan dengan sempurna, kecil kemungkinan kelas langka ini akan memenangkan suara mayoritas saat memprediksi. Alih-alih mengumpulkan berdasarkan suara terbanyak, Anda dapat mengumpulkan fraksi suara.

Pengambilan sampel bertingkat dapat digunakan untuk meningkatkan pengaruh kelas langka. Hal ini dilakukan dengan biaya pada downsampling kelas-kelas lain. Pohon yang tumbuh akan menjadi kurang dalam karena sampel yang lebih sedikit perlu dipecah sehingga membatasi kompleksitas pola potensial yang dipelajari. Jumlah pohon yang ditanam harus besar misalnya 4000 sehingga sebagian besar pengamatan berpartisipasi dalam beberapa pohon.

Dalam contoh di bawah ini, saya telah mensimulasikan set data pelatihan 5000 sampel dengan 3 kelas dengan prevalensi masing-masing 1%, 49% dan 50%. Dengan demikian akan ada 50 sampel kelas 0. Gambar pertama menunjukkan kelas sebenarnya dari pelatihan ditetapkan sebagai fungsi dari dua variabel x1 dan x2. Gambar ini menunjukkan pola simulasi untuk belajar

Empat model dilatih: Sebuah model default, dan tiga model bertingkat dengan 1:10:10 1: 2: 2 dan 1: 1: 1 stratifikasi kelas. Utama sementara jumlah sampel inbag (termasuk redraws) di setiap pohon akan menjadi 5000, 1050, 250 dan 150. Karena saya tidak menggunakan voting mayoritas, saya tidak perlu membuat stratifikasi yang seimbang sempurna. Sebaliknya suara pada kelas langka bisa ditimbang 10 kali atau aturan keputusan lainnya. Biaya negatif palsu dan positif palsu Anda harus memengaruhi aturan ini.

Gambar berikut menunjukkan bagaimana stratifikasi memengaruhi fraksi suara. Perhatikan bahwa rasio kelas yang bertingkat selalu merupakan pusat prediksi. stratifikasi dan pemungutan suara

Terakhir Anda dapat menggunakan kurva ROC untuk menemukan aturan pemilihan yang memberi Anda trade-off yang baik antara spesifisitas dan sensitivitas. Garis hitam bukanlah stratifikasi, merah 1: 5: 5, hijau 1: 2: 2 dan biru 1: 1: 1. Untuk kumpulan data ini 1: 2: 2 atau 1: 1: 1 tampaknya merupakan pilihan terbaik. kurva roc

Ngomong-ngomong, fraksi suara ada di sini di luar tas crossvalidated.

Dan kodenya:

library(plotrix)
library(randomForest)
library(AUC)

make.data = function(obs=5000,vars=6,noise.factor = .2,smallGroupFraction=.01) {
X = data.frame(replicate(vars,rnorm(obs)))
yValue = with(X,sin(X1*pi)+sin(X2*pi*2)+rnorm(obs)*noise.factor)
yQuantile = quantile(yValue,c(smallGroupFraction,.5))
yClass = apply(sapply(yQuantile,function(x) x<yValue),1,sum)
yClass = factor(yClass)
print(table(yClass)) #five classes, first class has 1% prevalence only
Data=data.frame(X=X,y=yClass)
}

plot.separation = function(rf,...) {
triax.plot(rf$votes,...,col.symbols = c("#FF0000FF",
                                       "#00FF0010",
                                       "#0000FF10")[as.numeric(rf$y)])
}

#make data set where class "0"(red circles) are rare observations
#Class 0 is somewhat separateble from class "1" and fully separateble from class "2"
Data = make.data()
par(mfrow=c(1,1))
plot(Data[,1:2],main="separation problem: identify rare red circles",
     col = c("#FF0000FF","#00FF0020","#0000FF20")[as.numeric(Data$y)])

#train default RF and with 10x 30x and 100x upsumpling by stratification
rf1 = randomForest(y~.,Data,ntree=500, sampsize=5000)
rf2 = randomForest(y~.,Data,ntree=4000,sampsize=c(50,500,500),strata=Data$y)
rf3 = randomForest(y~.,Data,ntree=4000,sampsize=c(50,100,100),strata=Data$y)
rf4 = randomForest(y~.,Data,ntree=4000,sampsize=c(50,50,50)  ,strata=Data$y)

#plot out-of-bag pluralistic predictions(vote fractions).
par(mfrow=c(2,2),mar=c(4,4,3,3))
plot.separation(rf1,main="no stratification")
plot.separation(rf2,main="1:10:10")
plot.separation(rf3,main="1:5:5")
plot.separation(rf4,main="1:1:1")

par(mfrow=c(1,1))
plot(roc(rf1$votes[,1],factor(1 * (rf1$y==0))),main="ROC curves for four models predicting class 0")
plot(roc(rf2$votes[,1],factor(1 * (rf1$y==0))),col=2,add=T)
plot(roc(rf3$votes[,1],factor(1 * (rf1$y==0))),col=3,add=T)
plot(roc(rf4$votes[,1],factor(1 * (rf1$y==0))),col=4,add=T)
Soren Havelund Welling
sumber
oups satu gambar gambar mengatakan 1: 5: 5 bukannya 1: 2: 2
Soren Havelund Welling
terima kasih banyak atas jawaban terinci Anda, yang pasti akan banyak membantu saya dalam pekerjaan sehari-hari saya. Ada satu kalimat yang saya tidak mengerti: "Main sementara jumlah sampel inbag (termasuk redraws) di setiap pohon akan menjadi 5000.1050, 250 dan 150". Bisakah Anda jelaskan di mana saya berasal angka?
Metariat
1
kesenangan saya;) dalam contoh ini kelas langka memiliki 50 anggota. Jika bertingkat 1:10:10 kita perlu menentukan sampsize = c (50.500.500). 50 + 500 + 500 = 1050. Pohon 1050 sampel yang sudah dewasa akan memiliki total 1050x2 node.
Soren Havelund Welling
Maaf jika pertanyaan saya bodoh, tetapi apa artinya 1:10:10, 1: 2: 2 dan 1: 1: 1 stratifikasi di sini? Dan ketika Anda mengatakan "suara pada kelas langka dapat ditimbang 10 kali". Bagian mana dari kode yang menyatakan hal itu? Apakah 1:10:10? Terima kasih banyak!
Metariat
1
1:10:10 adalah rasio antar kelas. Kumpulan data yang disimulasikan dirancang untuk memiliki rasio 1:49:50. Rasio ini diubah dengan mengambil sampel dua kelas yang lebih besar. Dengan memilih contoh sampsize = c (50.500.500) sama dengan c (1,10,10) * 50 Anda mengubah rasio kelas di pohon. 50 adalah jumlah sampel dari kelas langka. Jika Anda selanjutnya menetapkan keep.inbag = BENAR dan memeriksa rf $ inbag, Anda akan melihat bahwa sampel kelas langka adalah inbag di ~ 2/3 pohon sedangkan setiap sampel kelas non-langka dimasukkan dalam sangat sedikit pohon karena down sampling.
Soren Havelund Welling