とある地味なブログ

プログラミングとお絵かきに関する雑記。

Chainerで競馬予想してみた

今話題のディープラーニングで競馬予想、してみたくない? 素人なりに情報を整理しました。 数学的に適当なことを申しておりますので、ツッコミいただけると幸いです。

環境構築

qiita.com

Chainerの入門にはこちらのページの内容が参考になりました。

GitHub - takecian/HorseRacePrediction: Chainer で競馬予想

GitHubのプロジェクトを参照すると、雰囲気がつかめます。

上記プロジェクトに手を入れて、自分好みにしていきます。

回帰問題と分類問題

回帰問題とは、ざっくりいうと、実数の出力が必要になる問題です。 例) 馬の走破タイムを求めるなど

一方で、分類問題とは、分類後のラベルの出力が必要になる問題です。 例) 複勝圏内に入るか否か、馬の順位を求めるなど

Chainerにおける回帰問題と分類問題

回帰問題にする

predicator.pyの、L.Classifierの初期化を以下のようにします。

model = L.Classifier(MLP(args.unit, 1), lossfun=F.mean_squared_error)
model.compute_accuracy = False

上記マジックナンバー1は、出力の数で、一頭の馬の走破タイムであれば、出力は1つになります。

教師データとして、

  • [1, 2, 3]のときの出力が[1]
  • [3, 4, 5]のときの出力が[2]
  • [6, 7, 8]のときの出力が[3]

があるとします。

このとき、入力の配列は以下のようにし、

[[1, 2, 3], 
[4, 5, 6] 
[7, 8, 9]]

正解配列は、

[[1], [2], [3]]

となります。

これらを、 predicator.pyの以下のコードのように、設定していきます。(ここは書き換え不要です。)

...
train = tuple_dataset.TupleDataset(loader.train_data, loader.train_data_answer)
test = tuple_dataset.TupleDataset(loader.test_data, loader.test_data_answer)
...

分類問題にする

predicator.pyの書き換えは不要です。注意したいのは以下です。

教師データとして、

  • [1, 2, 3]のときの出力が0
  • [3, 4, 5]のときの出力が1
  • [6, 7, 8]のときの出力が2

があるとします。

このとき、入力の配列は以下のようにし、

[[1, 2, 3], 
[4, 5, 6] 
[7, 8, 9]]

正解配列は、

[0, 1, 2]

となります。正解ラベルの配列には制約があり、

  • 1次元配列でなければならない
  • 0始まりでなければならない
  • int32でなければならない

です。

そうでないとよくわからないエラーを吐くので注意です。

入力となるパラメータ

alphaimpact.jp

こちらのサイトがとても参考になります。 こちらのサイトの通り、

  • 複勝圏内に入るかどうかの分類問題
  • 特定の場所、芝・ダート、距離で学習

という感じでやります。 入力パラメータは上記サイトのパラメータを使用すると、回収率が100%超えるくらいには再現しました。

パラメータの標準化

パラメータは、距離なら1000 〜 2000、一位との着差タイムだと、0 〜 10など数値自体の大きさが違います。 すると、学習効率が悪いらしいので、標準化します。

下記サンプルコードのように、numpyのmean()std()をつかって平均や標準偏差などを求め、 パラメータを平均0で、-1から1までに分散するよう標準化します。

def normalize(x, mean, std):
    x_copy = np.copy(x)
    x_norm = (x_copy - mean) / std

    return x_norm

means = train_data.mean()
stds = train_data.std()
train_data = normalize(train_data, means, stds)

hyperoptでハイパーパラメータを設定する

chainerで作ったDeep Learningモデルのハイパーパラメータチューニングを自動化してみる - verilog書く人

こちらの記事を参考に、hyperoptの使用コードを実装します。

そうしてみると、意外とreluよりsigmoidのほうがいいのかな?とか

batchsizeが100より500のほうがaccuracyが安定して上昇するぞとか、わかります。

おまけ

JRA-VANのデータはCSVで提供されており、ヘルプサイトにCSVのヘッダが置いてあったりするので、 SQLiteを使うとインポートが楽です。一行目のカラム名がついたテーブルを作成してくれます。

sqlite3
> .mode csv
> .import ファイル名 テーブル名

特定の馬の前回のレース結果を取ったりするのにSQLが便利です。