Menggunakan paket caret apakah mungkin untuk mendapatkan matriks kebingungan untuk nilai ambang tertentu?

13

Saya telah mendapatkan model regresi logistik (via train) untuk respons biner, dan saya telah mendapatkan matriks kebingungan logistik via confusionMatrixin caret. Ini memberi saya matriks kebingungan model logistik, meskipun saya tidak yakin apa ambang batas yang digunakan untuk mendapatkannya. Bagaimana cara mendapatkan matriks kebingungan untuk nilai ambang batas tertentu menggunakan confusionMatrixdi caret?

Susu Hitam
sumber
Saya tidak punya jawaban, tetapi seringkali pertanyaan seperti ini dijawab dalam file bantuan. Jika gagal, Anda dapat melihat kode sumbernya sendiri. Anda dapat mencetak sumber ke konsol dengan mengetik confusionmatrix, tanpa tanda kurung.
shadowtalker
Tidak jelas apa yang telah Anda lakukan persis. Apakah Anda memanggil glmfungsi dari statspaket dan meneruskan hasilnya confusionMatrix? Saya tidak tahu orang bisa melakukan itu, dan membaca manual itu tidak jelas sama sekali. Atau ada predictsesuatu? Contoh singkat akan membantu.
Calimo
1
@ Calimo Saya telah menggunakan trainfungsi caretagar sesuai dengan model, yang memungkinkan saya menentukannya sebagai glm dengan keluarga binomial. Saya kemudian menggunakan predictfungsi pada objek yang dihasilkan melalui train.
Susu Hitam

Jawaban:

10

Sebagian besar model klasifikasi dalam R menghasilkan prediksi kelas dan probabilitas untuk setiap kelas. Untuk data biner, dalam hampir setiap kasus, prediksi kelas didasarkan pada cutoff probabilitas 50%.

glmadalah sama. Dengan caret, menggunakan predict(object, newdata)memberi Anda kelas yang diprediksi dan predict(object, new data, type = "prob")akan memberikan Anda probabilitas kelas-spesifik (saat objectdihasilkan oleh train).

Anda dapat melakukan berbagai hal secara berbeda dengan mendefinisikan model Anda sendiri dan menerapkan cutoff apa pun yang Anda inginkan. Situs caret web ini juga memiliki contoh yang menggunakan resampling untuk mengoptimalkan kemungkinan cutoff.

tl; dr

confusionMatrix menggunakan kelas prediksi dan dengan demikian kemungkinan cutoff 50%

Maks

topepo
sumber
14

Ada cara yang cukup mudah, dengan asumsi tune <- train(...):

probsTest <- predict(tune, test, type = "prob")
threshold <- 0.5
pred      <- factor( ifelse(probsTest[, "yes"] > threshold, "yes", "no") )
pred      <- relevel(pred, "yes")   # you may or may not need this; I did
confusionMatrix(pred, test$response)

Jelas, Anda dapat menetapkan ambang batas untuk apa pun yang ingin Anda coba atau pilih yang "terbaik", di mana yang terbaik berarti spesifisitas dan sensitivitas gabungan tertinggi:

library(pROC)
probsTrain <- predict(tune, train, type = "prob")
rocCurve   <- roc(response = train$response,
                      predictor = probsTrain[, "yes"],
                      levels = rev(levels(train$response)))
plot(rocCurve, print.thres = "best")

Setelah melihat contoh yang diposting Max, saya tidak yakin apakah ada beberapa nuansa statistik yang membuat pendekatan saya kurang diinginkan.

efh0888
sumber
Dalam plot rocCurve yang di-output, apa arti ketiga nilai tersebut? misalnya pada data saya dikatakan 0,289 (0,853, 0,831). Apakah 0,289 menandakan ambang batas terbaik yang harus digunakan dalam demarkasi hasil biner? yaitu setiap kasus dengan probabilitas yang diprediksi> 0,289 akan dikodekan "1" dan setiap kasus dengan probabilitas yang diprediksi <0,289 akan dikodekan "0", daripada 0,5 ambang batas standar caretpaket?
gabungkan
2
ya itu persis benar, dan 2 nilai lainnya di dalam tanda kurung adalah sensitivitas dan spesifisitas (jujur, saya lupa yang mana)
efh0888
2
juga, sejak itu saya tahu Anda dapat mengekstraknya dari kurva roc rocCurve$thresholds[which(rocCurve$sensitivities + rocCurve$specificities == max(rocCurve$sensitivities + rocCurve$specificities))]yang juga memberi Anda fleksibilitas untuk menimbang mereka secara berbeda jika Anda ingin ... satu hal terakhir yang perlu diperhatikan adalah bahwa secara realistis, Anda mungkin ingin menyetel ambang (seperti Anda akan menggunakan model hyperparameter) seperti yang dijelaskan Max di sini .
efh0888