第7回 リカレントニューラルネットワークの実装(3)|Tech Book Zone Manatee

マナティ

詳解 ディープラーニング

第7回 リカレントニューラルネットワークの実装(3)

前回第6回はTensorFlowによるリカレントニューラルネットワーク(RNN)を実装しましたが、今回はKerasでの実装について解説します。

電子書籍『詳解 ディープラーニング』をマナティで発売中!
(上の書籍画像をクリックすると購入サイトに移動できます)

Kerasによるリカレントニューラルネットワークの実装

 TensorFlowでは、モデルの設計部分の一部を自分で数式に沿って実装しなければなりませんでしたが、Kerasではその部分もメソッドが用意されているため、より簡単にモデルを記述できます。

 TensorFlow では tf.contrib.rnn.BasicRNNCell() でしたが、Keras では、

from keras.layers.recurrent import SimpleRNN

とすることでリカレントニューラルネットワークに対応させることができます。層の追加はこれまでと同様で、

model = Sequential()
model.add(SimpleRNN(n_hidden,
                    init=weight_variable,
                    input_shape=(maxlen, n_out)))
model.add(Dense(n_out, init=weight_variable))
model.add(Activation('linear'))

とするだけです。TensorFlow では state をさかのぼる時間分、隠れ層の出力を求める必要がありましたが、Keras ではそこも含めてライブラリ側で計算してくれます。最適化手法の設定に関してもこれまでと同じです。

optimizer = Adam(lr=0.001, beta_1=0.9, beta_2=0.999)
model.compile(loss='mean_squared_error',
              optimizer=optimizer)

誤差が mean_squared_error となっている点に注意してください。

 また、TensorFlow 同様、実際の学習の部分もこれまでとまったく同じコードで実現できます。

epochs = 500
batch_size = 10
 
model.fit(X_train, Y_train,
          batch_size=batch_size,
          epochs=epochs,
          validation_data=(X_validation, Y_validation),
          callbacks=[early_stopping])

Keras ではモデルの出力は model.predict() で得られるので、sin 波を生成するコードは下記になります。

truncate = maxlen
Z = X[:1] # 元データの最初の一部だけ切り出し
 
original = [f[i] for i in range(maxlen)]
predicted = [None for i in range(maxlen)]
 
for i in range(length_of_sequences - maxlen + 1):
    z_ = Z[-1:]
    y_ = model.predict(z_)
    sequence_ = np.concatenate(
        (z_.reshape(maxlen, n_in)[1:], y_),
        axis=0).reshape(1, maxlen, n_in)
    Z = np.append(Z, sequence_, axis=0)
    predicted.append(y_.reshape(-1))

 TensorFlow と比べ、Keras はかなりシンプルに実装がまとまりますが、コードの裏側でどのような計算が行われているのかについてはきちんと理解しておくようにしましょう。

著者プロフィール

巣籠悠輔(著者)
Gunosy、READYFOR創業メンバー、電通・Google NY支社に勤務後、株式会社情報医療の創業に参加。医療分野での人工知能活用を目指す。著書に『Deep Learning Javaプログラミング 深層学習の理論と実装』(インプレス刊、Packet Publishing:Java Deep Learning Essentials)がある。