Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
mochen.bmc committed Dec 4, 2023
1 parent 3399ec2 commit 218d13b
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1335,18 +1335,17 @@ void InitXlaModuleBindings(py::module m) {
m.def(
"_xla_step_marker",
[](const std::string& device, const std::vector<std::string>& devices,
bool wait, const std::optional<py::dict> compile_options_args) {
bool wait, const std::optional<py::dict> compile_options) {
NoGilSection nogil;
auto& py_options = compile_options_args;
if (py_options.has_value()) {
if (compile_options.has_value()) {
auto wrapper = lynx::CompileOptionsWrapper::GetInstance();
auto* completion_options_proto = &wrapper->completion_options_proto;
auto* executable_build_options =
completion_options_proto->mutable_executable_build_options();
executable_build_options->set_use_spmd_partitioning(
py_options.value()["use_spmd_partitioning"].cast<bool>());
int num_replicas = py_options.value()["num_replicas"].cast<int>();
int num_partitions = py_options.value()["num_partitions"].cast<int>();
compile_options.value()["use_spmd_partitioning"].cast<bool>());
int num_replicas = compile_options.value()["num_replicas"].cast<int>();
int num_partitions = compile_options.value()["num_partitions"].cast<int>();
executable_build_options->set_num_partitions(num_partitions);
executable_build_options->set_num_replicas(num_replicas);
// TODO: device assignment move to pjrt_computation_client.cc
Expand All @@ -1355,7 +1354,7 @@ void InitXlaModuleBindings(py::module m) {
StepMarker(device, devices, wait);
},
py::arg("device") = "", py::arg("devices"), py::arg("wait") = true,
py::arg("compile_options_args") = py::none());
py::arg("compile_options") = py::none());
m.def("_get_stablehlo",
[](const std::vector<at::Tensor>& tensors, const std::string& device,
const std::vector<std::string>& devices,
Expand Down

0 comments on commit 218d13b

Please sign in to comment.