2023.02.27
GPUに対応した最新の機械学習ライブラリを使いながらディープラーニングを基本から学べる! 『JAX/Flaxで学ぶ ディープラーニングの仕組み』
『TensorFlowとKerasで動かしながら学ぶディープラーニングの仕組み』という書籍をご存知でしょうか? タイトルの通り、TensorFlowとKerasを実際に使いながらディープラーニングを学べる内容で、Manateeでもご好評いただいています。
今回は本書の内容をGoogleの研究者にも注目されている最新のライブラリであるJAXとFlaxで再構成した書籍をご紹介します。
まだ日本語の情報が少ないJAX/Flaxを動かしながら学べる
JAX/Flaxを学ぶ上でまず大変なのは、日本語の情報が少ないことではないでしょうか。Qiitaや個人ブログではいくつか見かけますが、書籍としてはまだ発行されていないようです(2023年2月時点)。
このような状況だとJAX/Flaxのインストールはどうやるのか?どんな関数があるのか?そもそも何が違うの?TensorFlowとKerasでいいんじゃない? そういった疑問を解消するのも一苦労です。
調べてみようにも英語の公式リファレンスを見るのはちょっと大変…そこで、『TensorFlowとKerasで動かしながら学ぶディープラーニングの仕組み』が『JAX/Flaxで学ぶディープラーニングの仕組み』に生まれ変わりました。
本書ではTensorFlow/KerasとJAX/Flaxの違いを詳しく解説しています。第1章をていねいに読み解いていけばJAX/Flaxの基本的な使い方がわかります。また、TensorFlow/KerasがあるのにJAX/Flaxが必要とされている理由も納得できるでしょう。
たとえば、TensorFlow/Kerasだと基本的な機械学習モデルを構築して与えられたデータで学習する上では非常に簡単なコードで実装できます。
しかし、学習中のモデルの中身を分析する、学習済みのパラメータを他のモデルに移すといった「研究・開発レベルの機械学習」の難易度は高くなっています。
JAX/Flaxであれば、モデルとパラメータが分離しているため、パラメータの調整や管理がやりやすくなっています。
Jax/Flaxはモデルの学習に必要なコードを自分で用意するため、入門レベルの機械学習を実施する場合はTensorFlow/Kearasにくらべて手間がかかりますが、その一方で、より高度な機械学習モデルを実装するのは容易になっていると言えます。
仕組みを理解する上では、モデルとパラメータが分かれていたり、学習のコードを自分で用意する方が、より早く理解が進むのではないでしょうか。
難易度が高くなっているとはいっても、本書ではコーディングの手順をライブラリのインストール方法から順番に細かく解説しているので一つ一つ丁寧に読んで実行していくことで自然とJAX/Flaxを用いたディープラーニングを身に着けることができます。
本書だけでJAX/Flaxを使った手書き文字認識、多層ニューラルネットワークの実装、畳み込みフィルタを使った画像分類や転移学習、DCGANなどをまとめて学ぶことができます。
実行環境はColaboratoryを採用しているので、ディープラーニングは知ってるけどJAX/Flaxはまったく知らないという方でも是非気軽に始めてみてください。
本書で解説に使ったipynbファイルはすべてダウンロードすることができます。こちらも学習に役立ててください。
ディープラーニングの仕組みを数式と豊富な図で詳しく解説
本書に掲載されているコードを一つ一つ順番に実行していくと手書き文字認識やDCGANによる画像生成などができるようになります。
プログラミングに慣れている人ならJAX/Flaxを使って簡単なモデルを構築して学習処理を実行するのはすぐにできるようになるかもしれません。
ですが、研究や開発をするうえでは定型的なモデル構築だけではなく、モデルの分析ができるようになる必要があります。
そのためには原理の理解が欠かせません。そこで、本書では数式を用いてディープラーニングへの理解を深めていきます。
数学は全然覚えてないから不安だ、という方もご安心ください。本書では図版も豊富に使用しています。まずはこちらを見ながら、少しずつ数式にも目を通してみるのがオススメです。
本書の330ページ以降では本書で使用されている数学の公式と、利用されている関数が簡単にまとめられています。
特に行列の計算はディープラーニングでは必ず覚える必要があるので、是非とも本書をきっかけにして身に着けていただければ今後にも役に立つかと思います。
関数のまとめは代表的な関数のリファレンスとして、本書の復習としてご活用ください。
まとめ
本書はManateeでも大人気の書籍、『TensorFlowとKerasで動かしながら学ぶディープラーニングの仕組み』をJAX/Flaxで再構成した内容になっています。
本書を通して、畳み込みニューラルネットワーク(CNN)の仕組みを理解しながらJAX/Flaxで実装することができるようになります。
ただ使うライブラリや関数を覚えてモデルを作るだけではなく、原理を理解できる内容になっているので、Pythonの知識を利用して様々な応用ができるようになります。
JAX/Flaxは比較的新しい機械学習ライブラリで、日本語でこれだけ詳しく丁寧に解説した書籍は本書以外にはまだ発行されていません(2023年2月時点)。
GoogleのアカウントとPCさえあれば実行できる内容になっているので、JAX/Flaxなんて初めて聞いたけどディープラーニングには興味がある方はもちろん、いち早くディープラーニングのトレンドを押さえて活用したいエンジニアの方には特にオススメの書籍です。