Apa yang terjadi di sini, ketika saya menggunakan kuadrat kerugian dalam pengaturan regresi logistik?

16

Saya mencoba menggunakan kuadrat kerugian untuk melakukan klasifikasi biner pada kumpulan data mainan.

Saya menggunakan mtcarskumpulan data, menggunakan mil per galon dan berat untuk memprediksi jenis transmisi. Plot di bawah ini menunjukkan dua jenis data tipe transmisi dalam warna berbeda, dan batas keputusan dihasilkan oleh fungsi kerugian yang berbeda. Kerugian kuadrat adalah i(yipi)2 mana yi adalah label kebenaran dasar (0 atau 1) dan psaya adalah probabilitas yang diperkirakan pi=Logit1(βTxi). Dengan kata lain, saya mengganti kerugian logistik dengan kerugian kuadrat dalam pengaturan klasifikasi, bagian lain adalah sama.

Sebagai contoh mainan dengan mtcarsdata, dalam banyak kasus, saya mendapat model "mirip" dengan regresi logistik (lihat gambar berikut, dengan seed acak 0).

masukkan deskripsi gambar di sini

Tetapi dalam beberapa hal (jika kita lakukan set.seed(1)), kerugian kuadrat tampaknya tidak berfungsi dengan baik. masukkan deskripsi gambar di sini Apa yang terjadi disini? Optimasi tidak bertemu? Kehilangan logistik lebih mudah dioptimalkan dibandingkan dengan kerugian kuadrat? Bantuan apa pun akan dihargai.


Kode

d=mtcars[,c("am","mpg","wt")]
plot(d$mpg,d$wt,col=factor(d$am))
lg_fit=glm(am~.,d, family = binomial())
abline(-lg_fit$coefficients[1]/lg_fit$coefficients[3],
       -lg_fit$coefficients[2]/lg_fit$coefficients[3])
grid()

# sq loss
lossSqOnBinary<-function(x,y,w){
  p=plogis(x %*% w)
  return(sum((y-p)^2))
}

# ----------------------------------------------------------------
# note, this random seed is important for squared loss work
# ----------------------------------------------------------------
set.seed(0)

x0=runif(3)
x=as.matrix(cbind(1,d[,2:3]))
y=d$am
opt=optim(x0, lossSqOnBinary, method="BFGS", x=x,y=y)

abline(-opt$par[1]/opt$par[3],
       -opt$par[2]/opt$par[3], lty=2)
legend(25,5,c("logisitc loss","squared loss"), lty=c(1,2))
Haitao Du
sumber
1
Mungkin nilai awal acak adalah nilai yang buruk. Mengapa tidak memilih yang lebih baik?
whuber
1
@whuber kerugian logistik adalah cembung, jadi mulai tidak masalah. bagaimana dengan kuadrat kerugian pada p dan y? apakah itu cembung?
Haitao Du
5
Saya tidak dapat mereproduksi apa yang Anda gambarkan. optimmemberitahu Anda itu belum selesai, itu saja: itu sedang konvergen. Anda mungkin belajar banyak dengan menjalankan kembali kode Anda dengan argumen tambahan control=list(maxit=10000), merencanakan kecocokannya, dan membandingkan koefisiennya dengan yang asli.
whuber
2
@amoeba terima kasih atas komentar Anda, saya merevisi pertanyaan itu. semoga lebih baik.
Haitao Du
@amoeba Saya akan merevisi legenda, tetapi pernyataan ini tidak akan memperbaiki (3)? "Saya menggunakan kumpulan data mtcars, menggunakan mil per galon dan berat untuk memprediksi tipe transmisi. Plot di bawah ini menunjukkan dua tipe data tipe transmisi dalam warna yang berbeda, dan batas keputusan yang dihasilkan oleh fungsi kehilangan yang berbeda."
Haitao Du

Jawaban:

19

Sepertinya Anda telah memperbaiki masalah dalam contoh khusus Anda, tetapi saya pikir ini masih layak untuk studi yang lebih cermat tentang perbedaan antara kuadrat terkecil dan regresi logistik kemungkinan maksimum.

