上一次的分享我們提到了神經(jīng)網(wǎng)絡(luò)的幾個(gè)基本概念,其中提到了隨機(jī)梯度下降(SGD)算法是神經(jīng)網(wǎng)絡(luò)學(xué)習(xí)(或者更通用的,一般性參數(shù)優(yōu)化問題)的主流方法。概念上,神經(jīng)網(wǎng)絡(luò)的學(xué)習(xí)非常簡(jiǎn)單,可以被歸納為下面的步驟:(a) 構(gòu)造神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)(選擇層數(shù)、激活函數(shù)等);(b) 初始化構(gòu)造出的神經(jīng)網(wǎng)絡(luò)參數(shù)w;(c) 對(duì)于給定的訓(xùn)練樣本(x,y)與當(dāng)前的w,計(jì)算梯度▽w
神經(jīng)網(wǎng)絡(luò)的初學(xué)者往往會(huì)發(fā)現(xiàn),上述四個(gè)步驟當(dāng)中,對(duì)于給定樣本(x,y),計(jì)算其梯度是最不直觀的一個(gè)步驟。本文我們玻森(bosonnlp.com)的討論就圍繞解決梯度▽w的核心算法:后向傳播算法來展開。
首先理清一個(gè)概念,步驟(d)的梯度下降算法是一種優(yōu)化算法,而我們要討論的后向傳播算法,是計(jì)算步驟(c)中所需要梯度▽w的一種算法。下面的討論,我們首先完成單參數(shù)(即只有一個(gè)參數(shù)w∈R需要學(xué)習(xí))的特例情況下的推導(dǎo),進(jìn)而通過動(dòng)態(tài)規(guī)劃(Dynamic programming)思想,將其推導(dǎo)泛化到多變量的情況。需要注意的是,雖然后向傳播概念上并不復(fù)雜,所用到的數(shù)學(xué)工具也很基本,但由于所涉及的變量較多、分層等特點(diǎn),在推導(dǎo)的時(shí)候需要比較仔細(xì),類似繡花。
特例
在討論后向傳播算法之前,我們簡(jiǎn)單回顧一下單變量微積分中的求導(dǎo)規(guī)則。來看個(gè)例子,假設(shè)我們有一個(gè)極端簡(jiǎn)化的網(wǎng)絡(luò),其中只有一個(gè)需要學(xué)習(xí)的參數(shù)w,形式如下
并且假設(shè)損失函數(shù)Cost為平方誤差(MSE)。
假設(shè)我們只有一個(gè)訓(xùn)練樣本(x,y)=(1,1)。因?yàn)檫@個(gè)形式非常簡(jiǎn)單,我們?cè)囋噷⒃摌颖局苯訋霌p失函數(shù):
顯然當(dāng)w=0時(shí),我們可以讓損失函數(shù)為0,達(dá)到最優(yōu)。下面讓我們假裝不知道最優(yōu)解,考慮如何用梯度下降方法來求解。假設(shè)我們猜w0=2為最優(yōu),帶入計(jì)算得到
嗯,不算太壞的一個(gè)初始值。讓我們計(jì)算其梯度,或者損失函數(shù)關(guān)于w的導(dǎo)數(shù)。
設(shè)置學(xué)習(xí)率參數(shù)η=0.02,我們可以通過梯度下降方法來不斷改進(jìn)w,以達(dá)到降低損失函數(shù)的目的。三十個(gè)迭代的損失函數(shù)變化如下:
生成上圖采用的是如下Python代碼
import matplotlib.pyplot as pltw0, eta, n_iter = 2, 0.02, 30gradient_w = lambda w: 2*(w**3)cost = lambda w: 0.5*(w**4)costs = []w = w0for i in range(n_iter): costs.append(cost(w)) w = w - eta*gradient_w(w)# gradientdescentplt.plot(range(n_iter), costs)
可以發(fā)現(xiàn),經(jīng)過30次迭代后,我們的參數(shù)w從初始的2改進(jìn)到了0.597,正在接近我們的最優(yōu)目標(biāo)w?=0。
對(duì)于一般f(w)的情況
回憶一下,上面的結(jié)果是基于我們給定 y=w2x+1.下得到的,注意這里我們假設(shè)輸入信號(hào)x為常量。我們將上面的求解步驟做一點(diǎn)點(diǎn)泛化。
重復(fù)上面的求解
關(guān)于w求導(dǎo),
注意,上面求導(dǎo)用到了鏈?zhǔn)椒▌t(Chain Rule),即
或者寫成偏導(dǎo)數(shù)形式:
對(duì)于一般性損失函數(shù)的情況
上式推導(dǎo)基于損失函數(shù)為平方最小下得出,那么我們?cè)俜夯稽c(diǎn),對(duì)于任意給定的可導(dǎo)損失函數(shù),其關(guān)于w的梯度:
其中▽Cost(f(w),ˉy)是損失函數(shù)關(guān)于f(w)的導(dǎo)數(shù)。實(shí)際上這個(gè)形式很通用,對(duì)于任意特定的損失函數(shù)和神經(jīng)網(wǎng)絡(luò)的激活函數(shù),都可以通過這個(gè)式子進(jìn)行梯度計(jì)算。譬如,對(duì)于一個(gè)有三層的神經(jīng)網(wǎng)絡(luò)
同樣通過鏈?zhǔn)椒▌t,
上式看上去比較復(fù)雜,我們可以在符號(hào)上做一點(diǎn)簡(jiǎn)化。令每一層網(wǎng)絡(luò)得到的激活函數(shù)結(jié)果為ai,即a1=f1(w),a2=f2(f1(w)), 那么:
即:不論復(fù)合函數(shù)f本身有多么復(fù)雜,我們都可以將其導(dǎo)數(shù)拆解成每一層函數(shù)fi的導(dǎo)數(shù)的乘積。
上面的推導(dǎo)我們給出了當(dāng)神經(jīng)網(wǎng)絡(luò)僅僅被一個(gè)可學(xué)習(xí)參數(shù)w所刻畫的情況。一句話總結(jié),在單參數(shù)的網(wǎng)絡(luò)推導(dǎo)中,我們真正用到的唯一數(shù)學(xué)工具就是鏈?zhǔn)椒▌t。實(shí)際問題中,我們面對(duì)的參數(shù)往往是數(shù)以百萬計(jì)的,這也就是為什么我們無法采用直覺去“猜”到最優(yōu)值,而需要用梯度下降方法的原因。下面我考慮在多參數(shù)情況下,如何求解梯度。
首先,不是一般性的,我們假設(shè)所構(gòu)建的為一個(gè)L層的神經(jīng)網(wǎng)絡(luò),其中每一層神經(jīng)網(wǎng)絡(luò)都經(jīng)過線性變換和非線性變換兩個(gè)步驟(為簡(jiǎn)化推導(dǎo),這里我們略去對(duì)bias項(xiàng)的考慮):
定義網(wǎng)絡(luò)的輸入a0=x∈Rn,而aL作為輸出層。一般的,我們令網(wǎng)絡(luò)第l層具有nl個(gè)節(jié)點(diǎn),那么al,zl∈Rnl,Wl∈Rnl×nl?1。注意此時(shí)我們網(wǎng)絡(luò)共有N=n0n1+?+nL?1nL個(gè)參數(shù)需要優(yōu)化。
為了求得梯度▽w,我們關(guān)心參數(shù)Wlji關(guān)于損失函數(shù)的的導(dǎo)數(shù):?Cost?Wlji,但似乎難以把Wlji簡(jiǎn)單地與損失函數(shù)聯(lián)系起來。問題在哪里呢?事實(shí)上,在單參數(shù)的情況下,我們通過鏈?zhǔn)椒▌t,成功建立第一層網(wǎng)絡(luò)的w參數(shù)與最終損失函數(shù)的聯(lián)系。譬如,f(w)=f2(f1(w)),w的改變影響f1函數(shù)的值,而連鎖反應(yīng)影響到f2的函數(shù)結(jié)果。那么,對(duì)于Wlji值的改變,會(huì)影響zlj,從而影響alj。通過Wl+1的線性變換(因?yàn)?span style="color: inherit;">zl+1=Wl+1al),alj的改變將會(huì)影響到每一個(gè)zl+1k(1≤k≤nl+1)。
將上面的思路寫下來:
可以通過上式不斷展開進(jìn)行其梯度計(jì)算。這個(gè)方式相當(dāng)于我們枚舉了每一條Wlji改變對(duì)最終損失函數(shù)影響的路徑。通過簡(jiǎn)單使用鏈?zhǔn)椒▌t,我們得到了一個(gè)指數(shù)級(jí)復(fù)雜度的梯度計(jì)算方法。稍仔細(xì)觀察可以發(fā)現(xiàn),這個(gè)是一個(gè)典型的遞歸結(jié)構(gòu)(為什么呢?因?yàn)?span style="color: inherit;">zl=Wlal?1定義的是一個(gè)遞歸結(jié)構(gòu)),可以采用動(dòng)態(tài)規(guī)劃(Dynamic programming)方法,通過記錄子問題答案而進(jìn)行快速求解。設(shè)δli=?Cost?zli用于動(dòng)態(tài)規(guī)劃的狀態(tài)記錄。我們先解決最后一層的邊界情況:
上式為通用形式。對(duì)于Sigmoid, Tanh等形式的element-wise激活函數(shù),因?yàn)榭梢詫懗?span style="color: inherit;">aLj=fL(zLj)的形式,所示上式可以簡(jiǎn)化為:
即該情況下,最后一層的關(guān)于zLi的導(dǎo)數(shù)與損失函數(shù)在aLi導(dǎo)數(shù)和最后一層激活函數(shù)在zLi的導(dǎo)數(shù)相關(guān)。注意當(dāng)選擇了具體的損失函數(shù)和每層的激活函數(shù)后,▽Cost與f′i也被唯一確定了。下面我們看看動(dòng)態(tài)規(guī)劃的狀態(tài)轉(zhuǎn)移情況:
成功建立δl與δl+1的遞推關(guān)系,所以整個(gè)網(wǎng)絡(luò)的δli可以被計(jì)算出。在確定了δli后,我們的對(duì)于任意參數(shù)Wlji的導(dǎo)數(shù)可以被簡(jiǎn)單表示出:
至此,我們通過鏈?zhǔn)椒▌t和動(dòng)態(tài)規(guī)劃的思想,不失一般性的得到了后向傳播算法的推導(dǎo)。
1. 后向傳播算法的時(shí)間復(fù)雜度是多少?
不難看出,為了進(jìn)行后向傳播,我們首先需要計(jì)算每一層的激活函數(shù),即al(1≤l≤L),這一步與后向傳播相對(duì),通常稱為前向傳播,復(fù)雜度為O(N),與網(wǎng)絡(luò)中參數(shù)的個(gè)數(shù)相當(dāng)。而后向傳播的步驟,通過我們的狀態(tài)轉(zhuǎn)移的推導(dǎo),也可以看出其復(fù)雜度為O(N),所以總的時(shí)間復(fù)雜度為O(N)。需要注意的是,采用mini-batch的方式優(yōu)化時(shí),我們會(huì)將b個(gè)樣本打包進(jìn)行計(jì)算。這本質(zhì)上將后向傳播的矩陣-向量乘積變成了矩陣-矩陣乘積。對(duì)于任意兩個(gè)n×n的矩陣的乘法,目前理論最優(yōu)復(fù)雜度為O(n2.3728)的類Coppersmith–Winograd算法。這類算法由于常數(shù)巨大,不能很好利用GPU并行等限制,并沒有在真正在機(jī)器學(xué)習(xí)或數(shù)值計(jì)算領(lǐng)域有應(yīng)用。尋求智力挑戰(zhàn)的朋友可閱讀Powers of tensors and fast matrix multiplication。
2. 對(duì)于后向傳播的學(xué)習(xí)算法,生物上是否有類似的機(jī)制?
這是一個(gè)有爭(zhēng)議的問題。Hinton教授在其How to do backpropagation in a brain演講當(dāng)中,講到了人們對(duì)于后向傳播不能在生物學(xué)上實(shí)現(xiàn)的三個(gè)原因:a) 神經(jīng)元之間不傳播實(shí)數(shù)信號(hào),而通過尖峰信號(hào)(spikes)溝通。Hinton解釋說通過Poisson過程,意味著可以傳遞實(shí)數(shù)信號(hào),并且采用spike而不是實(shí)數(shù)進(jìn)行信號(hào)傳遞是一個(gè)更魯邦的過程。b) 神經(jīng)元不會(huì)求導(dǎo)(f′l(?))。Hinton說通過構(gòu)建filter可以實(shí)現(xiàn)(數(shù)值)求導(dǎo)過程。c) 神經(jīng)元不是對(duì)稱的?或者說神經(jīng)元連接不是一個(gè)無向圖而是有向圖。這意味著通過前向傳播Wlal?1與后向傳播的Wl并不應(yīng)該是一個(gè)W。Hinton教授解釋說,通過一些數(shù)值實(shí)驗(yàn)發(fā)現(xiàn),其實(shí)即便前向后向的W不對(duì)稱(比如讓后向傳播的W為固定的隨機(jī)矩陣),采用類似的梯度算法也可以收斂到不錯(cuò)的解。我不同意Hinton教授的觀點(diǎn)。其解釋在邏輯上混淆一個(gè)基本常識(shí):大腦可以做到并不意味著大腦事實(shí)是這樣完成運(yùn)算的。其實(shí)我們已經(jīng)看到,通過與非門也可以完成所有的函數(shù)運(yùn)算,但這并不代表我們大腦里面一定裝載了10億個(gè)與非門。而有大量證據(jù)表明(如能耗,小樣本學(xué)習(xí)),后向傳播算法與真實(shí)大腦學(xué)習(xí)的機(jī)制相去甚遠(yuǎn)。所以我覺得更合理的對(duì)待,仍然是將后向傳播作為一種高效計(jì)算嵌套函數(shù)梯度的數(shù)值算法工具,而不必強(qiáng)行將其附會(huì)成大腦的工作原理。
下一次分享,我們主要從優(yōu)化算法的正則化(Regularization)的角度進(jìn)行討論,歡迎關(guān)注~
聯(lián)系客服