Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
mochen.bmc committed Dec 4, 2023
1 parent b246a06 commit 5d0148e
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 10 deletions.
5 changes: 2 additions & 3 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@ def _run_step_closures():
return devctx


def mark_step(wait=False, compile_options=None):
def mark_step(wait=False):
if xu.getenv_as('XLA_EMIT_STEPLOG', bool, False):
print(
'torch_xla.core.xla_model::mark_step\n',
Expand All @@ -828,8 +828,7 @@ def mark_step(wait=False, compile_options=None):
flush=True)
torch_xla._XLAC._xla_step_marker(
torch_xla._XLAC._xla_get_default_device(), [],
wait=xu.getenv_as('XLA_SYNC_WAIT', bool, wait),
compile_options=compile_options)
wait=xu.getenv_as('XLA_SYNC_WAIT', bool, wait))
# Only emit metrics from the first local device index, to avoid emitting the
# same values from different threads.
if is_master_ordinal():
Expand Down
4 changes: 1 addition & 3 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,5 @@ cc_library(
hdrs = ["common/singleton.h",
"common/base.h",
"common/lynx_types.h"],
deps = [
"@xla//xla/pjrt:compile_options_proto_cc",
],
deps = [],
)
1 change: 0 additions & 1 deletion torch_xla/csrc/common/lynx_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include <utility>
#include <string>

#include "xla/pjrt/compile_options.pb.h"
#include "torch_xla/csrc/common/singleton.h"

namespace lynx {
Expand Down
5 changes: 2 additions & 3 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1335,12 +1335,11 @@ void InitXlaModuleBindings(py::module m) {
m.def(
"_xla_step_marker",
[](const std::string& device, const std::vector<std::string>& devices,
bool wait, const std::optional<py::dict> compile_options) {
bool wait) {
NoGilSection nogil;
StepMarker(device, devices, wait);
},
py::arg("device") = "", py::arg("devices"), py::arg("wait") = true,
py::arg("compile_options") = py::none());
py::arg("device") = "", py::arg("devices"), py::arg("wait") = true);
m.def("_get_stablehlo",
[](const std::vector<at::Tensor>& tensors, const std::string& device,
const std::vector<std::string>& devices,
Expand Down

0 comments on commit 5d0148e

Please sign in to comment.