RandomForest dan bobot kelas

11

Pertanyaan dalam satu kalimat: Apakah ada yang tahu cara menentukan bobot kelas yang baik untuk hutan acak?

Penjelasan: Saya bermain-main dengan dataset yang tidak seimbang. Saya ingin menggunakan Rpaket randomForestuntuk melatih model pada dataset yang sangat miring dengan hanya sedikit contoh positif dan banyak contoh negatif. Saya tahu, ada metode lain dan pada akhirnya saya akan menggunakannya tetapi untuk alasan teknis, membangun hutan acak adalah langkah menengah. Jadi saya bermain-main dengan parameter classwt. Saya membuat dataset sangat buatan dari 5000 contoh negatif dalam disk dengan radius 2 dan kemudian saya sampel 100 contoh positif dalam disk dengan jari-jari 1. Apa yang saya duga adalah bahwa

1) tanpa pembobotan kelas model menjadi 'merosot', yaitu memprediksi di FALSEmana-mana.

2) dengan bobot kelas yang adil saya akan melihat 'titik hijau' di tengah, yaitu akan memprediksi disk dengan jari-jari 1 seolah-olah TRUEada contoh negatif.

Begini tampilannya:

masukkan deskripsi gambar di sini

Inilah yang terjadi tanpa memberi bobot: (panggilan adalah randomForest(x = train[, .(x,y)],y = as.factor(train$z),ntree = 50):)

masukkan deskripsi gambar di sini

Untuk memeriksa saya juga mencoba apa yang terjadi ketika saya menyeimbangkan dataset dengan downsampling kelas negatif sehingga hubungannya adalah 1: 1 lagi. Ini memberi saya hasil yang diharapkan:

masukkan deskripsi gambar di sini

Namun, ketika saya menghitung model dengan bobot kelas 'FALSE' = 1, 'TRUE' = 50 (ini adalah bobot yang adil karena ada 50 kali lebih banyak negatif daripada positif) maka saya mendapatkan ini:

masukkan deskripsi gambar di sini

Hanya ketika saya mengatur bobot ke beberapa nilai aneh seperti 'FALSE' = 0,05 dan 'TRUE' = 500000 maka saya mendapatkan hasil yang masuk akal:

masukkan deskripsi gambar di sini

Dan ini sangat tidak stabil, yaitu mengubah bobot 'SALAH' menjadi 0,01 membuat model tersebut merosot lagi (yaitu memprediksi di TRUEmana-mana).

Pertanyaan: Apakah ada yang tahu cara menentukan bobot kelas yang baik untuk hutan acak?

Kode R:

library(plot3D)
library(data.table)
library(randomForest)
set.seed(1234)
amountPos = 100
amountNeg = 5000

# positives
r = runif(amountPos, 0, 1)
phi = runif(amountPos, 0, 2*pi)
x = r*cos(phi)
y = r*sin(phi)
z = rep(T, length(x))
pos = data.table(x = x, y = y, z = z)

# negatives
r = runif(amountNeg, 0, 2)
phi = runif(amountNeg, 0, 2*pi)
x = r*cos(phi)
y = r*sin(phi)
z = rep(F, length(x))
neg = data.table(x = x, y = y, z = z)

train = rbind(pos, neg)

# draw train set, verify that everything looks ok
plot(train[z == F]$x, train[z == F]$y, col="red")
points(train[z == T]$x, train[z == T]$y, col="green")
# looks ok to me :-)

Color.interpolateColor = function(fromColor, toColor, amountColors = 50) {
  from_rgb = col2rgb(fromColor)
  to_rgb = col2rgb(toColor)

  from_r = from_rgb[1,1]
  from_g = from_rgb[2,1]
  from_b = from_rgb[3,1]

  to_r = to_rgb[1,1]
  to_g = to_rgb[2,1]
  to_b = to_rgb[3,1]

  r = seq(from_r, to_r, length.out = amountColors)
  g = seq(from_g, to_g, length.out = amountColors)
  b = seq(from_b, to_b, length.out = amountColors)

  return(rgb(r, g, b, maxColorValue = 255))
}
DataTable.crossJoin = function(X,Y) {
  stopifnot(is.data.table(X),is.data.table(Y))
  k = NULL
  X = X[, c(k=1, .SD)]
  setkey(X, k)
  Y = Y[, c(k=1, .SD)]
  setkey(Y, k)
  res = Y[X, allow.cartesian=TRUE][, k := NULL]
  X = X[, k := NULL]
  Y = Y[, k := NULL]
  return(res)
}

