Goni fungsi logistik

15

Saya mengalami kesulitan untuk mendapatkan Hessian dari fungsi objektif, l(θ) , dalam regresi logistik di mana adalah: l(θ)

l(θ)=i=1m[yilog(hθ(xi))+(1yi)log(1hθ(xi))]

hθ(x) adalah fungsi logistik. Hessian adalah . Saya mencoba menurunkannya dengan menghitung , tetapi kemudian tidak jelas bagi saya bagaimana cara mendapatkan notasi matriks dari .XTDX2l(θ)θiθj2l(θ)θiθj

Adakah yang tahu cara bersih dan mudah untuk mendapatkan ?XTDX

DSKim
sumber
3
apa yang kau bisa untuk ? 2lθiθj
Glen_b -Reinstate Monica
1
Berikut ini adalah kumpulan slide yang menunjukkan perhitungan tepat yang Anda cari: sites.stat.psu.edu/~jiali/course/stat597e/notes2/logit.pdf
Saya menemukan video yang luar biasa yang menghitung langkah Hessian demi langkah. Regresi logistik (biner) - menghitung Hessian
Naomi

Jawaban:

19

Di sini saya mendapatkan semua sifat dan identitas yang diperlukan agar solusi dapat berdiri sendiri, tetapi selain itu derivasi ini bersih dan mudah. Mari kita meresmikan notasi kita dan menulis fungsi kerugian sedikit lebih kompak. Pertimbangkan m sampel {xi,yi} sehingga xiRd dan yiR . Ingatlah bahwa dalam regresi logistik biner kita biasanya memiliki fungsi hipotesis hθ menjadi fungsi logistik. Secara formal

hθ(xi)=σ(ωTxi)=σ(zi)=11+ezi,

dimana ωRd dan zi=ωTxi . Fungsi kerugian (yang saya percaya OP kehilangan tanda negatif) kemudian didefinisikan sebagai:

l(ω)=i=1m(yilogσ(zi)+(1yi)log(1σ(zi)))

Ada dua sifat penting dari fungsi logistik yang saya peroleh di sini untuk referensi di masa mendatang. Pertama, perhatikan bahwa 1σ(z)=11/(1+ez)=ez/(1+ez)=1/(1+ez)=σ(z) .

Perhatikan juga itu

zσ(z)=z(1+ez)1=ez(1+ez)2=11+ezez1+ez=σ(z)(1σ(z))

Alih-alih mengambil turunan sehubungan dengan komponen, di sini kami akan bekerja langsung dengan vektor (Anda dapat meninjau turunan dengan vektor di sini ). Hessian dari fungsi kerugian l(ω) diberikan oleh 2l(ω) , tetapi pertama-tama ingat bahwazω=xTωω=xTdanzωT=ωTxωT=x .

Mari li(ω)=yilogσ(zi)(1yi)log(1σ(zi)) . Menggunakan properti yang kami peroleh di atas dan aturan rantai

logσ(zi)ωT=1σ(zi)σ(zi)ωT=1σ(zi)σ(zi)ziziωT=(1σ(zi))xilog(1σ(zi))ωT=11σ(zi)(1σ(zi))ωT=σ(zi)xi

Sekarang sepele untuk menunjukkan itu

li(ω)=li(ω)ωT=yixi(1σ(zi))+(1yi)xiσ(zi)=xi(σ(zi)yi)

Wah!

Langkah terakhir kami adalah menghitung Hessian

2li(ω)=li(ω)ωωT=xixiTσ(zi)(1σ(zi))

For m samples we have 2l(ω)=i=1mxixiTσ(zi)(1σ(zi)). This is equivalent to concatenating column vectors xiRd into a matrix X of size d×m such that i=1mxixiT=XXT. The scalar terms are combined in a diagonal matrix D such that Dii=σ(zi)(1σ(zi)). Finally, we conclude that

H(ω)=2l(ω)=XDXT

A faster approach can be derived by considering all samples at once from the beginning and instead work with matrix derivatives. As an extra note, with this formulation it's trivial to show that l(ω) is convex. Let δ be any vector such that δRd. Then

δTH(ω)δ=δT2l(ω)δ=δTXDXTδ=δTXD(δTX)T=δTDX20

since D>0 and δTX0. This implies H is positive-semidefinite and therefore l is convex (but not strongly convex).

Manuel Morales
sumber
2
In the last equation, shouldn't it be ||δD1/2X|| since XDX = XD1/2(XD1/2)?
appletree
1
Shouldn't it be XTDX?
Chintan Shah