diff --git a/codegen/xla_native_functions.yaml b/codegen/xla_native_functions.yaml index a25244680f51..7b04960ad82f 100644 --- a/codegen/xla_native_functions.yaml +++ b/codegen/xla_native_functions.yaml @@ -84,6 +84,7 @@ full_codegen: - sinh - softshrink - softshrink_backward + - sqrt - take - tan - tanh @@ -313,7 +314,6 @@ supported: - sort.stable - split_copy.Tensor - split_with_sizes_copy - - sqrt - squeeze_copy - squeeze_copy.dim - squeeze_copy.dims @@ -373,7 +373,6 @@ supported: - narrow_copy - pixel_shuffle - pixel_unshuffle - - reshape - select_backward - select.int - slice.Tensor @@ -406,8 +405,6 @@ symint: - narrow_copy - select_backward - select.int - # See Note: [functionalization and CompositeExplicitAutograd] - - reshape # See Note: [Disabling functionalization] - expand - view diff --git a/infra/tpu-pytorch-releases/artifacts.auto.tfvars b/infra/tpu-pytorch-releases/artifacts.auto.tfvars index 11f9bc6f427f..adf971a87921 100644 --- a/infra/tpu-pytorch-releases/artifacts.auto.tfvars +++ b/infra/tpu-pytorch-releases/artifacts.auto.tfvars @@ -25,54 +25,54 @@ nightly_builds = [ versioned_builds = [ # Remove libtpu from PyPI builds { - git_tag = "v2.2.0-rc5" - package_version = "2.2.0rc5" - pytorch_git_rev = "v2.2.0-rc5" + git_tag = "v2.2.0-rc6" + package_version = "2.2.0rc6" + pytorch_git_rev = "v2.2.0-rc6" accelerator = "tpu" bundle_libtpu = "0" }, { - git_tag = "v2.2.0-rc5" - package_version = "2.2.0rc5" - pytorch_git_rev = "v2.2.0-rc5" + git_tag = "v2.2.0-rc6" + package_version = "2.2.0rc6" + pytorch_git_rev = "v2.2.0-rc6" accelerator = "tpu" python_version = "3.9" bundle_libtpu = "0" }, { - git_tag = "v2.2.0-rc5" - package_version = "2.2.0rc5" - pytorch_git_rev = "v2.2.0-rc5" + git_tag = "v2.2.0-rc6" + package_version = "2.2.0rc6" + pytorch_git_rev = "v2.2.0-rc6" accelerator = "tpu" python_version = "3.10" bundle_libtpu = "0" }, { - git_tag = "v2.2.0-rc5" - package_version = "2.2.0rc5" - pytorch_git_rev = "v2.2.0-rc5" + git_tag = "v2.2.0-rc6" + package_version = "2.2.0rc6" + pytorch_git_rev = "v2.2.0-rc6" accelerator = "tpu" python_version = "3.11" bundle_libtpu = "0" }, # Bundle libtpu for Kaggle { - git_tag = "v2.2.0-rc5" - package_version = "2.2.0rc5+libtpu" - pytorch_git_rev = "v2.2.0-rc5" + git_tag = "v2.2.0-rc6" + package_version = "2.2.0rc6+libtpu" + pytorch_git_rev = "v2.2.0-rc6" accelerator = "tpu" python_version = "3.10" bundle_libtpu = "1" }, { - git_tag = "v2.2.0-rc5" - package_version = "2.2.0rc5" + git_tag = "v2.2.0-rc6" + package_version = "2.2.0rc6" accelerator = "cuda" cuda_version = "12.1" }, { - git_tag = "v2.2.0-rc5" - package_version = "2.2.0rc5" + git_tag = "v2.2.0-rc6" + package_version = "2.2.0rc6" accelerator = "cuda" cuda_version = "12.1" python_version = "3.10" diff --git a/plugins/cuda/README.md b/plugins/cuda/README.md new file mode 100644 index 000000000000..f5a2647f6e6f --- /dev/null +++ b/plugins/cuda/README.md @@ -0,0 +1,41 @@ +# CUDA PJRT plugin (experimental) + +This directory contains an experimental implementation of the PJRT GPU client as +a plugin. The actual implementation of the PJRT C API lives in the main OpenXLA +repository (see `bazel build` command below). + +## Building + +```bash +# Build PJRT plugin +bazel build @xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so --cxxopt=-D_GLIBCXX_USE_CXX11_ABI=1 --config=cuda +# Copy to package dir +cp bazel-bin/external/xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so plugins/cuda/torch_xla_cuda_plugin + +# Build wheel +pip wheel plugins/cuda +# Or install directly +pip install plugins/cuda +``` + +## Usage + +```python +import os + +# Log device type +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' +os.environ['TF_CPP_VMODULE'] = 'pjrt_registry=5' + +from torch_xla.experimental import plugins +import torch_xla_cuda_plugin +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +# Use dynamic plugin instead of built-in CUDA support +plugins.use_dynamic_plugins() +plugins.register_plugin('CUDA', torch_xla_cuda_plugin.GpuPlugin()) +xr.set_device_type('CUDA') + +print(xm.xla_device()) +``` diff --git a/plugins/cuda/pyproject.toml b/plugins/cuda/pyproject.toml new file mode 100644 index 000000000000..306b30495ea0 --- /dev/null +++ b/plugins/cuda/pyproject.toml @@ -0,0 +1,18 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[project] +name = "torch_xla_cuda_plugin" +version = "0.0.1" +authors = [ + {name = "Will Cromar", email = "wcromar@google.com"}, +] +description = "CUDA Plugin" +requires-python = ">=3.8" + +[tool.setuptools.package-data] +torch_xla_cuda_plugin = ["*.so"] + +[project.entry-points."torch_xla.plugins"] +gpu = "torch_xla_cuda_plugin:GpuPlugin" diff --git a/plugins/cuda/torch_xla_cuda_plugin/__init__.py b/plugins/cuda/torch_xla_cuda_plugin/__init__.py new file mode 100644 index 000000000000..d08f512e683d --- /dev/null +++ b/plugins/cuda/torch_xla_cuda_plugin/__init__.py @@ -0,0 +1,11 @@ +import os +from torch_xla.experimental import plugins +from torch_xla._internal import tpu + +class GpuPlugin(plugins.DevicePlugin): + def library_path(self) -> str: + return os.path.join(os.path.dirname(__file__), 'pjrt_c_api_gpu_plugin.so') + + def physical_chip_count(self) -> int: + # TODO: default to actual device count + return os.getenv('GPU_NUM_DEVICES', 1) diff --git a/test/pjrt/test_dtypes.py b/test/pjrt/test_dtypes.py new file mode 100644 index 000000000000..ebac882efdf4 --- /dev/null +++ b/test/pjrt/test_dtypes.py @@ -0,0 +1,35 @@ +from absl.testing import absltest, parameterized +import torch +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + + +class TestDtypes(parameterized.TestCase): + + @parameterized.parameters(torch.float16, torch.float32, torch.float64, + torch.bfloat16, torch.complex64) + def test_float_round_trip(self, dtype: torch.dtype): + t = torch.randn((3, 3), dtype=dtype) + xt = t.to(xm.xla_device()) + torch.testing.assert_close(xt.cpu(), t) + + @parameterized.parameters( + torch.uint8, + torch.int8, + torch.int16, + torch.int32, + torch.int64, + ) + def test_int_round_trip(self, dtype: torch.dtype): + t = torch.randint(0, 128, (3, 3), dtype=dtype) + xt = t.to(xm.xla_device()) + torch.testing.assert_close(xt.cpu(), t) + + def test_bool_round_trip(self): + t = torch.randint(0, 2, (3, 3), dtype=torch.bool) + xt = t.to(xm.xla_device()) + torch.testing.assert_close(xt.cpu(), t) + + +if __name__ == "__main__": + absltest.main() diff --git a/test/run_tests.sh b/test/run_tests.sh index 4e5dc6e90f6b..f73dc156df76 100755 --- a/test/run_tests.sh +++ b/test/run_tests.sh @@ -128,7 +128,7 @@ function run_torchrun { echo "Running torchrun test for GPU $@" num_devices=$(nvidia-smi --list-gpus | wc -l) PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node $num_devices $@ - fi + fi } function run_torch_op_tests { @@ -190,6 +190,7 @@ function run_xla_op_tests1 { # DO NOT MODIFY function run_xla_op_tests2 { run_downcast_bf16 "$CDIR/test_data_type.py" + run_test "$CDIR/pjrt/test_dtypes.py" run_test "$CDIR/test_autocast.py" # TODO(yeounoh) this is expensive on GPU } @@ -235,6 +236,7 @@ function run_mp_op_tests { run_test "$CDIR/test_mp_save.py" run_test "$CDIR/test_mp_mesh_reduce.py" run_test "$CDIR/test_mp_sync_batch_norm.py" + run_test "$CDIR/test_mp_early_exit.py" run_pt_xla_debug "$CDIR/debug_tool/test_mp_pt_xla_debug.py" run_xla_backend_mp "$CDIR/test_torch_distributed_all_gather_xla_backend.py" run_xla_backend_mp "$CDIR/test_torch_distributed_all_reduce_xla_backend.py" diff --git a/test/test_core_aten_ops.py b/test/test_core_aten_ops.py index 9f8f786d0545..e01f35196c32 100644 --- a/test/test_core_aten_ops.py +++ b/test/test_core_aten_ops.py @@ -295,7 +295,6 @@ def test_aten__adaptive_avg_pool3d_1(self): run_export_and_compare(self, torch.ops.aten._adaptive_avg_pool3d, args, kwargs) - @unittest.skip def test_aten_add_Scalar_0(self): args = ( torch.randn((10, 10)).to(torch.float32), @@ -1641,13 +1640,18 @@ def test_aten_expm1_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.expm1, args, kwargs) - @unittest.skip def test_aten_expm1_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() - run_export_and_compare(self, torch.ops.aten.expm1, args, kwargs) + run_export_and_compare( + self, + torch.ops.aten.expm1, + args, + kwargs, + rtol=0.001, + atol=0.01, + ) - @unittest.skip def test_aten_expm1_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() @@ -1757,7 +1761,6 @@ def test_aten_floor_divide_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.floor_divide, args, kwargs) - @unittest.skip def test_aten_floor_divide_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -2475,7 +2478,6 @@ def test_aten_logical_or_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.logical_or, args, kwargs) - @unittest.skip def test_aten_logical_or_1(self): args = ( torch.randn((10, 10)).to(torch.float16), @@ -3347,7 +3349,6 @@ def test_aten_prod_0(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.prod, args, kwargs) - @unittest.skip def test_aten_prod_1(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() @@ -4018,7 +4019,6 @@ def test_aten_sinh_1(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.sinh, args, kwargs) - @unittest.skip def test_aten_sinh_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() @@ -4197,19 +4197,16 @@ def test_aten_split_with_sizes_2(self): kwargs = dict() run_export_and_compare(self, torch.ops.aten.split_with_sizes, args, kwargs) - @unittest.skip def test_aten_sqrt_0(self): args = (torch.randn((10, 10)).to(torch.float32),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.sqrt, args, kwargs) - @unittest.skip def test_aten_sqrt_1(self): args = (torch.randn((10, 10)).to(torch.float16),) kwargs = dict() run_export_and_compare(self, torch.ops.aten.sqrt, args, kwargs) - @unittest.skip def test_aten_sqrt_2(self): args = (torch.randint(0, 10, (10, 10)).to(torch.int32),) kwargs = dict() diff --git a/test/test_mp_early_exit.py b/test/test_mp_early_exit.py new file mode 100644 index 000000000000..837aea1751be --- /dev/null +++ b/test/test_mp_early_exit.py @@ -0,0 +1,26 @@ +import sys +import torch +import torch_xla +import torch_xla.core.xla_model as xm +import torch_xla.distributed.parallel_loader as pl +import torch_xla.distributed.xla_multiprocessing as xmp +import torch_xla.utils.utils as xu + + +def _mp_fn(index): + device = xm.xla_device() + if xm.xla_device_hw(device) in ('TPU', 'GPU', 'CUDA', 'ROCM', 'NEURON'): + train_loader = xu.SampleGenerator( + data=torch.zeros(1, 12), sample_count=1024) + train_loader = pl.MpDeviceLoader(train_loader, device) + max_steps = 10 + for step, inputs in enumerate(train_loader): + xm.all_reduce('sum', [inputs], scale=1.0 / xm.xrt_world_size()) + if step > max_steps: + break + else: + print(f'{device} is not a TPU or GPU device', file=sys.stderr) + + +if __name__ == '__main__': + xmp.spawn(_mp_fn, args=()) diff --git a/test/tpu/xla_test_job.yaml b/test/tpu/xla_test_job.yaml index faf2d67ba08c..c65b2e5692e7 100644 --- a/test/tpu/xla_test_job.yaml +++ b/test/tpu/xla_test_job.yaml @@ -57,6 +57,7 @@ spec: python3 /src/pytorch/xla/test/test_autocast.py python3 /src/pytorch/xla/test/dynamo/test_dynamo.py python3 /src/pytorch/xla/test/spmd/test_spmd_debugging.py + python3 /src/pytorch/xla/test/pjrt/test_dtypes.py python3 /src/pytorch/xla/test/pjrt/test_dynamic_plugin_tpu.py volumeMounts: - mountPath: /dev/shm diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index d753f8f7c8f2..8d4997e28556 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -148,6 +148,8 @@ def _setup_tpu_vm_library_path() -> bool: def _prepare_to_exit(): + device = _XLAC._xla_get_default_device() + _XLAC._set_all_reduce_token(device, None) _XLAC._prepare_to_exit() if int(os.environ.get('PT_XLA_DEBUG', '0')): _summarize_fn_tracker() diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index 1e82e1f352e9..5322e8265e2c 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -2888,12 +2888,6 @@ std::vector XLANativeFunctions::split_with_sizes_copy( return bridge::AtenFromXlaTensors(xla_tensors); } -at::Tensor XLANativeFunctions::sqrt(const at::Tensor& self) { - TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); - return bridge::AtenFromXlaTensor( - tensor_methods::sqrt(bridge::GetXlaTensor(self))); -} - at::Tensor XLANativeFunctions::squeeze_copy(const at::Tensor& self) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); return bridge::AtenFromXlaTensor( @@ -3649,16 +3643,6 @@ at::Tensor XLANativeFunctions::pixel_unshuffle(const at::Tensor& self, pixel_unshuffle)>::call(self, downscale_factor); } -at::Tensor XLANativeFunctions::reshape_symint(const at::Tensor& self, - c10::SymIntArrayRef shape) { - // See Note: [Disabling functionalization] - if (runtime::sys_util::GetEnvBool("XLA_DISABLE_FUNCTIONALIZATION", false)) { - return at::native::reshape_symint(self, shape); - } - return at::functionalization::functionalize_aten_op_symint::call(self, shape); -} - at::Tensor XLANativeFunctions::select_backward_symint( const at::Tensor& grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index) { diff --git a/torch_xla/csrc/convert_ops.cpp b/torch_xla/csrc/convert_ops.cpp index a920bdb69e9d..cd86e0f31697 100644 --- a/torch_xla/csrc/convert_ops.cpp +++ b/torch_xla/csrc/convert_ops.cpp @@ -15,11 +15,6 @@ namespace torch_xla { namespace { -xla::XlaOp ExplicitBooleanConvert(xla::XlaOp op, xla::PrimitiveType from) { - xla::XlaOp zero = xla::Zero(op.builder(), from); - return xla::Ne(op, zero); -} - xla::XlaOp CreateRawMask(xla::XlaOp op, xla::PrimitiveType type, int64_t size, int64_t narrow_size) { uint64_t mask_value = @@ -53,50 +48,20 @@ xla::XlaOp ConvertData(xla::XlaOp op, xla::PrimitiveType type, } // namespace xla::XlaOp ConvertTo(xla::XlaOp op, xla::PrimitiveType from, - xla::PrimitiveType to, - const torch::lazy::BackendDevice* device) { + xla::PrimitiveType to) { if (from == to) { return op; } - XlaDeviceType hw_type = - static_cast(bridge::GetDeviceOrCurrent(device).type()); - if (hw_type != XlaDeviceType::TPU) { - return xla::ConvertElementType(op, to); - } - switch (from) { - case xla::PrimitiveType::PRED: - case xla::PrimitiveType::S8: - case xla::PrimitiveType::U8: - case xla::PrimitiveType::S16: - case xla::PrimitiveType::U16: - case xla::PrimitiveType::S32: - case xla::PrimitiveType::U32: - case xla::PrimitiveType::BF16: - case xla::PrimitiveType::F32: - return xla::ConvertElementType(op, to); - case xla::PrimitiveType::S64: - case xla::PrimitiveType::U64: { - switch (to) { - case xla::PrimitiveType::PRED: - return ExplicitBooleanConvert(op, from); - default: - return xla::ConvertElementType(op, to); - } - break; - } - default: - XLA_ERROR() << "Unsupported XLA type " << from; - } + return xla::ConvertElementType(op, to); } xla::XlaOp ConvertToRaw(xla::XlaOp op, xla::PrimitiveType from, xla::PrimitiveType raw_from, xla::PrimitiveType to, - xla::PrimitiveType raw_to, - const torch::lazy::BackendDevice* device) { + xla::PrimitiveType raw_to) { if (from != raw_from) { op = ConvertData(op, from, raw_from); } - xla::XlaOp result = ConvertTo(op, from, to, device); + xla::XlaOp result = ConvertTo(op, from, to); return to == raw_to ? result : ConvertData(result, to, raw_to); } @@ -105,8 +70,7 @@ xla::XlaOp ConvertToNumeric(xla::XlaOp op, xla::PrimitiveType from) { torch::lazy::BackendDevice xla_device = bridge::GetCurrentDevice(); op = ConvertTo( op, from, - MaybeDowncastToXlaDeviceType(xla::PrimitiveType::U8, xla_device), - &xla_device); + MaybeDowncastToXlaDeviceType(xla::PrimitiveType::U8, xla_device)); } return op; } @@ -120,7 +84,7 @@ xla::XlaOp CastToScalarType(xla::XlaOp input, if (dtype) { torch::lazy::BackendDevice xla_device = bridge::GetCurrentDevice(); return ConvertTo(input, XlaHelpers::TypeOfXlaOp(input), - MakeXlaPrimitiveType(*dtype, &xla_device), &xla_device); + MakeXlaPrimitiveType(*dtype, &xla_device)); } return ConvertToNumeric(input, XlaHelpers::TypeOfXlaOp(input)); } diff --git a/torch_xla/csrc/convert_ops.h b/torch_xla/csrc/convert_ops.h index 3dd0ce99f3a4..029599667bd4 100644 --- a/torch_xla/csrc/convert_ops.h +++ b/torch_xla/csrc/convert_ops.h @@ -11,13 +11,11 @@ namespace torch_xla { xla::XlaOp ConvertTo(xla::XlaOp op, xla::PrimitiveType from, - xla::PrimitiveType to, - const torch::lazy::BackendDevice* device); + xla::PrimitiveType to); xla::XlaOp ConvertToRaw(xla::XlaOp op, xla::PrimitiveType from, xla::PrimitiveType raw_from, xla::PrimitiveType to, - xla::PrimitiveType raw_to, - const torch::lazy::BackendDevice* device); + xla::PrimitiveType raw_to); xla::XlaOp ConvertToNumeric(xla::XlaOp op, xla::PrimitiveType from); @@ -32,4 +30,4 @@ xla::XlaOp MaybeConvertTo(xla::XlaOp input, xla::PrimitiveType type); } // namespace torch_xla -#endif // XLA_TORCH_XLA_CSRC_CONVERT_OPS_H_ \ No newline at end of file +#endif // XLA_TORCH_XLA_CSRC_CONVERT_OPS_H_ diff --git a/torch_xla/csrc/data_ops.cpp b/torch_xla/csrc/data_ops.cpp index bb02ca7da2ef..3eb04baa8e8f 100644 --- a/torch_xla/csrc/data_ops.cpp +++ b/torch_xla/csrc/data_ops.cpp @@ -160,7 +160,7 @@ xla::XlaOp BuildMaskedFillScalar(xla::XlaOp input, xla::XlaOp mask, xla::XlaOp mask_pred = xla::Ne(mask, zero); xla::XlaOp update_scalar = ConvertTo(scalar, ShapeHelper::ShapeOfXlaOp(scalar).element_type(), - ShapeHelper::ShapeOfXlaOp(input).element_type(), nullptr); + ShapeHelper::ShapeOfXlaOp(input).element_type()); return xla::Select(mask_pred, update_scalar, input); } @@ -291,7 +291,7 @@ xla::XlaOp BuildUpdateSlice(xla::XlaOp input, xla::XlaOp source, xla::XlaOp update_source = source; if (source_shape.element_type() != input_shape.element_type()) { update_source = ConvertTo(source, source_shape.element_type(), - input_shape.element_type(), /*device=*/nullptr); + input_shape.element_type()); } xla::XlaOp reshaped_source = XlaHelpers::ReshapeToRank(update_source, input_shape.rank()); diff --git a/torch_xla/csrc/dtype.cpp b/torch_xla/csrc/dtype.cpp index 103630bcec8f..d4ed1b413a63 100644 --- a/torch_xla/csrc/dtype.cpp +++ b/torch_xla/csrc/dtype.cpp @@ -75,16 +75,6 @@ bool Use32BitLong() { return use_32bit_long; } -bool IsTpuDevice(XlaDeviceType hw_type) { - static bool spmd_device_is_tpu = - (hw_type == XlaDeviceType::SPMD) && - // HACK: find a better way to decide if SPMD is actually a TPU without - // accessing the runtime. - runtime::sys_util::GetEnvString("PJRT_DEVICE", "").find("TPU") != - std::string::npos; - return (hw_type == XlaDeviceType::TPU) || spmd_device_is_tpu; -} - } // namespace at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type) { @@ -163,8 +153,7 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType( if (UseBF16()) { return xla::PrimitiveType::BF16; } - if (DowncastBF16() || DowncastF16() || IsTpuDevice(hw_type) || - hw_type == XlaDeviceType::NEURON) { + if (DowncastBF16() || DowncastF16() || hw_type == XlaDeviceType::NEURON) { return xla::PrimitiveType::F32; } return xla::PrimitiveType::F64; @@ -175,20 +164,17 @@ xla::PrimitiveType MaybeDowncastToXlaDeviceType( return UseBF16() || DowncastBF16() ? xla::PrimitiveType::BF16 : xla::PrimitiveType::F32; case xla::PrimitiveType::U16: - return !IsTpuDevice(hw_type) && hw_type != XlaDeviceType::NEURON - ? xla::PrimitiveType::U16 - : xla::PrimitiveType::U32; + return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::U16 + : xla::PrimitiveType::U32; case xla::PrimitiveType::S16: - return !IsTpuDevice(hw_type) && hw_type != XlaDeviceType::NEURON - ? xla::PrimitiveType::S16 - : xla::PrimitiveType::S32; + return hw_type != XlaDeviceType::NEURON ? xla::PrimitiveType::S16 + : xla::PrimitiveType::S32; case xla::PrimitiveType::S64: return Use32BitLong() ? xla::PrimitiveType::S32 : xla::PrimitiveType::S64; case xla::PrimitiveType::U64: return Use32BitLong() ? xla::PrimitiveType::U32 : xla::PrimitiveType::U64; case xla::PrimitiveType::C128: - return !IsTpuDevice(hw_type) ? xla::PrimitiveType::C128 - : xla::PrimitiveType::C64; + return xla::PrimitiveType::C128; default: return type; } diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index 1e4f3be9f788..895b9f9279e1 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -29,7 +29,7 @@ xla::XlaOp ConvertBinaryOpResult(xla::XlaOp op1, xla::XlaOp op2, xla::PrimitiveType type2 = XlaHelpers::TypeOfXlaOp(op2); xla::PrimitiveType result_type = XlaHelpers::TypeOfXlaOp(result); if (type1 == type2 && type1 != result_type) { - return ConvertTo(result, result_type, type1, /*device=*/nullptr); + return ConvertTo(result, result_type, type1); } return result; } @@ -489,10 +489,10 @@ std::pair XlaHelpers::PromoteValues(xla::XlaOp op1, xla::PrimitiveType type2 = TypeOfXlaOp(op2); xla::PrimitiveType result_type = PromoteType(type1, type2); if (type1 != result_type) { - op1 = ConvertTo(op1, type1, result_type, /*device=*/nullptr); + op1 = ConvertTo(op1, type1, result_type); } if (type2 != result_type) { - op2 = ConvertTo(op2, type2, result_type, /*device=*/nullptr); + op2 = ConvertTo(op2, type2, result_type); } return std::pair(op1, op2); } @@ -504,13 +504,13 @@ std::tuple XlaHelpers::PromoteValues( xla::PrimitiveType type3 = TypeOfXlaOp(op3); xla::PrimitiveType result_type = PromoteType(type1, type2, type3); if (type1 != result_type) { - op1 = ConvertTo(op1, type1, result_type, /*device=*/nullptr); + op1 = ConvertTo(op1, type1, result_type); } if (type2 != result_type) { - op2 = ConvertTo(op2, type2, result_type, /*device=*/nullptr); + op2 = ConvertTo(op2, type2, result_type); } if (type3 != result_type) { - op3 = ConvertTo(op3, type3, result_type, /*device=*/nullptr); + op3 = ConvertTo(op3, type3, result_type); } return std::tuple(op1, op2, op3); } @@ -519,10 +519,9 @@ std::pair XlaHelpers::PromoteSecondValue( xla::XlaOp op1, xla::XlaOp op2) { xla::PrimitiveType type1 = TypeOfXlaOp(op1); xla::PrimitiveType type2 = TypeOfXlaOp(op2); - return type1 == type2 - ? std::pair(op1, op2) - : std::pair( - op1, ConvertTo(op2, type2, type1, /*device=*/nullptr)); + return type1 == type2 ? std::pair(op1, op2) + : std::pair( + op1, ConvertTo(op2, type2, type1)); } xla::Shape XlaHelpers::GetPromotedShape(const xla::Shape& shape1, diff --git a/torch_xla/csrc/matrix.cpp b/torch_xla/csrc/matrix.cpp index 0cffa5c8d2af..eccfc759a3d1 100644 --- a/torch_xla/csrc/matrix.cpp +++ b/torch_xla/csrc/matrix.cpp @@ -110,7 +110,7 @@ xla::XlaOp BuildDiagonalViewUpdate(xla::XlaOp target, xla::XlaOp input, xla::XlaOp diag_input = input; if (target_shape->element_type() != input_shape.element_type()) { diag_input = ConvertTo(input, input_shape.element_type(), - target_shape->element_type(), /*device=*/nullptr); + target_shape->element_type()); } std::vector permutation; xla::XlaOp diag_target = target; diff --git a/torch_xla/csrc/ops/cast.cpp b/torch_xla/csrc/ops/cast.cpp index 95068640a278..f1a0a1a90725 100644 --- a/torch_xla/csrc/ops/cast.cpp +++ b/torch_xla/csrc/ops/cast.cpp @@ -55,8 +55,7 @@ XlaOpVector Cast::Lower(LoweringContext* loctx) const { stype_ ? XlaTypeFromTorchType(*stype_) : input_shape.element_type(); xla::PrimitiveType raw_to = dtype_ ? XlaTypeFromTorchType(*dtype_) : type_; xla::XlaOp output = - ConvertToRaw(input, input_shape.element_type(), raw_from, type_, raw_to, - /*device=*/nullptr); + ConvertToRaw(input, input_shape.element_type(), raw_from, type_, raw_to); return ReturnOp(output, loctx); } diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index 494f79c2b096..0e1e885dde4a 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -262,10 +262,8 @@ torch::lazy::NodePtr Clamp(const torch::lazy::Value& input, xla::XlaOp xla_min = loctx->GetOutputOp(node.operand(1)); xla::XlaOp xla_max = loctx->GetOutputOp(node.operand(2)); xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); - xla_min = ConvertTo(xla_min, XlaHelpers::TypeOfXlaOp(xla_min), input_type, - /*device=*/nullptr); - xla_max = ConvertTo(xla_max, XlaHelpers::TypeOfXlaOp(xla_max), input_type, - /*device=*/nullptr); + xla_min = ConvertTo(xla_min, XlaHelpers::TypeOfXlaOp(xla_min), input_type); + xla_max = ConvertTo(xla_max, XlaHelpers::TypeOfXlaOp(xla_max), input_type); return node.ReturnOp(xla::Clamp(xla_min, xla_input, xla_max), loctx); }; return GenericOp(torch::lazy::OpKind(at::aten::clamp), {input, min, max}, @@ -412,7 +410,7 @@ torch::lazy::NodePtr Where(const torch::lazy::Value& condition, xla::XlaOp xla_other = loctx->GetOutputOp(node.operand(2)); xla::XlaOp pred_condition = ConvertTo(xla_condition, XlaHelpers::TypeOfXlaOp(xla_condition), - xla::PrimitiveType::PRED, /*device=*/nullptr); + xla::PrimitiveType::PRED); auto promoted_branches = XlaHelpers::ValidateShapes(xla_input, xla_other); return node.ReturnOp(xla::Select(pred_condition, promoted_branches.first, promoted_branches.second), diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index a86d7cee1687..e45c7782eb50 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -25,8 +25,7 @@ torch_xla::XlaOpVector Acos::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) { xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); - xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32, - /*device=*/nullptr); + xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32); } return ReturnOp(xla::Acos(xla_input), loctx); } @@ -393,6 +392,9 @@ torch_xla::XlaOpVector Exp::Lower(LoweringContext* loctx) const { torch_xla::XlaOpVector Expm1::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) { + xla_input = xla::ConvertElementType(xla_input, xla::PrimitiveType::F32); + } return ReturnOp(xla::Expm1(xla_input), loctx); } @@ -734,6 +736,10 @@ torch_xla::XlaOpVector Sin::Lower(LoweringContext* loctx) const { torch_xla::XlaOpVector Sinh::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) { + xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); + xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32); + } return ReturnOp(xla::Sinh(xla_input), loctx); } @@ -757,6 +763,15 @@ torch_xla::XlaOpVector SoftshrinkBackward::Lower(LoweringContext* loctx) const { // return ReturnOps({result.sign, result.logdet}, loctx); // } +torch_xla::XlaOpVector Sqrt::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) { + xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); + xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32); + } + return ReturnOp(xla::Sqrt(xla_input), loctx); +} + torch_xla::XlaOpVector Take::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); xla::XlaOp xla_index = loctx->GetOutputOp(operand(1)); @@ -768,8 +783,7 @@ torch_xla::XlaOpVector Tan::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) { xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); - xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32, - /*device=*/nullptr); + xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32); } return ReturnOp(xla::Tan(xla_input), loctx); } @@ -778,8 +792,7 @@ torch_xla::XlaOpVector Tanh::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); if (xla::primitive_util::IsIntegralType(XlaHelpers::TypeOfXlaOp(xla_input))) { xla::PrimitiveType input_type = XlaHelpers::TypeOfXlaOp(xla_input); - xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32, - /*device=*/nullptr); + xla_input = ConvertTo(xla_input, input_type, xla::PrimitiveType::F32); } return ReturnOp(xla::Tanh(xla_input), loctx); } diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index 74085eb37f3c..a442e6f44267 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -467,7 +467,11 @@ xla::Shape ExpOutputShape(const torch::lazy::Value& input) { } xla::Shape Expm1OutputShape(const torch::lazy::Value& input) { - return GetXlaShape(input); + xla::Shape result_shape = GetXlaShape(input); + if (xla::primitive_util::IsIntegralType(result_shape.element_type())) { + result_shape.set_element_type(xla::PrimitiveType::F32); + } + return result_shape; } xla::Shape FloorOutputShape(const torch::lazy::Value& input) { @@ -807,7 +811,11 @@ xla::Shape SinOutputShape(const torch::lazy::Value& input) { } xla::Shape SinhOutputShape(const torch::lazy::Value& input) { - return GetXlaShape(input); + xla::Shape result_shape = GetXlaShape(input); + if (xla::primitive_util::IsIntegralType(result_shape.element_type())) { + result_shape.set_element_type(xla::PrimitiveType::F32); + } + return result_shape; } xla::Shape SoftshrinkOutputShape(const torch::lazy::Value& self, @@ -831,6 +839,14 @@ xla::Shape SoftshrinkBackwardOutputShape(const torch::lazy::Value& grad_out, // return InferOutputShape({GetXlaShape(input)}, lower_for_shape_fn); // } +xla::Shape SqrtOutputShape(const torch::lazy::Value& input) { + xla::Shape result_shape = GetXlaShape(input); + if (xla::primitive_util::IsIntegralType(result_shape.element_type())) { + result_shape.set_element_type(xla::PrimitiveType::F32); + } + return result_shape; +} + xla::Shape TanOutputShape(const torch::lazy::Value& input) { xla::Shape result_shape = GetXlaShape(input); if (xla::primitive_util::IsIntegralType(result_shape.element_type())) { diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.h b/torch_xla/csrc/ops/ops_xla_shape_fn.h index 57cda7b83f20..6f961f50cde9 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.h +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.h @@ -268,6 +268,8 @@ xla::Shape SoftshrinkBackwardOutputShape(const torch::lazy::Value& grad_out, /* Blocked on https://github.com/pytorch/xla/issues/3596 */ // xla::Shape SlogdetOutputShape(const torch::lazy::Value& input); +xla::Shape SqrtOutputShape(const torch::lazy::Value& input); + xla::Shape TakeOutputShape(const torch::lazy::Value& input, const torch::lazy::Value& index); diff --git a/torch_xla/csrc/ops/prod.cpp b/torch_xla/csrc/ops/prod.cpp index e61bc9541eae..1790d477ccee 100644 --- a/torch_xla/csrc/ops/prod.cpp +++ b/torch_xla/csrc/ops/prod.cpp @@ -20,8 +20,7 @@ xla::XlaOp LowerProd(xla::XlaOp input, const std::vector& dimensions, xla::XlaOp casted_input; if (dtype) { casted_input = ConvertTo(input, XlaHelpers::TypeOfXlaOp(input), - MakeXlaPrimitiveType(*dtype, /*device=*/nullptr), - /*device=*/nullptr); + MakeXlaPrimitiveType(*dtype, /*device=*/nullptr)); } else { casted_input = ConvertToNumeric(input, XlaHelpers::TypeOfXlaOp(input)); } diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index ab4ca95f703c..13c9dfdbf58a 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -2579,10 +2579,6 @@ std::vector split_with_sizes(const XLATensorPtr& input, return input->MakeOutputTensors(node); } -XLATensorPtr sqrt(const XLATensorPtr& input) { - return input->CreateFrom(Sqrt(input->GetIrValue())); -} - XLATensorPtr squeeze(const XLATensorPtr& input) { auto input_shape = input->shape(); auto output_dimensions = BuildSqueezedDimensions( diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index acddf38482df..ffa05f02fe8f 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -818,8 +818,6 @@ std::vector split_with_sizes(const XLATensorPtr& input, std::vector split_size, int64_t dim); -XLATensorPtr sqrt(const XLATensorPtr& input); - // Squeeze out all trivial (size 1) dimensions. XLATensorPtr squeeze(const XLATensorPtr& input); diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index 34722987954f..29230546844b 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -170,7 +170,7 @@ xla::XlaOp CreateIndexAlongDim( xla::XlaOp updates = value; if (buffer_shape.element_type() != value_shape.element_type()) { updates = ConvertTo(updates, value_shape.element_type(), - buffer_shape.element_type(), /*device=*/nullptr); + buffer_shape.element_type()); } if (broadcast_value_to_index) { const xla::Shape& index_shape = ShapeHelper::ShapeOfXlaOp(index); @@ -603,7 +603,7 @@ xla::XlaOp CreateIndexUpdate( xla::XlaOp new_values = values; if (buffer_shape.element_type() != values_shape.element_type()) { new_values = ConvertTo(new_values, values_shape.element_type(), - buffer_shape.element_type(), /*device=*/nullptr); + buffer_shape.element_type()); } new_values = BuildExpand(new_values, expected_values_dims); const xla::Shape& new_values_shape = ShapeHelper::ShapeOfXlaOp(new_values); @@ -654,8 +654,7 @@ XlaOpCombiner NumericAddCombiner() { xla::XlaOp numeric_y = ConvertToNumeric(y); xla::XlaOp numeric_sum = numeric_x + numeric_y; return ConvertTo(numeric_sum, XlaHelpers::TypeOfXlaOp(numeric_sum), - XlaHelpers::TypeOfXlaOp(x), - /*device=*/nullptr); + XlaHelpers::TypeOfXlaOp(x)); }; } @@ -665,8 +664,7 @@ XlaOpCombiner NumericMulCombiner() { xla::XlaOp numeric_y = ConvertToNumeric(y); xla::XlaOp numeric_sum = numeric_x * numeric_y; return ConvertTo(numeric_sum, XlaHelpers::TypeOfXlaOp(numeric_sum), - XlaHelpers::TypeOfXlaOp(x), - /*device=*/nullptr); + XlaHelpers::TypeOfXlaOp(x)); }; } @@ -677,8 +675,7 @@ XlaOpCombiner NumericMinCombiner() { xla::XlaOp numeric_sum = xla::Min(numeric_x, numeric_y); // xla::XlaOp numeric_sum = xla::Min(numeric_x, numeric_y); return ConvertTo(numeric_sum, XlaHelpers::TypeOfXlaOp(numeric_sum), - XlaHelpers::TypeOfXlaOp(x), - /*device=*/nullptr); + XlaHelpers::TypeOfXlaOp(x)); }; } @@ -688,8 +685,7 @@ XlaOpCombiner NumericMaxCombiner() { xla::XlaOp numeric_y = ConvertToNumeric(y); xla::XlaOp numeric_sum = xla::Max(numeric_x, numeric_y); return ConvertTo(numeric_sum, XlaHelpers::TypeOfXlaOp(numeric_sum), - XlaHelpers::TypeOfXlaOp(x), - /*device=*/nullptr); + XlaHelpers::TypeOfXlaOp(x)); }; } diff --git a/torch_xla/distributed/xla_backend.py b/torch_xla/distributed/xla_backend.py index 1935a12d42a2..6b89911c9e77 100644 --- a/torch_xla/distributed/xla_backend.py +++ b/torch_xla/distributed/xla_backend.py @@ -15,7 +15,7 @@ def _create_xla_process_group(prefix_store, rank, size, timeout): def _register_xla_backend(): - dist.Backend.register_backend('xla', _create_xla_process_group) + dist.Backend.register_backend('xla', _create_xla_process_group, devices='xla') _register_xla_backend()