Chainerで競馬予想してみた
今話題のディープラーニングで競馬予想、してみたくない? 素人なりに情報を整理しました。 数学的に適当なことを申しておりますので、ツッコミいただけると幸いです。
環境構築
Chainerの入門にはこちらのページの内容が参考になりました。
GitHub - takecian/HorseRacePrediction: Chainer で競馬予想
GitHubのプロジェクトを参照すると、雰囲気がつかめます。
上記プロジェクトに手を入れて、自分好みにしていきます。
回帰問題と分類問題
回帰問題とは、ざっくりいうと、実数の出力が必要になる問題です。 例) 馬の走破タイムを求めるなど
一方で、分類問題とは、分類後のラベルの出力が必要になる問題です。 例) 複勝圏内に入るか否か、馬の順位を求めるなど
Chainerにおける回帰問題と分類問題
回帰問題にする
predicator.pyの、L.Classifier
の初期化を以下のようにします。
model = L.Classifier(MLP(args.unit, 1), lossfun=F.mean_squared_error) model.compute_accuracy = False
上記のマジックナンバー1
は、出力の数で、一頭の馬の走破タイムであれば、出力は1つになります。
教師データとして、
[1, 2, 3]
のときの出力が[1]
[3, 4, 5]
のときの出力が[2]
[6, 7, 8]
のときの出力が[3]
があるとします。
このとき、入力の配列は以下のようにし、
[[1, 2, 3], [4, 5, 6] [7, 8, 9]]
正解配列は、
[[1], [2], [3]]
となります。
これらを、 predicator.pyの以下のコードのように、設定していきます。(ここは書き換え不要です。)
... train = tuple_dataset.TupleDataset(loader.train_data, loader.train_data_answer) test = tuple_dataset.TupleDataset(loader.test_data, loader.test_data_answer) ...
分類問題にする
predicator.pyの書き換えは不要です。注意したいのは以下です。
教師データとして、
[1, 2, 3]
のときの出力が0
[3, 4, 5]
のときの出力が1
[6, 7, 8]
のときの出力が2
があるとします。
このとき、入力の配列は以下のようにし、
[[1, 2, 3], [4, 5, 6] [7, 8, 9]]
正解配列は、
[0, 1, 2]
となります。正解ラベルの配列には制約があり、
- 1次元配列でなければならない
- 0始まりでなければならない
- int32でなければならない
です。
そうでないとよくわからないエラーを吐くので注意です。
入力となるパラメータ
こちらのサイトがとても参考になります。 こちらのサイトの通り、
- 複勝圏内に入るかどうかの分類問題
- 特定の場所、芝・ダート、距離で学習
という感じでやります。 入力パラメータは上記サイトのパラメータを使用すると、回収率が100%超えるくらいには再現しました。
パラメータの標準化
パラメータは、距離なら1000 〜 2000、一位との着差タイムだと、0 〜 10など数値自体の大きさが違います。 すると、学習効率が悪いらしいので、標準化します。
下記サンプルコードのように、numpyのmean()
やstd()
をつかって平均や標準偏差などを求め、
パラメータを平均0で、-1から1までに分散するよう標準化します。
def normalize(x, mean, std): x_copy = np.copy(x) x_norm = (x_copy - mean) / std return x_norm means = train_data.mean() stds = train_data.std() train_data = normalize(train_data, means, stds)
hyperoptでハイパーパラメータを設定する
chainerで作ったDeep Learningモデルのハイパーパラメータチューニングを自動化してみる - verilog書く人
こちらの記事を参考に、hyperoptの使用コードを実装します。
そうしてみると、意外とreluよりsigmoidのほうがいいのかな?とか
batchsizeが100より500のほうがaccuracyが安定して上昇するぞとか、わかります。
おまけ
JRA-VANのデータはCSVで提供されており、ヘルプサイトにCSVのヘッダが置いてあったりするので、 SQLiteを使うとインポートが楽です。一行目のカラム名がついたテーブルを作成してくれます。
sqlite3 > .mode csv > .import ファイル名 テーブル名
特定の馬の前回のレース結果を取ったりするのにSQLが便利です。