該能力簡單來說就是,對于一個預訓練好的大語言模型,遷移到新任務上的時候,只需要給模型輸入幾個示例(示例輸入和示例輸出對),模型就能為新輸入生成正確輸出而不需要對模型做 fine-tuning。
這也引發(fā)了研究人員對該能力產(chǎn)生原因的思考和探索。本文會首先給讀者介紹什么是上下文學習,接著解讀一篇最近由微軟研究院發(fā)布的探索 LLM 上下文學習能力來源的文章[7]
。
GPT-n
系列的模型都屬于自回歸類的語言模型,所謂自回歸模型就是根據(jù)當前輸入預測下一個詞,然后將預測結(jié)果和輸入拼接再當做模型的輸入再預測下一個詞,這樣循環(huán)往復。
而自回歸模型的訓練目標也很簡單,就是從超大規(guī)模語料庫中采樣訓練樣本,模型根據(jù)輸入輸出一個概率向量(概率向量包含所有詞的預測概率,對于GPT-3 模型來說,維度約1千多萬),而因為文本數(shù)據(jù)自帶標注所以我們是知道真實的下一個詞,所以損失函數(shù)就采用得交叉熵。
然后研究人員發(fā)現(xiàn),預訓練好的 GPT-3 模型擁有一項神奇的能力,后來被稱為:上下文學習(In-Context Learning)。
這項能力簡單來說就是,預訓練好的 GPT-3 模型在遷移到新任務上的時候并不需要重新訓練,而只需要提供任務描述(這個任務描述是可選項)接著提供幾個示例(任務查詢和對應答案,以一對對的形式組織),最后加上要模型回答的查詢。將以上內(nèi)容打包一起作為模型的輸入,則模型就能正確輸出最后一個查詢對應的答案。
舉個例子:
比如現(xiàn)在想用 GPT-3 來做個翻譯任務,翻譯英文為法文。輸入的格式如下:
首先第一行是對任務描述,告訴模型要做翻譯,接下來三行就是示例,英文單詞和對應的法文單詞對,最后一行就是待翻譯的英文單詞。將以上內(nèi)容整體作為 GPT-3 的輸入,讓模型去補全輸出就能得到 cheese
對應的法文單詞。
上下文學習非常的靈活,除了上面展示的翻譯任務,還可以做語法修飾甚至寫代碼。而神奇的地方就在于,在 GPT-3 的訓練過程中是并沒有顯式的提供,類似測試階段任務描述加示例這樣的訓練數(shù)據(jù)。
當然 GPT-3 的訓練數(shù)據(jù)量非常巨大(比如包含了 wiki, 書本期刊,reddit 上的討論等等),或許里面就已經(jīng)就包含了各種任務類似結(jié)構(gòu)的數(shù)據(jù),GPT-3 模型容量足夠大能夠?qū)⑺杏柧殧?shù)據(jù)都記了下來。
對于上下文學習能力的成因,目前還是一個開放性的問題。為什么只有大規(guī)模的語言模型才會具備該能力?或許只有模型參數(shù)量大還不夠,還必須要訓練數(shù)據(jù)量也足夠大,模型才能顯現(xiàn)出該能力?
首先來看一個很簡單的任務,就是讓模型直接復制輸入的內(nèi)容。
首先示例個數(shù)設置為 5
個,每個示例輸入包含 5
個不同的小寫單詞(從字母表前 8 個小寫字母中隨機選5個得到),這些單詞用逗號分隔,輸出直接拷貝的輸入,比如:
Input: g, c, b, h, d
Output: g, c, b, h, d
Input: b, g, d, h, a
Output: b, g, d, h, a
Input: f, c, d, e, h
Output: f, c, d, e, h
Input: c, f, g, h, d
Output: c, f, g, h, d
Input: e, f, b, g, d
Output: e, f, b, g, d
Input: a, b, c, d, e
Output:
期待模型的輸出是:
a, b, c, d, e
接著對于5個字母順序的所有可能情況 (8!/3!=6720
,從8個樣本中選5個總的組合數(shù))也就是最后 input
的位置將 6720
個情況都測試了,GPT-3 模型的準確率是 100%
。
接著用 GPT-3 系列最小的模型 text-ada-001
來做這個任務,獲得了 6705/6720 = 99.78%
的準確率,一定程度上證明了模型規(guī)模的重要性。
接著來看 GPT-3 在更復雜一些的任務上的表現(xiàn)。
這個任務是對日期做格式化,將 年-月-日
格式的輸入格式轉(zhuǎn)化成 !月!日!年!
,其中年份四位數(shù),月份和日子是兩位數(shù),比如:
上面這個例子中,示例個數(shù)是3,最后是待測試的日期 2005-07-23
。
為什么選擇日期格式化這個任務呢?
首先足夠簡單,日期包含三個隨機變量(年月日),它們長度都是固定的,而且設定的輸出格式也不是正常的格式,所以訓練數(shù)據(jù)中不太可能包含類似的樣本,也排除了模型可能只是將訓練數(shù)據(jù)都記憶了下來。
接下來看看測試結(jié)果,我們測試了 GPT-3 全系列的模型 [8]
,包括text-ada-001
,text-babbage-001
,text-curie-001
和 text-davinci-003
,模型參數(shù)量依次從小到大排列。
并通過設置不同的上下文示例個數(shù)(對于每個示例個數(shù)的設置,都有2000個測試樣本),記錄各個模型的預測準確率,測試結(jié)果如下:
從圖表展示的結(jié)果來看,固定橫坐標示例個數(shù),則模型越大準確率也越高,模型越大準確率曲線也就更加的陡峭。而對于每個模型來說,增加上下文的示例個數(shù)也能有效提升準確率。
不過仔細觀察圖表可以發(fā)現(xiàn),即使增大示例個數(shù)和模型,模型的精確度也只是無限接近 100%
但還是達不到。
接下來我們分析一下,GPT-3 預測錯誤的樣本都包含哪些類型。
這里我們選取了前10個最常見的錯誤類型,其中圖標中的 DD
表示兩位數(shù)的日子,MM
表示兩位數(shù)字的月份,mm
一位數(shù)的月份,YYYY
則是四位數(shù)的年份,YY
是兩位數(shù)的年份,**
則是其他的兩位數(shù)。
從實驗結(jié)果上看,隨著上下文示例個數(shù)的增加,預測錯誤的樣本個數(shù)也在下降。
而模型預測錯誤最多的格式是,將日期放在月份前面,這也能理解,因為訓練數(shù)據(jù)中常見的日期格式都是先日期,再月份,最后年份。
繼續(xù)分析模型預測錯誤的樣本,發(fā)現(xiàn)一個有趣的結(jié)果:
就是對于 2019 年份的輸入,模型是最容易預測錯誤的,這也能理解因為訓練數(shù)據(jù)中 2019 年份的數(shù)據(jù)不多。
這個測試任務就是將實體做一個不正常的重新分類,比如:
volleyball: animal
onions: sport
broccoli: sport
hockey: animal
kale: sport
beet: sport
golf: animal
horse: plant/vegetable
corn: sport
football: animal
luge: animal
bowling: animal
beans: sport
archery: animal
sheep: plant/vegetable
zucchini: sport
goldfish: plant/vegetable
duck: plant/vegetable
leopard: plant/vegetable
lacrosse: animal
badminton: animal
lion: plant/vegetable
celery: sport
porcupine: plant/vegetable
wolf: plant/vegetable
lettuce: sport
camel: plant/vegetable
billiards: animal
zebra: plant/vegetable
radish: sport
輸入示例中包含了 [animal(動物), plant/vegetable(植物/蔬菜), sport(運動)]
三種類型標簽。現(xiàn)在將它們原來的標簽映射打亂,將動物映射為植物(duck: plant/vegetable
),將運動映射為動物(golf: animal
),將植物映射為運動(beans: sport
)。
接著測試 GPT-3 能否根據(jù)僅有的示例學會預測新的映射,下面是測試結(jié)果:
llama: plant/vegetable ?
cat: plant/vegetable ?
elephant: plant/vegetable ?
monkey: plant/vegetable ?
panda: plant/vegetable ?
cucumber: sport ?
peas: sport ?
tomato: sport ?
spinach: sport ?
carrots: sport ?
rugby: animal ?
cycling: animal ?
baseball: animal ?
tennis: animal ?
judo: animal ?
可以看到 GPT-3 能正確輸出映射關(guān)系。而即使將標簽改成無意義的符號比如 [^*, #@#, !!~]
,模型同樣可以輸出正確的預測。
經(jīng)過上面對上下文學習的介紹,相信讀者也能體會到其神奇之處。
為什么 LLM 能夠具備該能力?上下文學習的原理究竟是怎樣的呢?
接下來解讀一篇最近微軟研究院發(fā)布的文章[7]
,對于上下文學習能力來源的探究。
文章中提出,關(guān)鍵在于 LLM 中的注意力層(attention layers),在推理過程實現(xiàn)了一個隱式的參數(shù)優(yōu)化過程,這和 fine-tuning 的時候通過梯度下降法顯式優(yōu)化參數(shù)的過程是類似的。
文章[7]
中提出,一個線性的注意力層其實和基于梯度下降法優(yōu)化的全連接層是互為對偶的形式,具體怎么理解呢?
首先文章中定義,全連接層的初始參數(shù)矩陣為 W0
,參數(shù)的梯度矩陣為△W
,維度為 dout × din
。還有當前輸入向量 x
,維度為 din
。則經(jīng)過一次梯度下降法優(yōu)化的全連接層可以表示為:
其中 △W
由上一次的輸入 x'
和上一次全連接層的輸出梯度 e
計算得到:
怎么理解這個梯度的計算公式呢,我們畫個圖:
接下來看基于梯度下降法優(yōu)化的全連接和線性注意力層是怎么聯(lián)系起來的,
我們關(guān)注紅框部分,參數(shù)梯度矩陣 △W
是上一次輸入和上一次輸出梯度的外積求和,這部分可以等價變換為,首先讓上一次輸入xi'T
和當前輸入 x
做內(nèi)積,接著再和 ei
做內(nèi)積最后再求和。
接著如果我們將
ei
看做是一個 value
向量,E
是 value
矩陣xi'T
看做是一個 key
向量,X'
是 key
矩陣x
看做是一個 query
向量其實就等價于是一個線性的注意力層。
上一次輸入xi'T
和 x
先做內(nèi)積,就是相當于 key
矩陣和當前 query
向量做乘法,得到每個 value
向量的權(quán)值,然后每個 ei
和權(quán)值相乘再相加,就是所有 value
向量加權(quán)求和。
上下文學習怎么實現(xiàn)隱式 finetuning
文章中定義,將上下文學習輸入的最后一個詞表示定義為 query token ,維度是 d
則輸入到注意力層之后的 query 向量計算公式如下:
則對于最后一個 token 來說,經(jīng)過一個注意力頭操作的輸出公式如下:
其中 Wv
,WK
和 WQ
都是變換矩陣,維度是 d' × d
,X'
是輸入中示例部分的 token 向量表示,而 X
則表示輸入中示例部分之后又在最后一個詞之前的所有的 token 的向量表示。[X';X]
表示矩陣拼接。
然后論文中簡化了下公式,將注意力計算中的 softmax
操作去掉了,就得到了上面新的公式。
我們關(guān)注上公式的第二到第三行的變換,上圖解釋變換過程:
接著文章中將,輸入中示例部分之后又在最后一個詞之前的所有的 token 的 value
和key
相乘部分定義為 Wzsl
(zsl 表示 Zero-shot Learning,0樣本學習)當做是初始的權(quán)值:
則 Wzsl * q
就相當于是一個0樣本學習的 attention 結(jié)果,因為沒有加上前面示例部分的 attention 結(jié)果。接著就是根據(jù)前面全連接層和 attention 互轉(zhuǎn)的公式可得:
我們看右邊紅框部分的變換,我們將示例部分的 token attention 操作中的
Wv*X'
看做是對應前面全連接上一次計算的輸出梯度Wk*X'
看作是對應前面全連接上一次計算的輸入q
看作是當前的輸入然后就可以把推理得示例部分的 token attention 操作部分看做是對應 Wzsl
初始權(quán)值的更新梯度 △Wicl
(icl 表示 In-Context Learning)。
這就是為什么說 LLM 中的注意力層在推理過程中實現(xiàn)了隱式的參數(shù)優(yōu)化過程。所以這也是上下文學習能 work 的原因。
但是有個疑問就是 attention 機制不管模型規(guī)模大小都是一樣的操作,為什么模型規(guī)模得增加到一定程度上下文學習才能顯現(xiàn)呢?
我感覺還是回到模型規(guī)模和訓練數(shù)據(jù)上,首先 LLM 中的key, query, value
變換矩陣的維度 d ' x d
足夠大,其次預訓練的數(shù)據(jù)量也大,所以初始權(quán)值 Wzsl
足夠好只需要少量的示例梯度 △Wicl
更新參數(shù)之后就能 work 了,其實感覺就和 Few-Shot Learning 沒什么區(qū)別。