18
16

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Mahoutでニューラルネットする

Posted at

ニューラルネットとDeep Learning

機械学習といわれるとDeep Learning(深層学習)という言葉を思い浮かべる人も多いかと思います。

Deep Learning自身は特に新しい概念ではなく、古くからあるニューラルネットが多層になったものと言っていいと思います。学習の難しさや計算量の多さから一旦廃れたニューラルネットでしたが、学習方法の工夫や計算能力の向上で再び脚光を浴びてきたのがDeep Learningです。(本質は変わってないので今まで通り「ニューラルネット」と呼べばいいような気がしますが、別名を付けて新しさを出そうとするIT業界ではよくある手でしょうか)

音声認識の世界でも、今までの主流の方式だったHMM(Hiddlen Markov Model)に替わって、Deep Learningを使う方法が広まりつつあります。そのおかげで精度や応答速度が向上しているそうです。

Mahoutでニューラルネット

そんなホットな機械学習であるDeep Learningを、ぜひMahoutでもやってみたいというのが今回の記事です。Mahoutには機械学習のためのアルゴリズムがいろいろ入っていますが現在リリースされているMahoutバージョン0.9にはMultilayer Perceptron(多層パーセプトロン)、いわゆるニューラルネットワークも入っているとドキュメントに書いてあります。これを使えばDeep Learningできるのではないでしょうか?しかし、そんなに甘くはありません。

使い方がわからない

Multilayer Perceptronは入っていると書いてあるものの使い方に関する記述はどこにも一切ありません。探してみるとMahoutチケット管理のページにコマンドラインからの使い方Readmeを作成することというタスクがありドラフトのドキュメントが添付されていました。

ドキュメントくらい入れておいてほしいと思いつつこれを参考に使って見ることにします。

そもそもコマンドが無い

Readme記載の実行例を実行すると、クラスが無いというエラーが出てしまいました。ここまできてはじめてわかりましたがバージョン0.9にはアルゴリズムは入っていますがそれをコマンドラインから使うためのクラスが無いということです。

$ bin/mahout org.apache.mahout.classifier.mlp.TrainMultilayerPerceptron .......
WARN driver.MahoutDriver: Unable to add class: org.apache.mahout.classifier.mlp.TrainMultilayerPerceptron
java.lang.ClassNotFoundException: org.apache.mahout.classifier.mlp.TrainMultilayerPerceptron
	at java.net.URLClassLoader$1.run(URLClassLoader.java:217)
	at java.security.AccessController.doPrivileged(Native Method)
	at java.net.URLClassLoader.findClass(URLClassLoader.java:205)
	at java.lang.ClassLoader.loadClass(ClassLoader.java:321)
	at sun.misc.Launcher$AppClassLoader.loadClass(Launcher.java:294)
	at java.lang.ClassLoader.loadClass(ClassLoader.java:266)
	at java.lang.Class.forName0(Native Method)
	at java.lang.Class.forName(Class.java:186)
	at org.apache.mahout.driver.MahoutDriver.addClass(MahoutDriver.java:237)
	at org.apache.mahout.driver.MahoutDriver.main(MahoutDriver.java:128)
14/12/29 17:30:48 WARN driver.MahoutDriver: No org.apache.mahout.classifier.mlp.TrainMultilayerPerceptron.props
found on classpath, will use command-line arguments only

最新版Mahoutでニューラルネット

残念ながらバージョン0.9でニューラルネットするのは諦めて、開発中の最新のソースを持ってきてビルドしあらためてニューラルネットすることにしました。

ビルドは簡単で、まずMahoutのGithubから最新版Mahoutのソースを持ってきます。そしてルートのフォルダでMavenを実行します。Mavenはバージョン3系が必要です。たまたま使っていた環境Ubuntu10ではapt-getでmaven2系しかインストールできなかったのでMavenのサイトよりダウンロードしパスを通しました。

Mahoutのルートのフォルダで以下のコマンドを実行するとビルドが完了します。

mvn -DskipTests clean install

ニューラルネットワークを学習する