Mari kita mendapatkan beberapa notasi. Let LS(yi,y^i)=12(yiy^i)2danLL(yi,y^i)=yilogy^i+(1yi)log(1y^i). Jika kita melakukan kemungkinan maksimum (atau minimum negatif log kemungkinan seperti yang saya lakukan di sini), kita memiliki β L:=argminb R

β^L:=argminbRpi=1nyilogg1(xiTb)+(1yi)log(1g1(xiTb))
dengangmenjadi fungsi tautan kami.

Atau kita memiliki β S : = argmin b R p 1

β^S:=argminbRp12i=1n(yig1(xiTb))2
sebagai solusi kuadrat terkecil. Dengan demikian β SmeminimalkanLSdan juga untukLL.β^SLSLL

Biarkan fS dan fL menjadi fungsi objektif yang sesuai dengan meminimalkan LS dan LL masing-masing seperti yang dilakukan untuk β S dan β L . Akhirnya, mari h = g - 1 sehingga y i = h ( x T i b ) . Perhatikan bahwa jika kita menggunakan tautan kanonik, kita mendapat h ( z ) = 1β^Sβ^Lh=g1y^i=h(xiTb)

h(z)=11+ezh(z)=h(z)(1h(z)).


Untuk regresi logistik biasa kita harus

fLbj=i=1nh(xiTb)xij(yih(xiTb)1yi1h(xiTb)).
Menggunakanh=h(1h)kita dapat menyederhanakan ini menjadi
fLbj=i=1nxij(yi(1y^i)(1yi)y^i)=i=1nxij(yiy^i)
jadi
fL(b)=XT(YY^).

Selanjutnya mari kita lakukan turunan kedua. The Hessian

HL:=2fLbjbk=i=1nxijxiky^i(1y^i).
HL=XTAXA=diag(Y^(1Y^))HLY^YHLb


Mari kita bandingkan ini dengan kuadrat terkecil.

fSbj=i=1n(yiy^i)h(xiTb)xij.

fS(b)=XTA(YY^).
i y^i(1y^i)(0,1)fL

fSbj=i=1nxij(yiy^i)y^i(1y^i)=i=1nxij(yiy^i(1+yi)y^i2+y^i3).

This leads us to

HS:=2fSbjbk=i=1nxijxikh(xiTb)(yi2(1+yi)y^i+3y^i2).

Let B=diag(yi2(1+yi)y^i+3y^i2). We now have

HS=XTABX.

Unfortunately for us, the weights in B are not guaranteed to be non-negative: if yi=0 then yi2(1+yi)y^i+3y^i2=y^i(3y^i2) which is positive iff y^i>23. Similarly, if yi=1 then yi2(1+yi)y^i+3y^i2=14y^i+3y^i2 which is positive when y^i<13 (it's also positive for y^i>1 but that's not possible). This means that HS is not necessarily PSD, so not only are we squashing our gradients which will make learning harder, but we've also messed up the convexity of our problem.


All in all, it's no surprise that least squares logistic regression struggles sometimes, and in your example you've got enough fitted values close to 0 or 1 so that y^i(1y^i) can be pretty small and thus the gradient is quite flattened.

Connecting this to neural networks, even though this is but a humble logistic regression I think with squared loss you're experiencing something like what Goodfellow, Bengio, and Courville are referring to in their Deep Learning book when they write the following:

One recurring theme throughout neural network design is that the gradient of the cost function must be large and predictable enough to serve as a good guide for the learning algorithm. Functions that saturate (become very flat) undermine this objective because they make the gradient become very small. In many cases this happens because the activation functions used to produce the output of the hidden units or the output units saturate. The negative log-likelihood helps to avoid this problem for many models. Many output units involve an exp function that can saturate when its argument is very negative. The log function in the negative log-likelihood cost function undoes the exp of some output units. We will discuss the interaction between the cost function and the choice of output unit in Sec. 6.2.2.

and, in 6.2.2,

Unfortunately, mean squared error and mean absolute error often lead to poor results when used with gradient-based optimization. Some output units that saturate produce very small gradients when combined with these cost functions. This is one reason that the cross-entropy cost function is more popular than mean squared error or mean absolute error, even when it is not necessary to estimate an entire distribution p(y|x).

