Optunaの目的関数に引数をもたせたい

Optuna

Optuna - A hyperparameter optimization framework

ハイパーパラメータを最適化するフレームワーク。詳細は公式ドキュメントなりをみていただければ。わたしはGitHubのREADME を読んだだけで使ってみた状態なので、以下に書くことが行儀の良いやり方なのかはしりません。

やりたいこと

READMEのコードには、

# Define an objective function to be minimized.
def objective(trial):
    # ... 省略


study = optuna.create_study()  # Create a new study.
study.optimize(objective, n_trials=100)  # Invoke optimization of the objective function.

と書かれている。ここで、 study.optimize は関数オブジェクトを受け取っている。言い換えると objective は trial : optuna.Trial というただひとつの引数だけを受付ける関数であるという前提がある。

一方で、いろいろな事情で最適化を複数回実行したいとか、オプション引数に合わせて目的関数を変えたいなとか思うことがあるだろう。しかし目的関数は trial 以外の引数は取らないから、たとえば is_hoge みたいなフラグ変数を作るのは難しそうだ。

と、そこでPythonは関数がファーストオブジェクトであることを思い出して、高階関数を書けばいいだけだったと思い至った。

サンプル例

下記の Optuna のランディングページ のQuick Startにあるサンプルコードを使う。

import optuna

def objective(trial):
    x = trial.suggest_uniform('x', -10, 10)
    return (x - 2) ** 2

study = optuna.create_study()
study.optimize(objective, n_trials=100)

print(study.best_params)

たとえば、 (x - 2) ** n を最適化するように変えてみよう。

import optuna

def objective_variable_degree(n):

    def objective(trial):
        x = trial.suggest_uniform('x', -10, 10)
        return (x - 2) ** n

    return objective

こんな感じで objective_variable_degree が引数を受け取った結果帰ってくるのが objective 関数であればよい。

ためしに、4次の多項式を最適化してみた。3次にしなかった理由は単に-10に近づくだけでつまらなかったので…。下記のコードを書いて適切に実行した。といってもサンプルから受け取る関数の名前を変えただけである。

study = optuna.create_study()
study.optimize(objective_variable_degree(4), n_trials=100)

print(study.best_params)

次のような出力が得られる。

[I 2019-06-26 22:50:03,850] Finished trial#0 resulted in value: 6052.213452475149. Current best value is 6052.213452475149 with parameters: {'x': -6.820202562342194}.
[I 2019-06-26 22:50:03,865] Finished trial#1 resulted in value: 10725.225690223175. Current best value is 6052.213452475149 with parameters: {'x': -6.820202562342194}.
[I 2019-06-26 22:50:03,891] Finished trial#2 resulted in value: 1403.2146449107272. Current best value is 1403.2146449107272 with parameters: {'x': 8.120417202779894}.

... 中略 ...

[I 2019-06-26 22:50:06,396] Finished trial#97 resulted in value: 198.58743888360596. Current best value is 5.240658892491164e-11 with parameters: {'x': 1.9973094165081346}.
[I 2019-06-26 22:50:06,431] Finished trial#98 resulted in value: 44.84344891326348. Current best value is 5.240658892491164e-11 with parameters: {'x': 1.9973094165081346}.
[I 2019-06-26 22:50:06,465] Finished trial#99 resulted in value: 3.433404666025541. Current best value is 5.240658892491164e-11 with parameters: {'x': 1.9973094165081346}.

{'x': 1.9973094165081346}

無事 2.0 に近い値が出てきた。めでたしめでたし。