Was ist die GELU-Aktivierung?


17

Ich habe BERT-Artikel durchgearbeitet, in dem GELU (Gaußsche Fehler- Lineareinheit ) verwendet wird, wobei die Gleichung wie folgt lautet:

GELU(x)=xP(Xx)=xΦ(x).
das entspricht Könnten Sie die Gleichung vereinfachen und erklären, wie sie angenommen wurde.
0.5x(1+tanh[2/π(x+0.044715x3)])

Antworten:


16

GELU-Funktion

Wir können die kumulative Verteilung vonN(0,1) , Φ(x)GELU(x):=x P (Xx)=xΦ(x)=0,5x ( 1 + erf ( x , dh wie folgt erweitern: Φ(x)

GELU(x):=xP(Xx)=xΦ(x)=0.5x(1+erf(x2))

Beachten Sie, dass dies eine Definition ist , keine Gleichung (oder eine Beziehung). Autoren haben einige Begründungen für diesen Vorschlag vorgesehen, eine stochastische zB Analogie , aber mathematisch, das ist nur eine Definition.

Hier ist die Handlung von GELU:

Tanh-Annäherung

Für diese Art von numerischen Approximationen besteht die Schlüsselidee darin, eine ähnliche Funktion zu finden (hauptsächlich basierend auf Erfahrung), sie zu parametrisieren und dann an eine Reihe von Punkten aus der ursprünglichen Funktion anzupassen.

Zu wissen, dass sehr nahe anerf(x)tanh(x)

und die erste Ableitung von stimmt mit der von bei überein , was , passen wir (oder mit mehr Begriffen) zu einer Menge von Punkten .erf(x2)tanh(2πx)x=02π

tanh(2π(x+ax2+bx3+cx4+dx5))
(xi,erf(xi2))

Ich habe diese Funktion an 20 Samples zwischen angepasst (unter Verwendung dieser Site ), und hier sind die Koeffizienten:(1.5,1.5)

Durch Setzen von wurde auf geschätzt . Bei mehr Proben aus einem größeren Bereich (an dieser Stelle sind nur 20 zulässig) liegt der Koeffizient näher bei des . Endlich bekommen wira=c=d=0b0.04495641b0.044715

GELU(x)=xΦ(x)=0.5x(1+erf(x2))0.5x(1+tanh(2π(x+0.044715x3)))

mit dem mittleren Fehlerquadrat für .108x[10,10]

Beachten Sie, dass der Begriff in den folgenden Parametern enthalten gewesen wäre , wenn wir die Beziehung zwischen den ersten Ableitungen nicht verwendet hätten: was weniger schön ist (weniger analytisch, mehr numerisch)!2π

0.5x(1+tanh(0.797885x+0.035677x3))

Die Parität ausnutzen

Wie von @BookYourLuck vorgeschlagen , können wir die Parität der Funktionen verwenden, um den Raum der Polynome, in denen wir suchen, einzuschränken. Das heißt, da eine ungerade Funktion ist, dh , und auch eine ungerade Funktion ist, befindet sich die Polynomfunktion Inneren sollte auch ungerade sein (sollte nur ungerade Potenzen von ), um erff(x)=f(x)tanhpol(x)tanhx

erf(x)tanh(pol(x))=tanh(pol(x))=tanh(pol(x))erf(x)

Früher haben wir hatten das Glück , mit (fast) am Ende Null - Koeffizienten für gerade Potenzen und , aber im Allgemeinen, könnte dies zu geringer Qualität Annäherungen führt , dass zum Beispiel hat einen Begriff wie , dass wird durch zusätzliche Ausdrücke (gerade oder ungerade) aufgehoben, anstatt einfach wählen .x2x40.23x20x2

Sigmoid-Approximation

erf(x)2(σ(x)12)104x[10,10]

Hier ist ein Python-Code zum Generieren von Datenpunkten, Anpassen der Funktionen und Berechnen der mittleren quadratischen Fehler:

import math
import numpy as np
import scipy.optimize as optimize


def tahn(xs, a):
    return [math.tanh(math.sqrt(2 / math.pi) * (x + a * x**3)) for x in xs]


def sigmoid(xs, a):
    return [2 * (1 / (1 + math.exp(-a * x)) - 0.5) for x in xs]


print_points = 0
np.random.seed(123)
# xs = [-2, -1, -.9, -.7, 0.6, -.5, -.4, -.3, -0.2, -.1, 0,
#       .1, 0.2, .3, .4, .5, 0.6, .7, .9, 2]
# xs = np.concatenate((np.arange(-1, 1, 0.2), np.arange(-4, 4, 0.8)))
# xs = np.concatenate((np.arange(-2, 2, 0.5), np.arange(-8, 8, 1.6)))
xs = np.arange(-10, 10, 0.001)
erfs = np.array([math.erf(x/math.sqrt(2)) for x in xs])
ys = np.array([0.5 * x * (1 + math.erf(x/math.sqrt(2))) for x in xs])

# Fit tanh and sigmoid curves to erf points
tanh_popt, _ = optimize.curve_fit(tahn, xs, erfs)
print('Tanh fit: a=%5.5f' % tuple(tanh_popt))

sig_popt, _ = optimize.curve_fit(sigmoid, xs, erfs)
print('Sigmoid fit: a=%5.5f' % tuple(sig_popt))

# curves used in https://mycurvefit.com:
# 1. sinh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))/cosh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))
# 2. sinh(sqrt(2/3.141593)*(x+b*x^3))/cosh(sqrt(2/3.141593)*(x+b*x^3))
y_paper_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + 0.044715 * x**3))) for x in xs])
tanh_error_paper = (np.square(ys - y_paper_tanh)).mean()
y_alt_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + tanh_popt[0] * x**3))) for x in xs])
tanh_error_alt = (np.square(ys - y_alt_tanh)).mean()

