diff --git a/optuna_integration/mlflow/mlflow.py b/optuna_integration/mlflow/mlflow.py index 431f7fee..197e8264 100644 --- a/optuna_integration/mlflow/mlflow.py +++ b/optuna_integration/mlflow/mlflow.py @@ -153,7 +153,7 @@ def wrapper(trial: optuna.trial.Trial) -> float | Sequence[float]: with self._lock: study = trial.study self._initialize_experiment(study) - nested = self._mlflow_kwargs.get("nested") + nested = bool(self._mlflow_kwargs.get("nested")) run_name = self._mlflow_kwargs.get("run_name", str(trial.number)) with mlflow.start_run(run_name=run_name, nested=nested) as run: diff --git a/tests/mlflow/test_mlflow.py b/tests/mlflow/test_mlflow.py index c9992409..0082fd19 100644 --- a/tests/mlflow/test_mlflow.py +++ b/tests/mlflow/test_mlflow.py @@ -127,7 +127,8 @@ def test_use_existing_experiment_by_id(tmpdir: py.path.local) -> None: assert experiment.experiment_id == experiment_id assert experiment.name == "foo" - runs = mlfl_client.search_runs(experiment_id) + # TODO(y0z): Remove type ignore once the MLFlow typing is fixed. + runs = mlfl_client.search_runs(experiment_id) # type: ignore assert len(runs) == 10 @@ -235,7 +236,7 @@ def test_tag_truncation(tmpdir: py.path.local) -> None: first_run_dict = first_run.to_dictionary() my_user_attr = first_run_dict["data"]["tags"]["my_user_attr"] - assert len(my_user_attr) <= 5000 + assert len(my_user_attr) <= 8000 def test_nest_trials(tmpdir: py.path.local) -> None: