Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
mochen.bmc committed Dec 4, 2023
1 parent 218d13b commit b246a06
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 54 deletions.
5 changes: 0 additions & 5 deletions torch_xla/csrc/common/lynx_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@ struct P2PChannelsManager : public Singleton<P2PChannelsManager> {
friend class Singleton<P2PChannelsManager>;
};

struct CompileOptionsWrapper : public Singleton<CompileOptionsWrapper> {
xla::CompileOptionsProto completion_options_proto;
bool initialized = false;
};

} // namespace lynx

#endif
14 changes: 0 additions & 14 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1337,20 +1337,6 @@ void InitXlaModuleBindings(py::module m) {
[](const std::string& device, const std::vector<std::string>& devices,
bool wait, const std::optional<py::dict> compile_options) {
NoGilSection nogil;
if (compile_options.has_value()) {
auto wrapper = lynx::CompileOptionsWrapper::GetInstance();
auto* completion_options_proto = &wrapper->completion_options_proto;
auto* executable_build_options =
completion_options_proto->mutable_executable_build_options();
executable_build_options->set_use_spmd_partitioning(
compile_options.value()["use_spmd_partitioning"].cast<bool>());
int num_replicas = compile_options.value()["num_replicas"].cast<int>();
int num_partitions = compile_options.value()["num_partitions"].cast<int>();
executable_build_options->set_num_partitions(num_partitions);
executable_build_options->set_num_replicas(num_replicas);
// TODO: device assignment move to pjrt_computation_client.cc
wrapper->initialized = true;
}
StepMarker(device, devices, wait);
},
py::arg("device") = "", py::arg("devices"), py::arg("wait") = true,
Expand Down
36 changes: 1 addition & 35 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -489,43 +489,9 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
tsl::profiler::TraceMeLevel::kInfo);
std::vector<ComputationClient::ComputationPtr> computations;

auto wrapper = lynx::CompileOptionsWrapper::GetInstance();

for (auto& instance : instances) {
xla::CompileOptions compile_options;
if (wrapper->initialized) {
auto res =
xla::CompileOptions::FromProto(wrapper->completion_options_proto);
if (TF_PREDICT_FALSE(!res.ok())) {
XLA_ERROR() << "Failed to call xla::CompileOptions::FromProto(proto).";
continue;
}
compile_options = std::move(res).value();
compile_options.parameter_is_tupled_arguments =
instance.parameter_is_tupled_arguments;
auto replica_count = compile_options.executable_build_options.num_replicas();
auto partition_count = compile_options.executable_build_options.num_partitions();
compile_options.executable_build_options
.set_allow_spmd_sharding_propagation_to_output(
{instance.allow_spmd_sharding_propagation_to_output});
xla::DeviceAssignment device_assignment(replica_count, partition_count);
std::unordered_map<int, int> revert_global_ordinals;
for (const auto& [device_id, global_ordinal] : global_ordinals_) {
revert_global_ordinals[global_ordinal] = device_id;
}
// DeviceAssignment values must be the PjRtDevice ID, so we need to
// unwind the global ordinal mapping.
for (int64_t partition_id = 0; partition_id < partition_count;
++partition_id) {
for (int64_t replica_id = 0; replica_id < replica_count; ++replica_id) {
int64_t flattened_id = replica_id * partition_count + partition_id;
device_assignment(replica_id, partition_id) =
revert_global_ordinals[flattened_id];
}
}
compile_options.executable_build_options.set_device_assignment(
device_assignment);
} else if (instance.is_sharded) {
if (instance.is_sharded) {
// TODO(yeounoh) multi-host, multi-slice configurations
compile_options.executable_build_options.set_use_spmd_partitioning(true);
// We can override the compiler's default behavior to replicate the
Expand Down

0 comments on commit b246a06

Please sign in to comment.