Saya menggunakan kertas BERT yang menggunakan GELU (Gaussian Error Linear Unit) yang menyatakan persamaan sebagai
yang
sumber
Saya menggunakan kertas BERT yang menggunakan GELU (Gaussian Error Linear Unit) yang menyatakan persamaan sebagai
yang
Kita dapat memperluas distribusi kumulatif dari , yaitu , sebagai berikut:
Perhatikan bahwa ini adalah definisi , bukan persamaan (atau hubungan). Penulis telah memberikan beberapa pembenaran untuk proposal ini, misalnya analogi stokastik , namun secara matematis, ini hanyalah sebuah definisi.
Berikut ini alur cerita GELU:
Untuk jenis perkiraan numerik ini, ide kuncinya adalah menemukan fungsi yang serupa (terutama berdasarkan pengalaman), membuat parameterisasi, dan kemudian memasangnya ke satu set poin dari fungsi aslinya.
Mengetahui bahwa sangat dekat dengan
dan turunan pertama dari bertepatan dengan di , yang merupakan , kami melanjutkan untuk mencocokkan
(atau dengan lebih banyak istilah) ke sekumpulan titik .
Saya telah memasang fungsi ini ke 20 sampel antara ( menggunakan situs ini ), dan berikut adalah koefisiennya:
Dengan mengatur , diperkirakan 0,04495641 . Dengan lebih banyak sampel dari rentang yang lebih luas (situs itu hanya diizinkan 20), koefisien b akan lebih dekat dengan kertas . Akhirnya kita dapatkan0,044715
dengan mean squared error untuk .
Perhatikan bahwa jika kami tidak memanfaatkan hubungan antara derivatif pertama, istilah akan dimasukkan dalam parameter sebagai berikut
Berikut adalah kode Python untuk menghasilkan titik data, menyesuaikan fungsi, dan menghitung rata-rata kesalahan kuadrat:
import math
import numpy as np
import scipy.optimize as optimize
def tahn(xs, a):
return [math.tanh(math.sqrt(2 / math.pi) * (x + a * x**3)) for x in xs]
def sigmoid(xs, a):
return [2 * (1 / (1 + math.exp(-a * x)) - 0.5) for x in xs]
print_points = 0
np.random.seed(123)
# xs = [-2, -1, -.9, -.7, 0.6, -.5, -.4, -.3, -0.2, -.1, 0,
# .1, 0.2, .3, .4, .5, 0.6, .7, .9, 2]
# xs = np.concatenate((np.arange(-1, 1, 0.2), np.arange(-4, 4, 0.8)))
# xs = np.concatenate((np.arange(-2, 2, 0.5), np.arange(-8, 8, 1.6)))
xs = np.arange(-10, 10, 0.001)
erfs = np.array([math.erf(x/math.sqrt(2)) for x in xs])
ys = np.array([0.5 * x * (1 + math.erf(x/math.sqrt(2))) for x in xs])
# Fit tanh and sigmoid curves to erf points
tanh_popt, _ = optimize.curve_fit(tahn, xs, erfs)
print('Tanh fit: a=%5.5f' % tuple(tanh_popt))
sig_popt, _ = optimize.curve_fit(sigmoid, xs, erfs)
print('Sigmoid fit: a=%5.5f' % tuple(sig_popt))
# curves used in https://mycurvefit.com:
# 1. sinh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))/cosh(sqrt(2/3.141593)*(x+a*x^2+b*x^3+c*x^4+d*x^5))
# 2. sinh(sqrt(2/3.141593)*(x+b*x^3))/cosh(sqrt(2/3.141593)*(x+b*x^3))
y_paper_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + 0.044715 * x**3))) for x in xs])
tanh_error_paper = (np.square(ys - y_paper_tanh)).mean()
y_alt_tanh = np.array([0.5 * x * (1 + math.tanh(math.sqrt(2/math.pi)*(x + tanh_popt[0] * x**3))) for x in xs])
tanh_error_alt = (np.square(ys - y_alt_tanh)).mean()
# curve used in https://mycurvefit.com:
# 1. 2*(1/(1+2.718281828459^(-(a*x))) - 0.5)
y_paper_sigmoid = np.array([x * (1 / (1 + math.exp(-1.702 * x))) for x in xs])
sigmoid_error_paper = (np.square(ys - y_paper_sigmoid)).mean()
y_alt_sigmoid = np.array([x * (1 / (1 + math.exp(-sig_popt[0] * x))) for x in xs])
sigmoid_error_alt = (np.square(ys - y_alt_sigmoid)).mean()
print('Paper tanh error:', tanh_error_paper)
print('Alternative tanh error:', tanh_error_alt)
print('Paper sigmoid error:', sigmoid_error_paper)
print('Alternative sigmoid error:', sigmoid_error_alt)
if print_points == 1:
print(len(xs))
for x, erf in zip(xs, erfs):
print(x, erf)
Keluaran:
Tanh fit: a=0.04485
Sigmoid fit: a=1.70099
Paper tanh error: 2.4329173471294176e-08
Alternative tanh error: 2.698034519269613e-08
Paper sigmoid error: 5.6479106346814546e-05
Alternative sigmoid error: 5.704246564663601e-05
sumber