?論文標(biāo)題 | A Graph Autoencoder Approach to Causal Structure Learning
論文來源 | NeurIPS (workshop) 2019
論文鏈接 | https://arxiv.org/abs/1911.07420
源碼鏈接 | https://github.com/huawei-noah/trustworthyAI
?
對于變量間因果圖結(jié)構(gòu)學(xué)習(xí)的問題,論文中基于圖自編碼器和梯度優(yōu)化的方法來學(xué)習(xí)觀測數(shù)據(jù)中的因果結(jié)構(gòu),主要可以解決非線性結(jié)構(gòu)等價(jià)問題并且將其應(yīng)用到向量值形式的變量因果結(jié)構(gòu)預(yù)測中。實(shí)驗(yàn)部分在人工生成數(shù)據(jù)中驗(yàn)證了提出的 GAE 模型優(yōu)于當(dāng)前其它基于梯度優(yōu)化的模型 NOTEARS 和 DAG-GNN 等,尤其是在規(guī)模較大的因果圖預(yù)測問題中。此外還測試了模型效率問題,在訓(xùn)練過程中隨著圖規(guī)模增大可以達(dá)到線性時(shí)間。
對于如何學(xué)習(xí)變量間因果圖結(jié)構(gòu)的問題,當(dāng)前主流方法主要可以劃分為三種類型:
「Constraint-based」:PC,F(xiàn)CI 等
PC 算法可以參考我另一篇文章 因果推理之因果圖挖掘 PC 算法
「Score-based」:GES[^1] 等
「Gradient-based」:NOTEARS[^2],DAG-GNN 等
DAG-GNN 可以參考我另一篇文章 DAG-GNN:基于圖神經(jīng)網(wǎng)絡(luò)的有向無環(huán)圖結(jié)構(gòu)表示學(xué)習(xí)
相比于 Constraint- /Score-based 系列的方法,Gradient-based 的方法準(zhǔn)確性和計(jì)算效率更高,目測解釋性差點(diǎn)但奈何圖深度學(xué)習(xí)??啊。因此主要介紹基于梯度優(yōu)化的因果圖學(xué)習(xí)發(fā)展背景。
令 DAG 表示因果圖,包含節(jié)點(diǎn) 其中 ,主要考慮的問題是加性噪聲模型 (ANM) 下的因果結(jié)構(gòu)學(xué)習(xí)方法,假設(shè)數(shù)據(jù)生成如下所示:
其中 表示 中存在有向邊指向變量 的節(jié)點(diǎn)集合, 為向量映射函數(shù), 表示加性噪聲并假設(shè)是獨(dú)立同分布的。集合表示為 和 。
NOTEARS 首先將 score-based 系列的組合優(yōu)化問題轉(zhuǎn)化為的線性結(jié)構(gòu)等價(jià)模型(SEM)。對于上述數(shù)據(jù)生成模型改寫為
假設(shè) 并且 表示系數(shù)向量 表示線性 SEM 的加權(quán)鄰接矩陣。為了保證因果圖 是有向無環(huán)的,需要對 進(jìn)行限制,
NOTEARS 將最小平方損失函數(shù)作為優(yōu)化目標(biāo)函數(shù),如下所示
其中 表示 的第 個觀測值。從這定義就可以看出 NOTEARS 只能處理單值變量間的因果,而且是線性結(jié)構(gòu)等價(jià)模型。這也是 GAE 所優(yōu)化改進(jìn)的場景。
DAG-GNN 為了將上述模型適用到非線性場景,提出了的生成式模型如下
其中 表示非線性函數(shù),DAG-GNN 用到的是 MLP + GNN 方法; 作為隱變量而且維度可以小于變量數(shù)量 。模型設(shè)計(jì)細(xì)節(jié)可以參考我另一篇文章 DAG-GNN:基于圖神經(jīng)網(wǎng)絡(luò)的有向無環(huán)圖結(jié)構(gòu)表示學(xué)習(xí)
以上簡單介紹了因果圖挖掘的背景知識,也是個初學(xué)者的綜述性介紹,下面進(jìn)入這篇文章的正文。
這篇論文主要是基于 NOTEARS 方法進(jìn)行改進(jìn),從而使其適用更多場景。主要模型架構(gòu)如下所示
改進(jìn)主要包括兩部分:非線性因果學(xué)習(xí)和向量形式變量可用而不僅是標(biāo)量數(shù)據(jù)。
對于 NOTEARS 優(yōu)化的目標(biāo)函數(shù)可以改寫為
其中 表示數(shù)據(jù)生成模型,對于 NOTEARS 就是線性 SEM 即
為了將 擴(kuò)展為非線性的,可以自定義一個非線性的關(guān)系映射,例如文章用到
其中 為非線性函數(shù)可選為 MLP,和 DAG-GNN 想法類似。為了增強(qiáng)非線性就再加一層 MLP
上面的公式一看,不就和 GAE 的計(jì)算形式差不多,重寫一遍
如果 和 分別表示 variable-wise 編碼器和解碼器,上面的計(jì)算形式和優(yōu)化目標(biāo)函數(shù)不就是基于重構(gòu)誤差訓(xùn)練的 GAE,GAE 不就可以處理 vector-valued 的變量了么 ??
論文中選擇兩個 variable-wised MLPs 和 其中 表示隱藏層維度。最終的優(yōu)化函數(shù)即為
和 DAG-GNN 主要的不同點(diǎn)在于:論文中用的 GAE 是以 作為輸入,而 DAG-GNN 是生成式模型以噪聲數(shù)據(jù) 作為輸入。
個人感覺就是 GAE 和 VGAE 的區(qū)別,GAE PK VGAE 怎么說好呢?作者只能在實(shí)驗(yàn)中證明了 GAE 比 VGAE 效果好而且快。??
對于上述優(yōu)化函數(shù),作者采用了增廣的拉格朗日乘子法進(jìn)行求解,其形式如下
其中 , 表示拉格朗日乘子, 表示懲罰因子,因此對應(yīng)的梯度更新規(guī)則如下
其中 而且 表示可調(diào)的超參數(shù)。
實(shí)驗(yàn)部分在人工數(shù)據(jù)集中對比 baselines NOTEARS 和 DAG-GNN。數(shù)據(jù)包括兩種 Scalar-based 和 Vector-Valued。
采用的指標(biāo)包括兩種結(jié)構(gòu)化漢明距離(SHD)和正陽率 (TPR)。實(shí)驗(yàn)結(jié)果如下所示
整體而言性能提升蠻大的,而且模型訓(xùn)練實(shí)驗(yàn)也比較短。
[^1]: David M Chickering. Optimal structure identi?cation with greedy search. Journal of Machine Learning Research, 3:507–554, March 2003.
[^2]: Xun Zheng, Bryon Aragam, Pradeep K Ravikumar, and Eric P Xing. DAGs with NO TEARS: Continuous optimization for structure learning. In Advances in Neural Information Processing Systems 31, 2018.
介紹因果圖挖掘是為了了解如何用深度學(xué)習(xí)的方法挖掘變量間的因果圖,而不要局限在 PC 或者 kNN 的思路上。
尤其是之前介紹的多變量時(shí)間序列關(guān)聯(lián)挖掘,時(shí)間序列變量間存在因果關(guān)系但之前使用的因果挖掘方法卻過于簡單粗暴。
圖作為一種廣義的數(shù)據(jù)結(jié)構(gòu),任何存在某種關(guān)聯(lián)的數(shù)據(jù)都可以使用當(dāng)前??的圖機(jī)器學(xué)習(xí)進(jìn)行建模。目前很多任務(wù)都是已知圖結(jié)構(gòu),對于未知圖結(jié)構(gòu)的實(shí)體就需要使用模型學(xué)習(xí)到其中關(guān)系或者因果,這往往是一個更有挑戰(zhàn)的任務(wù)。