さっそく、ニューラルネットワークを学習させてみましょう。ソースに付属しているiris.csvをサンプルとして使います。このデータは3種類のあやめの4個の計測値をリストしたもので、分類など機械学習のテストデータとして有名なものということです。(参考文献を参照

学習にはorg.apache.mahout.classifier.mlp.TrainMultilayerPerceptronクラスを使用します。このクラスにオプションを指定して以下のように実行します。

bin/mahout org.apache.mahout.classifier.mlp.TrainMultilayerPerceptron -i mrlegacy/src/test/resources/iris.csv \
    -sh -labels setosa versicolor virginica -mo iris_model.model -ls 4 8 3 -l 0.2 -m 0.35 -r 0.0001

オプションはそれぞれ以下の意味を持ちます。

  • -i 入力ファイルを指定する。入力はCSVファイルのみ。
  • -sh このオプションをつけると入力ファイル1行目をヘッダ行とみなす。
  • -labels 出力として分類されるラベルを列挙。
  • -mo ニューラルネットワークのモデルデータを保存するファイル。
  • -ls 各レイヤのニューロンの数を指定。左より、入力層、中間層、出力層になる。
  • -l Learning Rate (省略可)
  • -m Momemtum weight (省略可)
  • -r Regularization weight (省略可)

入力ファイルの形式はCSVファイルで、各行が1つのデータの組となります。最後の列が出力として分類されるラベルとなります。その前の列はすべて入力ベクトルとして表されます。iris.csvの場合は、先頭の4列が花の計測値。最後の列がその花の分類を示すラベルになります。

iris.csv
Sepal.Length,Sepal.Width,Petal.Length,Petal.Width,Species
5.1,3.5,1.4,0.2,setosa
4.9,3.0,1.4,0.2,setosa
4.7,3.2,1.3,0.2,setosa
4.6,3.1,1.5,0.2,setosa
5.0,3.6,1.4,0.2,setosa
5.4,3.9,1.7,0.4,setosa
:

入力層のニューロン数は入力ベクトルの大きさ、出力層のニューロン数は出力ラベルの数と合わせる必要があります。中間層のニューロン数は任意のようです。

実行すると-moオプションで指定したファイルにニューラルネットワークのモデルが生成されました。

$ ls -l iris_model.model 
-rw-r--r-- 1 vmplanet vmplanet 710 2014-12-29 17:48 iris_model.model

参考文献

ニューラルネットワークで分類する

ニューラルネットワークで分類するにはorg.apache.mahout.classifier.mlp.RunMultilayerPerceptronクラスを使用します。このクラスにオプションを指定して以下のように実行します。

bin/mahout org.apache.mahout.classifier.mlp.RunMultilayerPerceptron -i mrlegacy/src/test/resources/iris.csv \
  --skipHeader --columnRange 0 3 -mo iris_model.model -o iris_result.txt

オプションはそれぞれ以下の意味を持ちます。

  • -i 入力ファイルを指定する。入力はCSVファイルのみ。
  • --skipHeader このオプションをつけると入力ファイル1行目をヘッダ行とみなす。(なぜか-shだとエラー)
  • --columnRange 入力ファイル中の入力値の開始列と終了列のインデックスを指定する。
  • -mo ニューラルネットワークのモデルデータのファイル。
  • -o 出力を格納するファイル。存在していると分類が実行されないので事前に消すこと。

実行すると-oオプションで指定したファイルに分類結果が出力されました。分類結果はラベルのインデックスで示されます。"0"ならば学習時に-labelオプションの最初に指定した"setosa"に分類されたことを示します。結果を見てみます。

$ more iris_result.txt 
222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222222

... 全部 "virginica"に分類されてしまってます。全然分類できていません。何がいけなかったのでしょうか...

学習データを変えてみる

ニューラルネットの学習というのは入力に対して期待した出力を出せるように重みを調整しながら進めていくのですが、この過程が学習順にものすごく影響を受けるとのことです。学習につかったiris.csvは分類されるべきデータが"setosa"、"versicolor"、"virginica"がそれぞれ50個ずつ順に並んでいます。このため最後に学習した "virginica"の影響を大きく受けていた可能性があります。

そこで、学習につかうデータの行を入れ替え"setosa"、"versicolor"、"virginica"が順にあらわれるようにしました。また学習データを増やすためこの150個のデータを10回繰り返してファイルの末尾につけ、合計1500個の学習データにしました。

iris10.csv
Sepal.Length,Sepal.Width,Petal.Length,Petal.Width,Species
5.1,3.5,1.4,0.2,setosa
7,3.2,4.7,1.4,versicolor
6.3,3.3,6,2.5,virginica
4.9,3,1.4,0.2,setosa
6.4,3.2,4.5,1.5,versicolor
5.8,2.7,5.1,1.9,virginica
:

このデータを使い学習し、もう一度iris.csvで分類のテストをしてみます。

$ cat iris_result.txt 
000000000000000000000000000000000000000000000000001111111111111111111121211111111112211111111111111122222222222222222222222222222222222222222222222222

こんどはいい感じで分類できているようです。150個中146個正解。正答率97.3%でした。

入力データの工夫など扱いが少し難しいのが難点ですがMahoutでしっかりニューラルネットできました。隠れ層を変えたり違うデータでも試してみたいですが今日はここまで。

18
16
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
18
16

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?