# curve used in https://mycurvefit.com:
# 1. 2*(1/(1+2.718281828459^(-(a*x))) - 0.5)
y_paper_sigmoid = np.array([x * (1 / (1 + math.exp(-1.702 * x))) for x in xs])
sigmoid_error_paper = (np.square(ys - y_paper_sigmoid)).mean()
y_alt_sigmoid = np.array([x * (1 / (1 + math.exp(-sig_popt[0] * x))) for x in xs])
sigmoid_error_alt = (np.square(ys - y_alt_sigmoid)).mean()

print('Paper tanh error:', tanh_error_paper)
print('Alternative tanh error:', tanh_error_alt)
print('Paper sigmoid error:', sigmoid_error_paper)
print('Alternative sigmoid error:', sigmoid_error_alt)

if print_points == 1:
    print(len(xs))
    for x, erf in zip(xs, erfs):
        print(x, erf)

Ausgabe:

Tanh fit: a=0.04485
Sigmoid fit: a=1.70099
Paper tanh error: 2.4329173471294176e-08
Alternative tanh error: 2.698034519269613e-08
Paper sigmoid error: 5.6479106346814546e-05
Alternative sigmoid error: 5.704246564663601e-05

2
Warum ist die Annäherung erforderlich? Könnten sie nicht einfach die erf-Funktion verwenden?
SebiSebi

8

Φ(x)=12erfc(x2)=12(1+erf(x2))
erf
erf(x2)tanh(2π(x+ax3))
a0.044715

x[1,1]x

tanh(x)=xx33+o(x3)
erf(x)=2π(xx33)+o(x3).
tanh(2π(x+ax3))=2π(x+(a23π)x3)+o(x3)
erf(x2)=2π(xx36)+o(x3).
x3
a0.04553992412
0.044715

Durch die Nutzung unserer Website bestätigen Sie, dass Sie unsere Cookie-Richtlinie und Datenschutzrichtlinie gelesen und verstanden haben.
Licensed under cc by-sa 3.0 with attribution required.