From 218d13be3001240111f1b9f5c1b9985b8f9888ae Mon Sep 17 00:00:00 2001 From: "mochen.bmc" Date: Mon, 4 Dec 2023 16:53:06 +0800 Subject: [PATCH] fix --- torch_xla/csrc/init_python_bindings.cpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index e1b32c32725..32b9d745eb4 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1335,18 +1335,17 @@ void InitXlaModuleBindings(py::module m) { m.def( "_xla_step_marker", [](const std::string& device, const std::vector& devices, - bool wait, const std::optional compile_options_args) { + bool wait, const std::optional compile_options) { NoGilSection nogil; - auto& py_options = compile_options_args; - if (py_options.has_value()) { + if (compile_options.has_value()) { auto wrapper = lynx::CompileOptionsWrapper::GetInstance(); auto* completion_options_proto = &wrapper->completion_options_proto; auto* executable_build_options = completion_options_proto->mutable_executable_build_options(); executable_build_options->set_use_spmd_partitioning( - py_options.value()["use_spmd_partitioning"].cast()); - int num_replicas = py_options.value()["num_replicas"].cast(); - int num_partitions = py_options.value()["num_partitions"].cast(); + compile_options.value()["use_spmd_partitioning"].cast()); + int num_replicas = compile_options.value()["num_replicas"].cast(); + int num_partitions = compile_options.value()["num_partitions"].cast(); executable_build_options->set_num_partitions(num_partitions); executable_build_options->set_num_replicas(num_replicas); // TODO: device assignment move to pjrt_computation_client.cc @@ -1355,7 +1354,7 @@ void InitXlaModuleBindings(py::module m) { StepMarker(device, devices, wait); }, py::arg("device") = "", py::arg("devices"), py::arg("wait") = true, - py::arg("compile_options_args") = py::none()); + py::arg("compile_options") = py::none()); m.def("_get_stablehlo", [](const std::vector& tensors, const std::string& device, const std::vector& devices,