2017.06.30
第4回 Backpropagation Through Time(BPTT)
リカレントニューラルネットワークの誤差を求める際には時間軸で展開するとイメージしやすいです。誤差が時間をさかのぼって逆伝播するBackpropagation Through Time(BPTT)について解説します。
電子書籍『詳解 ディープラーニング』をマナティで発売中!
(上の書籍画像をクリックすると購入サイトに移動できます)
リカレントニューラルネットワークの誤差を求める際、気を付けなければならない点が1つあります。一般的なニューラルネットワークでは、例えば誤差関数を2 乗誤差関数として考える場合、式
と同様、誤差 eh(t), eo(t) は、
で与えられることになります。これらの式自体は正しいのですが、リカレントニューラルネットワークでは、ネットワークの順伝播で時刻 t – 1 における隠れ層の出力 h(t – 1) を考えたため、逆伝播の際も t – 1 における誤差を考える必要があります。
リカレントニューラルネットワークのモデルを時間軸で展開するとイメージしやすいかもしれません。例えば下の図は時刻 t – 2 の入力 x(t – 2) まで展開したものになりますが、誤差 eh(t) は eh(t – 1) に逆伝播し、さらに eh(t – 1) は eh(t – 2) に逆伝播します。このように、順伝播の際は h(t) が h(t – 1) の再帰関係式で表されたのと同様、逆伝播の際は eh(t – 1) を eh(t) の式で表す必要があります。このとき、誤差は時間をさかのぼって逆伝播していることになるので、これを Backpropagation Through Time と呼び、BPTT と略記します。
BPTT という手法の名前は付いていますが、考えるべきは eh(t – 1) を eh(t) の式で表すことです。 t – 1 における誤差は、
なので、再帰関係式を求めると、
が得られます。よって、再帰的に eh(t – z – 1) と eh(t – z) は
の関係で表すことができます。これによりすべての勾配が計算できることが分かり、各パラメータの更新式は、
となります。この 𝜏 がどれくらいの過去までさかのぼって時間依存性を見るかを表すパラメータなので、理想的には 𝜏 → +∞ とすべきです。しかし、現実には勾配が消失(あるいは爆発)してしまうことを防ぐため、せいぜい 𝜏 = 10 ~ 100 くらいに設定するのが一般的です※。
※すなわち、より長期にまたがる時間依存性はここで考えるアプローチでは学習できないことになります。この問題を解決するための手法は次回の記事で扱います。