drawPredictionAreaSimple = function(model) {
  widthOfSquares = 0.1
  from = -2
  to = 2

  xTable = data.table(x = seq(from=from+widthOfSquares/2,to=to-widthOfSquares/2,by = widthOfSquares))
  yTable = data.table(y = seq(from=from+widthOfSquares/2,to=to-widthOfSquares/2,by = widthOfSquares))
  predictionTable = DataTable.crossJoin(xTable, yTable)
  pred = predict(model, predictionTable)
  res = rep(NA, length(pred))
  res[pred == "FALSE"] = 0
  res[pred == "TRUE"] = 1
  pred = res
  predictionTable = predictionTable[, PREDICTION := pred]
  #predictionTable = predictionTable[y == -1 & x == -1, PREDICTION := 0.99]
  col = Color.interpolateColor("red", "green")

  input = matrix(c(predictionTable$x, predictionTable$y), nrow = 2, byrow = T)
  m = daply(predictionTable, .(x, y), function(x) x$PREDICTION)
  image2D(z = m, x = sort(unique(predictionTable$x)), y = sort(unique(predictionTable$y)), col = col, zlim = c(0,1))
}


rfModel = randomForest(x = train[, .(x,y)],y = as.factor(train$z),ntree = 50)
rfModelBalanced = randomForest(x = train[, .(x,y)],y = as.factor(train$z),ntree = 50, classwt = c("FALSE" = 1, "TRUE" = 50))
rfModelBalancedWeird = randomForest(x = train[, .(x,y)],y = as.factor(train$z),ntree = 50, classwt = c("FALSE" = 0.05, "TRUE" = 500000))


drawPredictionAreaSimple(rfModel)
title("unbalanced")
drawPredictionAreaSimple(rfModelBalanced)
title("balanced with weights")
pos = train[z == T]
neg = train[z == F]
neg = neg[sample.int(neg[, .N], size = 100, replace = FALSE)]
trainSampled = rbind(pos, neg)
rfModelBalancedSampling = randomForest(x = trainSampled[, .(x,y)],y = as.factor(trainSampled$z),ntree = 50)
drawPredictionAreaSimple(rfModelBalancedSampling)
title("balanced with sampling")


drawPredictionAreaSimple(rfModelBalancedWeird)
title("balanced with weird weights")
Fabian Werner
sumber
jika sampsize berfungsi mengapa tidak menggunakannya? Saya juga menemukan sampsize bekerja lebih baik untuk menyelesaikan ini, seperti yang dilakukan orang lain. Juga lihat jawaban yang bagus di sini stats.stackexchange.com/questions/157714/...
katya
daply ada di plyr, Anda harus memanggil itu.
EngrStudent
Priw kelas Priw kelas. Tidak perlu menambahkan hingga satu. Diabaikan untuk regresi.
Diego
@Diego: Parameter ini tampaknya sangat bermasalah, saya tidak mengerti bagaimana parameter ini berperilaku (lihat komentar dalam pertanyaan!) ...
Fabian Werner

Jawaban:

2

Jangan gunakan cutoff keras untuk mengklasifikasikan keanggotaan keras, dan jangan gunakan KPI yang bergantung pada prediksi keanggotaan keras. Sebaliknya, bekerja dengan prediksi probabilistik, menggunakan predict(..., type="prob"), dan menilai ini menggunakan tepat .

Utas sebelumnya ini harus membantu: Mengapa akurasi bukan ukuran terbaik untuk menilai model klasifikasi? Cukup mengejutkan, saya yakin jawaban saya akan sangat membantu (maaf atas ketidakberdayaan), seperti jawaban saya sebelumnya .

Stephan Kolassa
sumber
1
ini adalah topik yang kamu bicarakan hari ini, kan? Saya akan melihat ini dalam pekerjaan saya :-)
Jonas Heidelberg