前回の記事で作成したUbuntu16.04LTS上のPython機械学習環境で日経ソフトウェア11月号のサンプルプログラムを動作させます。記事の通りkerasライブラリを使用することで簡単に
ディープラーニングを試すことができます。
機械学習データにはAnacondaと共にインストールされたscikit-learnライブラリに同梱されている、アヤメの品種分類データを使用します。動作環境はPython対話モード(インタプリタ)とします。
目次
1.アヤメデータ展開
(1)データロード
以下の通り、「python」コマンドで対話モードに入りsklearnライブラリをインポートします。このライブラリにはアヤメデータクラス用の専用メソッドが用意されています。load_iris()メソッドでアヤメデータクラスを取得し、DESCRメソッドでアヤメデータクラスの情報を参照しprint文で情報を出力しています。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
$ python >>>from sklearn import datasets >>>iris = datasets.load_iris() >>>>print(iris.DESCR) Iris Plants Database ==================== Notes ----- Data Set Characteristics: :Number of Instances: 150 (50 in each of three classes) :Number of Attributes: 4 numeric, predictive attributes and the class :Attribute Information: - sepal length in cm - sepal width in cm - petal length in cm - petal width in cm - class: - Iris-Setosa - Iris-Versicolour - Iris-Virginica :Summary Statistics: ... |
上記DESCR情報により、アヤメデータは150×4の2次元配列になっています。150のうち最初の50個が「setosa」品種、次の50個が「Versicolour」品種、最後の50個が「Virginica」品種のデータです。4は品種を識別する以下の属性データです。この4種類の値の違いで品種を識別しています。
sepal length:がく片の長さ
sepal width :がく片の幅
petal length:花びらの長さ
petal width :花びらの幅
「がく片」とは外側の大きい花びら、「花びら」が内側の小さい花びらのことのようです。
(2)リストデータ参照
dataリスト、targetリストを参照してみます。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
>>>iris.data array([[ 5.1, 3.5, 1.4, 0.2], [ 4.9, 3. , 1.4, 0.2], [ 4.7, 3.2, 1.3, 0.2], [ 4.6, 3.1, 1.5, 0.2], [ 5. , 3.6, 1.4, 0.2], [ 5.4, 3.9, 1.7, 0.4], [ 4.6, 3.4, 1.4, 0.3], [ 5. , 3.4, 1.5, 0.2], [ 4.4, 2.9, 1.4, 0.2], [ 4.9, 3.1, 1.5, 0.1], [ 5.4, 3.7, 1.5, 0.2], [ 4.8, 3.4, 1.6, 0.2], [ 4.8, 3. , 1.4, 0.1], ... >>>iris.target array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]) |
dataリストが(1)で説明した150×4の実データです。targetは150の属性データに対応した品種を識別する番号が0~2まで50個づつ振られています。0がsetosa品種、1がVersicolour品種、2がVirginica品種を意味します。
2.機械学習
(1)準備
150個のアヤメデータをシャッフルし学習用120個テスト用30個に分割します。シャッフルと分割はsklearnライブラリに専用のメソッドが用意されていますのでこれを使用します。
1 2 |
>>> from sklearn.model_selection import train_test_split as split >>> x_train, x_test, y_train, y_test = split(iris.data,iris.target,train_size=0.8,test_size=0.2) |
kerasライブラリをインポートします。
1 2 |
>>> import keras >>> from keras.layers import Dense, Activation |
(2)モデル作成
ニュートラルネットワークで使用するモデルを作成します。
1 2 3 4 5 6 |
>>> model = keras.models.Sequential() >>> model.add(Dense(units=32, input_dim=4)) >>> model.add(Activation('relu')) >>> model.add(Dense(units=3)) >>> model.add(Activation('softmax')) >>> model.compile(loss='sparse_categorical_crossentropy',optimizer='sgd',metrics=['accuracy']) |
1行目:レイヤーの線形スタックであるSequentialモデルを適用します
2行目:隠れ層が32個、入力層が4個(アヤメ属性数)のニューロンを指定します
3行目:隠れ層の活性化関数にReLU関数を適用します
4行目:出力層を3個(アヤメの品種)にします
5行目:出力層の活性化関数にsoftmax関数を適用します
6行目:compileでモデルを構築します。引数の意味は下記。
loss :損失関数(コスト関数)、今回はクロスエントロピー法を適用
optimizer:最適化アルゴリズム、今回は確率的勾配降下法を適用
metrics :評価関数のリスト、今回は正解率を適用
(3)実行
モデルができたので、学習データを与えて機械学習を開始します。epochs=100を指定し学習を100回行うようにします。機械学習では繰り返し学習をすることで入力層、隠れ層、出力層からなるニュートラルネットワークの重み(パラメータ)を最適化できます。
1 |
>>> model.fit(x_train,y_train,epochs=100) |
3.学習結果の評価
(1)テストデータでの判定
テストデータを入力にevaluate()で構築したニュートラルネットワークモデルの評価を行います。
1 2 3 4 |
>>> score = model.evaluate(x_test,y_test,batch_size = 1) 1/30 [>.............................] - ETA: 0s >>> print("正解率(accuary)=",score[1]) 正解率(accuary)= 0.866666666667 |
日経ソフトウェア記事の正解率は96.7%でしたが今回の正解率は86.7%とあまり良い数字ではありませんでした。ランダムにトレーニングデータとテストデータを選択しているのでその時選択されるデータにより正解率は異なるようです。
(2)分類
predict()で任意のデータに対する分類をしてみます。分類対象はNumPyの配列形式にする必要があるようです。
1 2 3 4 5 |
>>> import numpy as np >>> x = np.array([[5.1,3.5,1.4,0.2]]) >>> r = model.predict(x) >>> r array([[ 0.9518137 , 0.04466538, 0.00352097]], dtype=float32) |
0:setosa品種である確率が95.2%と分類されました。
簡単に最も確率の高い分類を判定するargmax()があります。
1 2 |
>>> r.argmax() 0 |
0:setosa品種と判定されました。
4.参考にさせて頂いたサイト&情報
pyenv + scikit-learn に付属しているデータセット
ありがとうございます。