Regresi hutan acak tidak memprediksi lebih tinggi dari data pelatihan

12

Saya perhatikan ketika membangun model regresi hutan acak, setidaknya dalam R, nilai prediksi tidak pernah melebihi nilai maksimum dari variabel target yang terlihat dalam data pelatihan. Sebagai contoh, lihat kode di bawah ini. Saya sedang membangun model regresi untuk memprediksi mpgberdasarkan mtcarsdata. Saya membuat OLS dan model hutan acak, dan menggunakannya untuk memprediksi mpgmobil hipotetis yang seharusnya memiliki penghematan bahan bakar yang sangat baik. OLS memprediksi mpghutan tinggi , seperti yang diharapkan, tetapi hutan acak tidak. Saya perhatikan ini dalam model yang lebih kompleks juga. Kenapa ini?

> library(datasets)
> library(randomForest)
> 
> data(mtcars)
> max(mtcars$mpg)
[1] 33.9
> 
> set.seed(2)
> fit1 <- lm(mpg~., data=mtcars) #OLS fit
> fit2 <- randomForest(mpg~., data=mtcars) #random forest fit
> 
> #Hypothetical car that should have very high mpg
> hypCar <- data.frame(cyl=4, disp=50, hp=40, drat=5.5, wt=1, qsec=24, vs=1, am=1, gear=4, carb=1)
> 
> predict(fit1, hypCar) #OLS predicts higher mpg than max(mtcars$mpg)
      1 
37.2441 
> predict(fit2, hypCar) #RF does not predict higher mpg than max(mtcars$mpg)
       1 
30.78899 
Gaurav Bansal
sumber
Apakah umum bahwa orang menyebut regresi linier sebagai OLS? Saya selalu menganggap OLS sebagai metode.
Hao Ye
1
Saya percaya OLS adalah metode standar regresi linier, setidaknya di R.
Gaurav Bansal
Untuk pohon / hutan acak, prediksi adalah rata-rata data pelatihan dalam simpul yang sesuai. Jadi itu tidak bisa lebih besar dari nilai dalam data pelatihan.
Jason
1
Saya setuju tetapi sudah dijawab oleh setidaknya tiga pengguna lain.
HelloWorld

Jawaban:

12

Seperti yang telah disebutkan di jawaban sebelumnya, hutan acak untuk pohon regresi / regresi tidak menghasilkan prediksi yang diharapkan untuk titik data di luar ruang lingkup rentang data pelatihan karena mereka tidak dapat memperkirakan (well). Pohon regresi terdiri dari hierarki node, di mana setiap node menentukan tes yang akan dilakukan pada nilai atribut dan setiap node daun (terminal) menentukan aturan untuk menghitung output yang diprediksi. Dalam kasus Anda, pengamatan pengamatan mengalir melalui pohon ke simpul daun yang menyatakan, misalnya, "jika x> 335, maka y = 15", yang kemudian dirata-ratakan oleh hutan acak.

Berikut ini adalah skrip R yang memvisualisasikan situasi dengan hutan acak dan regresi linier. Dalam kasus hutan acak, prediksi konstan untuk menguji titik data yang baik di bawah nilai x data pelatihan terendah atau di atas nilai x data pelatihan tertinggi.

library(datasets)
library(randomForest)
library(ggplot2)
library(ggthemes)

# Import mtcars (Motor Trend Car Road Tests) dataset
data(mtcars)

# Define training data
train_data = data.frame(
    x = mtcars$hp,  # Gross horsepower
    y = mtcars$qsec)  # 1/4 mile time

# Train random forest model for regression
random_forest <- randomForest(x = matrix(train_data$x),
                              y = matrix(train_data$y), ntree = 20)
# Train linear regression model using ordinary least squares (OLS) estimator
linear_regr <- lm(y ~ x, train_data)

# Create testing data
test_data = data.frame(x = seq(0, 400))

# Predict targets for testing data points
test_data$y_predicted_rf <- predict(random_forest, matrix(test_data$x)) 
test_data$y_predicted_linreg <- predict(linear_regr, test_data)

# Visualize
ggplot2::ggplot() + 
    # Training data points
    ggplot2::geom_point(data = train_data, size = 2,
                        ggplot2::aes(x = x, y = y, color = "Training data")) +
    # Random forest predictions
    ggplot2::geom_line(data = test_data, size = 2, alpha = 0.7,
                       ggplot2::aes(x = x, y = y_predicted_rf,
                                    color = "Predicted with random forest")) +
    # Linear regression predictions
    ggplot2::geom_line(data = test_data, size = 2, alpha = 0.7,
                       ggplot2::aes(x = x, y = y_predicted_linreg,
                                    color = "Predicted with linear regression")) +
    # Hide legend title, change legend location and add axis labels
    ggplot2::theme(legend.title = element_blank(),
                   legend.position = "bottom") + labs(y = "1/4 mile time",
                                                      x = "Gross horsepower") +
    ggthemes::scale_colour_colorblind()

Ekstrapolasi dengan hutan acak dan regresi linier

tuomastik
sumber
16

Tidak ada cara untuk Hutan Acak untuk memperkirakan seperti yang dilakukan OLS. Alasannya sederhana: prediksi dari Hutan Acak dilakukan melalui rata-rata hasil yang diperoleh di beberapa pohon. Pohon-pohon itu sendiri mengeluarkan nilai rata-rata sampel di setiap simpul terminal, yaitu daun. Tidak mungkin hasilnya berada di luar rentang data pelatihan, karena rata-rata selalu berada di dalam kisaran konstituennya.

Dengan kata lain, rata-rata tidak mungkin menjadi lebih besar (atau lebih rendah) daripada setiap sampel, dan regresi Random Forests didasarkan pada rata-rata.

Pembakar
sumber
11

Pohon Keputusan / Forrest Acak tidak dapat mengekstrapolasi di luar data pelatihan. Dan meskipun OLS dapat melakukan ini, prediksi seperti itu harus dilihat dengan hati-hati; karena pola yang teridentifikasi mungkin tidak berlanjut di luar kisaran yang diamati.

B. Frost
sumber