Hessisch logistische Funktion


15

Ich habe Schwierigkeiten die Hessian der Zielfunktion, abzuleiten l ( θ )l(θ) , in logistischer Regression , wo L ( θ )l(θ) ist: L ( θ ) = m Σ i = 1 [ y i log ( h θ ( x i ) ) + ( 1 - y i ) log ( 1 - h & thgr; ( x i ) ) ]

l(θ)=i=1m[yilog(hθ(xi))+(1yi)log(1hθ(xi))]

h θ ( x )hθ(x) ist eine logistische Funktion. Die Hessian ist X T D XXTDX . Ich habe versucht, es durch Berechnung von2 l ( θ ) abzuleitenθ iθ j2l(θ)θiθj , aber dann war mir nicht klar, wie ich von2l(θ)zur Matrixnotation komme& thgr; i& thgr; j2l(θ)θiθj .

Kennt jemand eine saubere und einfache Möglichkeit, X T D X abzuleiten XTDX?


3
was hast du dafür bekommen ?2lθ iθ j ? 2lθiθj
Glen_b

1
Hier ist ein guter Satz Folien, die die genaue Berechnung zeigen, die Sie suchen: sites.stat.psu.edu/~jiali/course/stat597e/notes2/logit.pdf

Ich habe ein wunderbares Video gefunden, das den Hessischen Schritt für Schritt berechnet. Logistische Regression (binär) - Berechnung des Hessischen
Naomi

Antworten:


19

Hier leite ich alle notwendigen Eigenschaften und Identitäten ab, damit die Lösung in sich geschlossen ist, aber ansonsten ist diese Herleitung sauber und einfach. Lassen Sie uns unsere Notation formalisieren und die Verlustfunktion etwas kompakter schreiben. Betrachten mm Proben { x i , y i }{xi,yi} , so dass x iR dxiRd und y iRyiR . Denken Sie daran, dass in der binären logistischen Regression typischerweise die Hypothesenfunktion h θhθ die logistische Funktion ist. Formal

hθ(xi)=σ(ωTxi)=σ(zi)=11+ezi,

hθ(xi)=σ(ωTxi)=σ(zi)=11+ezi,

where ωRdωRd and zi=ωTxizi=ωTxi. The loss function (which I believe OP's is missing a negative sign) is then defined as:

l(ω)=mi=1(yilogσ(zi)+(1yi)log(1σ(zi)))

l(ω)=i=1m(yilogσ(zi)+(1yi)log(1σ(zi)))

There are two important properties of the logistic function which I derive here for future reference. First, note that 1σ(z)=11/(1+ez)=ez/(1+ez)=1/(1+ez)=σ(z)1σ(z)=11/(1+ez)=ez/(1+ez)=1/(1+ez)=σ(z).

Also note that

zσ(z)=z(1+ez)1=ez(1+ez)2=11+ezez1+ez=σ(z)(1σ(z))

zσ(z)=z(1+ez)1=ez(1+ez)2=11+ezez1+ez=σ(z)(1σ(z))

Instead of taking derivatives with respect to components, here we will work directly with vectors (you can review derivatives with vectors here). The Hessian of the loss function l(ω)l(ω) is given by 2l(ω)⃗ 2l(ω), but first recall that zω=xTωω=xTzω=xTωω=xT and zωT=ωTxωT=xzωT=ωTxωT=x.

Let li(ω)=yilogσ(zi)(1yi)log(1σ(zi))li(ω)=yilogσ(zi)(1yi)log(1σ(zi)). Using the properties we derived above and the chain rule

logσ(zi)ωT=1σ(zi)σ(zi)ωT=1σ(zi)σ(zi)ziziωT=(1σ(zi))xilog(1σ(zi))ωT=11σ(zi)(1σ(zi))ωT=σ(zi)xi

logσ(zi)ωTlog(1σ(zi))ωT=1σ(zi)σ(zi)ωT=1σ(zi)σ(zi)ziziωT=(1σ(zi))xi=11σ(zi)(1σ(zi))ωT=σ(zi)xi

It's now trivial to show that

li(ω)=li(ω)ωT=yixi(1σ(zi))+(1yi)xiσ(zi)=xi(σ(zi)yi)

⃗ li(ω)=li(ω)ωT=yixi(1σ(zi))+(1yi)xiσ(zi)=xi(σ(zi)yi)

whew!

Our last step is to compute the Hessian

2li(ω)=li(ω)ωωT=xixTiσ(zi)(1σ(zi))

⃗ 2li(ω)=li(ω)ωωT=xixTiσ(zi)(1σ(zi))

For mm samples we have 2l(ω)=mi=1xixTiσ(zi)(1σ(zi))⃗ 2l(ω)=mi=1xixTiσ(zi)(1σ(zi)). This is equivalent to concatenating column vectors xiRdxiRd into a matrix XX of size d×md×m such that mi=1xixTi=XXTmi=1xixTi=XXT. The scalar terms are combined in a diagonal matrix DD such that Dii=σ(zi)(1σ(zi))Dii=σ(zi)(1σ(zi)). Finally, we conclude that

H(ω)=2l(ω)=XDXT

H⃗ (ω)=⃗ 2l(ω)=XDXT

A faster approach can be derived by considering all samples at once from the beginning and instead work with matrix derivatives. As an extra note, with this formulation it's trivial to show that l(ω)l(ω) is convex. Let δδ be any vector such that δRdδRd. Then

δTH(ω)δ=δT2l(ω)δ=δTXDXTδ=δTXD(δTX)T=δTDX20

δTH⃗ (ω)δ=δT⃗ 2l(ω)δ=δTXDXTδ=δTXD(δTX)T=δTDX20

since D>0D>0 and δTX0δTX0. This implies HH is positive-semidefinite and therefore ll is convex (but not strongly convex).


2
In the last equation, shouldn't it be ||δD1/2X||||δD1/2X|| since XDXXDX = XD1/2(XD1/2)XD1/2(XD1/2)?
appletree

1
Shouldn't it be XTDX?
Chintan Shah
Durch die Nutzung unserer Website bestätigen Sie, dass Sie unsere Cookie-Richtlinie und Datenschutzrichtlinie gelesen und verstanden haben.
Licensed under cc by-sa 3.0 with attribution required.