Ranking Distillation: Learning Compact Ranking Models With High Performance for Recommender System (KDD 2018)を読んだ

3行で

ランク学習にたいする知識蒸留タスクに対する手法を提案する論文です。 利用方法として、モデルサイズを小さくして応答時間を良くしたり、オンラインでの学習を実現しやすくできることが考えられます。

背景

Deepを使ってレコメンドタスクを解くことを考えます。どんなにオフライン指標が高いモデルができたとしても、インフラが貧弱で、サービスに載らなければ売上には繋がりません

... というわけで、こういったことの解決策の一つとして、「モデルを小さくする」ということが挙げられます。すこし悪いくらいの予測精度で、もっと小さなモデルを構築できれば、システムに載る可能性がでてきます。

いろんな方法が提案されていて、(中には3Dプリンタでモデルを書き出し、レーザー光を当てて出力を観測し予測、文字通り光の速さで予測できるよ、という論文もあります) 今回は蒸留によってモデルを小さくする話となります。

Knowledge Distillation

一般的なclassificationタスクでの蒸留手法はさまざま提案されていますが、ランク学習という枠組みでは解かれて来ませんでした。Classificationに対するKnowledge Distillationは大まかに言ってある入力にたいするラベルの分布、を真似するように小さいモデルを学習させますが、例えばレコメンドにおけるラベルの分類というのは下手をすればアイテムの個数分となります。

提案手法

ざっくりいうと次の手順を踏みます

  1. データセットD0を用いてTeacherである大きなモデルを学習させる
  2. データセットD1 (ラベルなし) をTeacherに予測させ、上位のDocumentを取得
  3. データセットD2 (ラベルあり) による誤差と、2で得られたDocumentを真似るような誤差とを合わせてChildを学習

f:id:Graphium:20190216170613p:plain

このうち3.でのロス関数は次の定義です。

 L (\theta _ s) = (1-\alpha)  L^{R} (\bf{y}, \hat{\bf{y}}) + \alpha L^{D}  (\pi _ {1 \cdots K}, \hat{\bf{y}})

右辺の左側は通常のpoint-wiseなランク学習のロス関数

 L^{R} (\bf{y},\hat{\bf{y}}) = - ( \sum _ {d \in \bf{y} _ {d+}} \log (P (rel = 1 | \hat{y} _ d)) + \sum_{d \in \bf{y} _ {d-}} \log (1 - P (rel = 1 | \hat{y} _ d)))

で、右が本題の手順2. のドキュメントを利用するときのロス関数に当たります(後述)

Teacher modelの蒸留

前述のロス関数右辺第二項は以下のように定義されます

 L^{D} (\pi _ {1 \cdots K}, \hat{\bf{y}}) = - \sum ^{K} _ {r=1} w _ {r} \cdot \log( P(rel = 1| \hat{y} _ {\pi _ r}))

 \pi _ {r} というのは、「Teacherがtop-r位に予測したDocument」 ということを示します.  K は、Teacherの予測の上位何位までをStudentの学習に用いるか、というハイパラになります。

そして、一つ一つのDocumentの誤差の和を取る際につける重み  w _ r

 w _ {r} =  \frac{ ( w ^ {a} _ {r} \cdot w ^ {b} _ {r} )  } {  \sum ^ {K} _  {i=1} w ^ {a} _ {i} \cdot w ^ {b} _ {r}}  で定義されます。

これは2つの要素からなっており、 Teacher が上位だと予測したdocumentほど、大きい値になる(= この重みが大きいDocumentを上位に予測するとより褒めてくれる) 重み  w _ a  w _ r ^ {a} \propto  e ^ {-r / \lambda }   ,  \lambda \in \mathbb{R} ^ {+}

と、 「このDocumentは本来Studentくんにr位くらいと予測してほしいが、ちょっと難しそうなので厳しく教えるね」 という重み  w _ b

 w _ r ^ {b} = tanh (max ( \mu \cdot ( \hat{r _ {\pi _ r }} - r), 0 ))

の2つを混ぜ合わせて計算します。

  \hat{r _ {\pi _ r }} というのは、「 \pi_r をStudentくんがドキュメント全体の何位に予測するだろうか」 という数値です。ランダムにサンプルしたデータの予測結果と比較して、全ドキュメントの何位あたりにに \pi_r が予測されるか、という数値を求め、 その順位が大きければ大きいほど w _ b は大きくなることになります

f:id:Graphium:20190216183318p:plain 

実験

データセット

データセットは次の2つ。全actionのうち、train:valid:test = 7:1:2としている。 f:id:Graphium:20190216184500p:plain

評価指標

Precision@n, nDCG@n, MAP ( n \in {3, 5, 10})

比較対象

  • Fossil : Similarity + 多次元マルコフチェイン
  • Caser : CNNによるDocumentのEmbeddingと、UserのEmbeddingのFusion
  • POP : popularityベースのレコメンド
  • ItemCF : 強調フィルタリング
  • BPR : Baysian personalized ranking

上2つの手法をTeacherとして、それぞれごく普通のLinear Layerを重ね合わせただけのStudentを学習させます。

結果

提案手法と比較対象の予測性能差

f:id:Graphium:20190216190210p:plain -T とついているのは元のTeacherモデルで、-RDとついているのは提案手法。 -SはTeacherを学習させたときと同じデータで、蒸留をせずStudentモデルを学習させた場合を表します。

ここで、-S のStudentモデルのモデルサイズは、「少しずつTeacherとcomparableになるまでモデルサイズを大きくしていった」と書かれています。

-RD と-Sを比較すると、-RDが大幅に性能がよい = 提案された蒸留方法がよく機能している ということが言えます

蒸留したモデルの予測計算時間

f:id:Graphium:20190216185633p:plain テストデータに対する全アイテムの予測時間です。先程の予測性能表と見比べると、「同じ程度の予測性能があるのに、予測にかかる時間は半分程度になったよ」ということが言えます

Distillation lossの有効性

f:id:Graphium:20190216185642p:plain 横軸iteration, 縦軸MAP. 最初はDistillation Lossを用いいずにStudentを学習させ、iteration = 30の破線があるタイミングでDistillation Lossを導入してみたところ(紫の折れ線)、 学習曲線がよくなったことから、提案したLoss関数が有効だと主張しています。

Studentのモデルサイズと予測性能の関係

f:id:Graphium:20190216185633p:plain 横軸にStudentのモデルサイズ (棒が右軸で示されるパラメータ数), 折れ線がStudentの予測性能(MAP), 破線がTeacherの予測性能となります。 「ある程度モデルが小さくなるまでは、モデルパラメータ数に対して線形に性能が変化する傾向」、つまり多少モデルサイズを減らしたところで一気にStudentの予測性能が落ちるわけではない、ということが分かります。

二種類の重みを用いることの有用性

f:id:Graphium:20190216185651p:plain 横軸に2種類の重みのどちらをより使うかのパラメータ、縦軸にMAP. いい感じに2つの重みをmixさせると予測性能いい感じだよ、というグラフですね.