(both excerpts are from chapter 6).

jld
sumber
1
Saya sangat suka Anda membantu saya untuk menurunkan turunan dan goni. Saya akan memeriksanya lebih hati-hati besok.
Haitao Du
1
@hxd1011 you're very welcome, and thanks for the link to that older question of yours! I've really been meaning to go through this more carefully so this was a great excuse :)
jld
1
I carefully read the math and verified with code. I found Hessian for squared loss does not match the numerical approximation. Could you check it? I am more than happy to show you the code if you want.
Haitao Du
@hxd1011 I just went through the derivation again and I think there's a sign error: for HS I think everywhere that I have yi2(1yi)y^i+3y^i2 it should be yi2(1+yi)y^i+3y^i2. Could you recheck and tell me if that fixes it? Thanks a lot for the correction.
jld
@hxd1011 glad that fixed it! thanks again for finding that
jld
5

I would thank to thank @whuber and @Chaconne for help. Especially @Chaconne, this derivation is what I wished to have for years.

The problem IS in the optimization part. If we set the random seed to 1, the default BFGS will not work. But if we change the algorithm and change the max iteration number it will work again.

As @Chaconne mentioned, the problem is squared loss for classification is non-convex and harder to optimize. To add on @Chaconne's math, I would like to present some visualizations on to logistic loss and squared loss.

We will change the demo data from mtcars, since the original toy example has 3 coefficients including the intercept. We will use another toy data set generated from mlbench, in this data set, we set 2 parameters, which is better for visualization.

Here is the demo

  • The data is shown in the left figure: we have two classes in two colors. x,y are two features for the data. In addition, we use red line to represent the linear classifier from logistic loss, and the blue line represent the linear classifier from squared loss.

  • The middle figure and right figure shows the contour for logistic loss (red) and squared loss (blue). x, y are two parameters we are fitting. The dot is the optimal point found by BFGS.

enter image description here

From the contour we can easily see how why optimizing squared loss is harder: as Chaconne mentioned, it is non-convex.

Here is one more view from persp3d.

enter image description here


Code

set.seed(0)
d=mlbench::mlbench.2dnormals(50,2,r=1)
x=d$x
y=ifelse(d$classes==1,1,0)

lg_loss <- function(w){
  p=plogis(x %*% w)
  L=-y*log(p)-(1-y)*log(1-p)
  return(sum(L))
}
sq_loss <- function(w){
  p=plogis(x %*% w)
  L=sum((y-p)^2)
  return(L)
}

w_grid_v=seq(-15,15,0.1)
w_grid=expand.grid(w_grid_v,w_grid_v)

opt1=optimx::optimx(c(1,1),fn=lg_loss ,method="BFGS")
z1=matrix(apply(w_grid,1,lg_loss),ncol=length(w_grid_v))

opt2=optimx::optimx(c(1,1),fn=sq_loss ,method="BFGS")
z2=matrix(apply(w_grid,1,sq_loss),ncol=length(w_grid_v))

par(mfrow=c(1,3))
plot(d,xlim=c(-3,3),ylim=c(-3,3))
abline(0,-opt1$p2/opt1$p1,col='darkred',lwd=2)
abline(0,-opt2$p2/opt2$p1,col='blue',lwd=2)
grid()
contour(w_grid_v,w_grid_v,z1,col='darkred',lwd=2, nlevels = 8)
points(opt1$p1,opt1$p2,col='darkred',pch=19)
grid()
contour(w_grid_v,w_grid_v,z2,col='blue',lwd=2, nlevels = 8)
points(opt2$p1,opt2$p2,col='blue',pch=19)
grid()


# library(rgl)
# persp3d(w_grid_v,w_grid_v,z1,col='darkred')
Haitao Du
sumber
2
I don't see any non-convexity on the third subplot of your first figure...
amoeba says Reinstate Monica
@amoeba I thought convex contour is more like ellipse, two U shaped curve back to back is non-convex, is that right?
Haitao Du
2
No, why? Maybe it's a part of a larger ellipse-like contour? I mean, it might very well be non-convex, I am just saying that I do not see it on this particular figure.
amoeba says Reinstate Monica