diff --git a/models/utility_functions.py b/models/utility_functions.py index 2b652f81542..f13fd48d8ca 100644 --- a/models/utility_functions.py +++ b/models/utility_functions.py @@ -15,6 +15,8 @@ from ttnn.device import Arch +from typing_extensions import deprecated + ### Math operations ### def _nearest_32(x): @@ -430,108 +432,22 @@ def convert_act_2d_matrix(activation, kernel_y, kernel_x, stride_y, stride_x, pa ### Tilizing / Untilizing ### +@deprecated("PyTorch data is handled automatically in tensor infra. This function does nothing now:") def tilize(x): - """ - This function tilizes a tensor. The last two tensor dims must be divisible by 32, after which this function - produces row major tiles and creates faces. The output of this function is a flattened list that - we can send to the device. - - :param x: Input PyTorch Tensor - :type x: class:`torch.Tensor` - - WARNING: This function should eventually be retired in favour of fully tilizing on device. - """ - nearest_32 = _nearest_32 - - assert isinstance( - x, (torch.Tensor, np.ndarray) - ), "Input to this function must be an instance of torch.Tensor or np.array" - assert len(x.shape) == 4, "Only 4D tensors suppported" - assert (x.shape[-2] % 32) == 0 and ( - x.shape[-1] % 32 - ) == 0, "The last two dimensions of the tensor must be divisible by 32" - - if isinstance(x, torch.Tensor): - ret = torch.zeros(np.prod(x.shape)) - else: - ret = np.zeros(np.prod(x.shape)) - - idx = 0 - for B in range(x.shape[0]): - for C in range(x.shape[1]): - for H in range(0, x.shape[2], 32): - for W in range(0, x.shape[3], 32): - unfaced_tile = x[B, C, H : H + 32, W : W + 32] - - face0 = unfaced_tile[:16, :16] - face1 = unfaced_tile[:16, 16:] - face2 = unfaced_tile[16:, :16] - face3 = unfaced_tile[16:, 16:] - - for face in (face0, face1, face2, face3): - ret[idx : idx + 256] = face.reshape(-1) - idx += 256 - - return ret.reshape(x.shape) + return x +@deprecated("PyTorch data is handled automatically in tensor infra. This function does nothing now:") def tilize_to_list(x): """ - Tilize a PyTorch and then return the values as a flat list. The last two - tensor dims must be divisible by 32, after which this function produces row - major tiles and creates faces. - - :param x: Input PyTorch Tensor - :type x: class:`torch.Tensor` - - WARNING: This function should eventually be retired in favour of fully tilizing on device. + Returns a flattened list of the tensor """ - return tilize(x).reshape(-1).tolist() +@deprecated("PyTorch data is handled automatically in tensor infra. This function does nothing now:") def untilize(x): - """ - This function untilizes a tensor to row major format. - - :param x: Input PyTorch Tensor - :type x: class:`torch.Tensor` - - WARNING: This function should eventually be retired in favour of fully tilizing on device. - """ - nearest_32 = _nearest_32 - - assert isinstance(x, (torch.Tensor, np.ndarray)), "Input to this function must be an instance of torch.Tensor" - assert len(x.shape) == 4, "Only 4D tensors suppported" - assert (x.shape[-2] % 32) == 0 and ( - x.shape[-1] % 32 - ) == 0, "The last two dimensions of the tensor must be divisible by 32" - - if isinstance(x, torch.Tensor): - ret = torch.zeros(x.shape, dtype=x.dtype) - else: - ret = np.zeros(x.shape, dtype=x.dtype) - - for B in range(x.shape[0]): - for C in range(x.shape[1]): - x_hw = x[B, C, :].reshape(-1) - hw = 0 - for h in range(0, x.shape[2], 32): - for w in range(0, x.shape[3], 32): - f_tile = x_hw[hw : hw + 256].reshape(16, 16) - ret[B, C, h : h + 16, w : w + 16] = f_tile - - f_tile = x_hw[hw + 256 : hw + 512].reshape(16, 16) - ret[B, C, h : h + 16, w + 16 : w + 32] = f_tile - - f_tile = x_hw[hw + 512 : hw + 768].reshape(16, 16) - ret[B, C, h + 16 : h + 32, w : w + 16] = f_tile - - f_tile = x_hw[hw + 768 : hw + 1024].reshape(16, 16) - ret[B, C, h + 16 : h + 32, w + 16 : w + 32] = f_tile - hw += 1024 # traverse tiles in RM-order - - return ret + return x ### Measuring accuracy and other metrics ### diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_indexed_fill.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_indexed_fill.py index 4245a35c3c2..3044f6bbb89 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_indexed_fill.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_indexed_fill.py @@ -9,15 +9,7 @@ import ttnn import torch import numpy as np - - -tt_dtype_to_torch_dtype = { - ttnn.uint16: torch.int16, - ttnn.uint32: torch.int32, - ttnn.float32: torch.float, - ttnn.bfloat16: torch.bfloat16, - ttnn.bfloat8_b: torch.float, -} +from tests.ttnn.utils_for_testing import tt_dtype_to_torch_dtype @pytest.mark.parametrize( diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_non_zero.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_non_zero.py index e672856c3e2..b280d8e0b66 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_non_zero.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_non_zero.py @@ -10,15 +10,7 @@ import torch import numpy as np import ttnn - - -tt_dtype_to_torch_dtype = { - ttnn.uint16: torch.int16, - ttnn.uint32: torch.int32, - ttnn.float32: torch.float, - ttnn.bfloat16: torch.bfloat16, - ttnn.bfloat8_b: torch.float, -} +from tests.ttnn.utils_for_testing import tt_dtype_to_torch_dtype @pytest.mark.parametrize( diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded_tensor.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded_tensor.py index 050099d62d6..1c19b8137e6 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded_tensor.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_sharded_tensor.py @@ -11,14 +11,9 @@ import ttnn from models.utility_functions import get_debug_tensor +from tests.ttnn.utils_for_testing import tt_dtype_to_torch_dtype from enum import Enum -tt_dtype_to_torch_dtype = { - ttnn.uint32: torch.int32, - ttnn.uint16: torch.int16, - ttnn.bfloat16: torch.bfloat16, - ttnn.bfloat8_b: torch.float, -} TILE_WIDTH = 32 TILE_HEIGHT = 32 diff --git a/tests/ttnn/unit_tests/tensor/test_tensor_conversion.py b/tests/ttnn/unit_tests/tensor/test_tensor_conversion.py index 63442308831..2fff322de44 100644 --- a/tests/ttnn/unit_tests/tensor/test_tensor_conversion.py +++ b/tests/ttnn/unit_tests/tensor/test_tensor_conversion.py @@ -11,29 +11,10 @@ import numpy as np import ttnn - -tt_dtype_to_torch_dtype = { - ttnn.uint8: torch.uint8, - ttnn.uint16: torch.int16, - ttnn.uint32: torch.int32, - ttnn.int32: torch.int32, - ttnn.float32: torch.float, - ttnn.bfloat16: torch.bfloat16, - ttnn.bfloat8_b: torch.float, - ttnn.bfloat4_b: torch.float, -} - -tt_dtype_to_np_dtype = { - ttnn.uint8: np.ubyte, - ttnn.uint16: np.int16, - ttnn.uint32: np.int32, - ttnn.int32: np.int32, - ttnn.float32: np.float32, - ttnn.bfloat8_b: np.float32, - ttnn.bfloat4_b: np.float32, -} +from tests.ttnn.utils_for_testing import tt_dtype_to_torch_dtype, tt_dtype_to_np_dtype +@pytest.mark.parametrize("convert_to_device", [True, False]) @pytest.mark.parametrize( "tt_dtype", [ @@ -49,7 +30,7 @@ ) @pytest.mark.parametrize("shape", [(2, 3, 64, 96)]) @pytest.mark.parametrize("python_lib", [torch, np]) -def test_tensor_conversion_with_tt_dtype(python_lib, shape, tt_dtype, device): +def test_tensor_conversion_with_tt_dtype(python_lib, shape, tt_dtype, convert_to_device, device): torch.manual_seed(0) if python_lib == torch: @@ -64,7 +45,7 @@ def test_tensor_conversion_with_tt_dtype(python_lib, shape, tt_dtype, device): elif python_lib == np: if tt_dtype == ttnn.bfloat16: - pytest.skip("ttnn.bloat16 dtype is not supported yet for numpy tensors!") + pytest.skip("ttnn.bfloat16 dtype is not supported yet for numpy tensors!") dtype = tt_dtype_to_np_dtype[tt_dtype] if dtype in {np.ubyte, np.int16, np.int32}: @@ -82,8 +63,9 @@ def test_tensor_conversion_with_tt_dtype(python_lib, shape, tt_dtype, device): assert tt_tensor.storage_type() == ttnn.StorageType.BORROWED assert tt_tensor.layout == ttnn.ROW_MAJOR_LAYOUT - tt_tensor = tt_tensor.to(device) - tt_tensor = tt_tensor.cpu() + if convert_to_device: + tt_tensor = tt_tensor.to(device) + tt_tensor = tt_tensor.cpu() if python_lib == torch: py_tensor_after_round_trip = tt_tensor.to_torch() @@ -123,6 +105,7 @@ def test_tensor_conversion_with_tt_dtype(python_lib, shape, tt_dtype, device): } +@pytest.mark.parametrize("convert_to_device", [True, False]) @pytest.mark.parametrize( "python_dtype_str", [ @@ -137,7 +120,7 @@ def test_tensor_conversion_with_tt_dtype(python_lib, shape, tt_dtype, device): ) @pytest.mark.parametrize("shape", [(2, 3, 64, 96)]) @pytest.mark.parametrize("python_lib", [torch, np]) -def test_tensor_conversion_with_python_dtype(python_lib, shape, python_dtype_str, device): +def test_tensor_conversion_with_python_dtype(python_lib, shape, python_dtype_str, convert_to_device, device): torch.manual_seed(0) if python_lib == torch: @@ -165,8 +148,9 @@ def test_tensor_conversion_with_python_dtype(python_lib, shape, python_dtype_str tt_tensor = ttnn.Tensor(py_tensor) assert tt_tensor.storage_type() == ttnn.StorageType.BORROWED - tt_tensor = tt_tensor.to(device) - tt_tensor = tt_tensor.cpu() + if convert_to_device: + tt_tensor = tt_tensor.to(device) + tt_tensor = tt_tensor.cpu() if python_lib == torch: py_tensor_after_round_trip = tt_tensor.to_torch() diff --git a/tests/ttnn/unit_tests/tensor/test_tensor_creation.py b/tests/ttnn/unit_tests/tensor/test_tensor_creation.py new file mode 100644 index 00000000000..f0615abba97 --- /dev/null +++ b/tests/ttnn/unit_tests/tensor/test_tensor_creation.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +import os +import pathlib + +import torch +import numpy as np + +import ttnn +from tests.ttnn.utils_for_testing import tt_dtype_to_torch_dtype + + +@pytest.mark.parametrize( + "layout", + [ + ttnn.ROW_MAJOR_LAYOUT, + ttnn.TILE_LAYOUT, + ], +) +@pytest.mark.parametrize( + "tt_dtype", + [ + ttnn.uint8, + ttnn.uint16, + ttnn.uint32, + ttnn.int32, + ttnn.float32, + ttnn.bfloat16, + ttnn.bfloat8_b, + ttnn.bfloat4_b, + ], +) +@pytest.mark.parametrize("shape", [(2, 3, 64, 96)]) +def test_tensor_creation(shape, tt_dtype, layout, device): + torch.manual_seed(0) + + dtype = tt_dtype_to_torch_dtype[tt_dtype] + + if dtype in {torch.uint8, torch.int16, torch.int32}: + py_tensor = torch.randint(torch.iinfo(dtype).min, torch.iinfo(dtype).max, shape, dtype=dtype) + else: + py_tensor = torch.rand(shape, dtype=dtype) + + tt_tensor = ttnn.Tensor(py_tensor, tt_dtype, device, layout) + + tt_tensor = tt_tensor.cpu() + + py_tensor_after_round_trip = tt_tensor.to_torch() + + assert py_tensor.dtype == py_tensor_after_round_trip.dtype + assert py_tensor.shape == py_tensor_after_round_trip.shape + + allclose_kwargs = {} + if tt_dtype == ttnn.bfloat8_b: + allclose_kwargs = dict(atol=1e-2) + elif tt_dtype == ttnn.bfloat4_b: + allclose_kwargs = dict(atol=0.2) + + passing = torch.allclose(py_tensor, py_tensor_after_round_trip, **allclose_kwargs) + assert passing + + +@pytest.mark.parametrize( + "layout", + [ + ttnn.ROW_MAJOR_LAYOUT, + ttnn.TILE_LAYOUT, + ], +) +@pytest.mark.parametrize( + "tt_dtype", + [ + ttnn.uint8, + ttnn.uint16, + ttnn.uint32, + ttnn.int32, + ttnn.float32, + ttnn.bfloat16, + ttnn.bfloat8_b, + ttnn.bfloat4_b, + ], +) +@pytest.mark.parametrize("shape", [(2, 3, 64, 96)]) +def test_tensor_creation_api_parity(shape, tt_dtype, layout, device): + torch.manual_seed(0) + + if tt_dtype in (ttnn.bfloat8_b, ttnn.bfloat4_b) and layout == ttnn.ROW_MAJOR_LAYOUT: + pytest.skip("{} is only valid for ttnn.TILE_LAYOUT!".format(tt_dtype)) + + dtype = tt_dtype_to_torch_dtype[tt_dtype] + + if dtype in {torch.uint8, torch.int16, torch.int32}: + py_tensor = torch.randint(torch.iinfo(dtype).min, torch.iinfo(dtype).max, shape, dtype=dtype) + else: + py_tensor = torch.rand(shape, dtype=dtype) + + tt_tensor_1 = ttnn.Tensor(py_tensor, tt_dtype, device, layout) + tt_tensor_2 = ttnn.from_torch(py_tensor, tt_dtype, device=device, layout=layout) + + tt_tensor_1 = tt_tensor_1.cpu() + tt_tensor_2 = tt_tensor_2.cpu() + + py_tensor_after_round_trip_1 = tt_tensor_1.to_torch() + py_tensor_after_round_trip_2 = tt_tensor_2.to_torch() + py_tensor_after_round_trip_3 = ttnn.to_torch(tt_tensor_1) + py_tensor_after_round_trip_4 = ttnn.to_torch(tt_tensor_2) + + allclose_kwargs = {} + if tt_dtype == ttnn.bfloat8_b: + allclose_kwargs = dict(atol=1e-2) + elif tt_dtype == ttnn.bfloat4_b: + allclose_kwargs = dict(atol=0.2) + + passing = torch.allclose(py_tensor, py_tensor_after_round_trip_1, **allclose_kwargs) + passing = torch.allclose(py_tensor, py_tensor_after_round_trip_2, **allclose_kwargs) + passing = torch.allclose(py_tensor, py_tensor_after_round_trip_3, **allclose_kwargs) + passing = torch.allclose(py_tensor, py_tensor_after_round_trip_4, **allclose_kwargs) + assert passing diff --git a/tests/ttnn/unit_tests/tensor/test_tensor_serialization.py b/tests/ttnn/unit_tests/tensor/test_tensor_serialization.py index 1db497c0843..a56dde83d19 100644 --- a/tests/ttnn/unit_tests/tensor/test_tensor_serialization.py +++ b/tests/ttnn/unit_tests/tensor/test_tensor_serialization.py @@ -11,15 +11,7 @@ import numpy as np import ttnn - -tt_dtype_to_torch_dtype = { - ttnn.uint16: torch.int16, - ttnn.uint32: torch.int32, - ttnn.float32: torch.float, - ttnn.bfloat16: torch.bfloat16, - ttnn.bfloat8_b: torch.float, - ttnn.bfloat4_b: torch.float, -} +from tests.ttnn.utils_for_testing import tt_dtype_to_torch_dtype @pytest.mark.parametrize("shape", [(2, 3, 64, 96)]) diff --git a/tests/ttnn/unit_tests/test_print_tensor.py b/tests/ttnn/unit_tests/test_print_tensor.py index 66254f7d363..90f1ecd5157 100644 --- a/tests/ttnn/unit_tests/test_print_tensor.py +++ b/tests/ttnn/unit_tests/test_print_tensor.py @@ -7,14 +7,8 @@ import torch import ttnn +from tests.ttnn.utils_for_testing import tt_dtype_to_torch_dtype -ttnn_dtype_to_torch_dtype = { - ttnn.uint16: torch.int16, - ttnn.uint32: torch.int32, - ttnn.float32: torch.float, - ttnn.bfloat16: torch.bfloat16, - ttnn.bfloat8_b: torch.float, -} GOLDEN_TENSOR_STRINGS = { ( @@ -77,7 +71,7 @@ def test_print(device, dtype, layout, profile, deallocate): ttnn.set_printoptions(profile=profile) - torch_dtype = ttnn_dtype_to_torch_dtype[dtype] + torch_dtype = tt_dtype_to_torch_dtype[dtype] shape = (2, 16, 64, 32) if torch_dtype in {torch.int16, torch.int32}: diff --git a/tests/ttnn/utils_for_testing.py b/tests/ttnn/utils_for_testing.py index fb083a681ff..92849b32e57 100644 --- a/tests/ttnn/utils_for_testing.py +++ b/tests/ttnn/utils_for_testing.py @@ -10,6 +10,33 @@ from models.utility_functions import comp_pcc, comp_equal, divup, roundup from typing import Tuple +import ttnn +import torch +import numpy as np + + +# Dictionaries for converting dtypes +tt_dtype_to_torch_dtype = { + ttnn.uint8: torch.uint8, + ttnn.uint16: torch.int16, + ttnn.uint32: torch.int32, + ttnn.int32: torch.int32, + ttnn.float32: torch.float, + ttnn.bfloat16: torch.bfloat16, + ttnn.bfloat8_b: torch.float, + ttnn.bfloat4_b: torch.float, +} + +tt_dtype_to_np_dtype = { + ttnn.uint8: np.ubyte, + ttnn.uint16: np.int16, + ttnn.uint32: np.int32, + ttnn.int32: np.int32, + ttnn.float32: np.float32, + ttnn.bfloat8_b: np.float32, + ttnn.bfloat4_b: np.float32, +} + def construct_pcc_assert_message(message, expected_pytorch_result, actual_pytorch_result): messages = [] diff --git a/ttnn/cpp/pybind11/pytensor.cpp b/ttnn/cpp/pybind11/pytensor.cpp index 48a360fb3cb..17de2f3493e 100644 --- a/ttnn/cpp/pybind11/pytensor.cpp +++ b/ttnn/cpp/pybind11/pytensor.cpp @@ -66,17 +66,17 @@ void log_external_operation( #endif template -Tensor create_owned_tensor( - T* data_ptr, - size_t num_elements, - tt::stl::Span shape, - DataType data_type, - Layout layout, - const std::optional& optional_tile = std::nullopt) { - auto data = std::vector(data_ptr, data_ptr + num_elements); +Tensor create_owned_tensor(T* data_ptr, const ttnn::TensorSpec& tensor_spec) { + std::size_t num_elements = tensor_spec.logical_shape().volume(); + auto data = std::vector(data_ptr, data_ptr + num_elements); auto buffer = owned_buffer::create(std::move(data)); + + if (tensor_spec.layout() == Layout::TILE) { + data = tensor_impl::convert_layout_row_major_to_tile(tensor_spec.physical_shape(), tensor_spec.tile(), buffer); + buffer = owned_buffer::create(std::move(data)); + } auto storage = OwnedStorage{std::move(buffer)}; - return Tensor(std::move(storage), shape, data_type, layout, optional_tile); + return Tensor(std::move(storage), tensor_spec); } OwnedBuffer create_owned_buffer_from_vector_of_floats(std::vector&& data, DataType data_type) { @@ -138,7 +138,7 @@ Tensor convert_float_vector_to_tt_tensor( return tensor; } auto owned_buffer = create_owned_buffer_from_vector_of_floats(std::move(data), data_type); - auto tensor = Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout, tile); + auto tensor = Tensor(OwnedStorage{owned_buffer}, shape, data_type, Layout::ROW_MAJOR, tile).to(layout); if (device) { return tensor.to(device, memory_config.value_or(MemoryConfig{})); } @@ -146,23 +146,30 @@ Tensor convert_float_vector_to_tt_tensor( } Tensor create_tt_tensor_from_py_data( - std::size_t num_elements, std::size_t py_data_ptr, - const ttnn::SmallVector& shape, - const DataType data_type, - const std::optional& optional_tile, - bool enable_borrow, - const std::function& on_creation_callback = [] {}, - const std::function& on_destruction_callback = [] {}) { + const TensorSpec& tensor_spec, + Device* device, + bool override_enable_borrow, + const std::function& on_creation_callback, + const std::function& on_destruction_callback) { + auto layout = tensor_spec.layout(); + + bool enable_borrow = true; + if (layout != Layout::ROW_MAJOR or override_enable_borrow) { + enable_borrow = false; + } + + auto data_type = tensor_spec.data_type(); + std::size_t num_elements = tensor_spec.logical_shape().volume(); switch (data_type) { case DataType::UINT8: { auto data_ptr = reinterpret_cast(py_data_ptr); if (enable_borrow) { auto storage = BorrowedStorage( borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback); - return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile); + return Tensor(std::move(storage), tensor_spec); } else { - return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR, optional_tile); + return create_owned_tensor(data_ptr, tensor_spec); } } case DataType::UINT16: { @@ -170,9 +177,9 @@ Tensor create_tt_tensor_from_py_data( if (enable_borrow) { auto storage = BorrowedStorage( borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback); - return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile); + return Tensor(std::move(storage), tensor_spec); } else { - return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR, optional_tile); + return create_owned_tensor(data_ptr, tensor_spec); } } case DataType::INT32: { @@ -180,9 +187,9 @@ Tensor create_tt_tensor_from_py_data( if (enable_borrow) { auto storage = BorrowedStorage( borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback); - return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile); + return Tensor(std::move(storage), tensor_spec); } else { - return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR, optional_tile); + return create_owned_tensor(data_ptr, tensor_spec); } } case DataType::UINT32: { @@ -190,9 +197,9 @@ Tensor create_tt_tensor_from_py_data( if (enable_borrow) { auto storage = BorrowedStorage( borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback); - return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile); + return Tensor(std::move(storage), tensor_spec); } else { - return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR, optional_tile); + return create_owned_tensor(data_ptr, tensor_spec); } } case DataType::FLOAT32: { @@ -200,9 +207,9 @@ Tensor create_tt_tensor_from_py_data( if (enable_borrow) { auto storage = BorrowedStorage( borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback); - return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile); + return Tensor(std::move(storage), tensor_spec); } else { - return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR, optional_tile); + return create_owned_tensor(data_ptr, tensor_spec); } } // TODO: This is not supported for numpy @@ -211,27 +218,28 @@ Tensor create_tt_tensor_from_py_data( if (enable_borrow) { auto storage = BorrowedStorage( borrowed_buffer::Buffer(data_ptr, num_elements), on_creation_callback, on_destruction_callback); - return Tensor(std::move(storage), shape, data_type, Layout::ROW_MAJOR, optional_tile); + return Tensor(std::move(storage), tensor_spec); } else { - return create_owned_tensor(data_ptr, num_elements, shape, data_type, Layout::ROW_MAJOR, optional_tile); + return create_owned_tensor(data_ptr, tensor_spec); } } case DataType::BFLOAT8_B: case DataType::BFLOAT4_B: { auto data_ptr = reinterpret_cast(py_data_ptr); - auto data = std::vector(data_ptr, data_ptr + num_elements); - auto buffer = owned_buffer::create(std::move(data)); - auto tile = optional_tile.value_or(Tile()); - auto tensor = Tensor(OwnedStorage{buffer}, shape, DataType::FLOAT32, Layout::ROW_MAJOR, optional_tile) - .to(Layout::TILE); - auto output_float_data = owned_buffer::get_as(tensor).get(); + auto float_tensor_spec = TensorSpec( + tensor_spec.logical_shape(), + TensorLayout(DataType::FLOAT32, tensor_spec.page_config(), tensor_spec.memory_config())); + auto float_tensor = create_owned_tensor(data_ptr, float_tensor_spec); + + auto tile = tensor_spec.tensor_layout().get_page_config().get_tile(); + auto output_float_data = owned_buffer::get_as(float_tensor).get(); auto output_packed_data = data_type == DataType::BFLOAT8_B ? pack_fp32_vec_as_bfp8_tiles( output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile) : pack_fp32_vec_as_bfp4_tiles( output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false, tile); auto output_buffer = owned_buffer::create(std::move(output_packed_data)); - return Tensor(std::move(OwnedStorage{std::move(output_buffer)}), shape, data_type, Layout::TILE, tile); + return Tensor(std::move(OwnedStorage{std::move(output_buffer)}), tensor_spec); } default: { TT_THROW("Unsupported DataType: {}", data_type); @@ -242,16 +250,26 @@ Tensor create_tt_tensor_from_py_data( Tensor convert_python_tensor_to_tt_tensor( const py::handle& py_tensor, - std::optional optional_data_type = std::nullopt, - const std::optional& optional_tile = std::nullopt, - bool enable_borrow = true) { + std::optional optional_data_type, + std::optional optional_layout, + const std::optional& optional_tile, + const MemoryConfig& memory_config, + Device* device, + bool override_enable_borrow = false) { GraphTracker::instance().track_function_start( - "tt::tt_metal::detail::convert_python_tensor_to_tt_tensor", py_tensor, optional_data_type, enable_borrow); + "tt::tt_metal::detail::convert_python_tensor_to_tt_tensor", + py_tensor, + optional_data_type, + optional_layout, + optional_tile, + memory_config, + device, + override_enable_borrow); py::object torch = py::module_::import("torch"); py::object np = py::module_::import("numpy"); auto py_dtype = py_tensor.attr("dtype"); - auto shape = py::cast>(py_tensor.attr("shape")); + auto shape = ttnn::SimpleShape(py::cast>(py_tensor.attr("shape"))); DataType data_type; @@ -323,7 +341,7 @@ Tensor convert_python_tensor_to_tt_tensor( num_elements = py::cast(contiguous_py_tensor.attr("numel")()); py_data_ptr = py::cast(contiguous_py_tensor.attr("data_ptr")()); } else if (py::isinstance(py_tensor, np.attr("ndarray"))) { - TT_FATAL(enable_borrow, "Owned storage for numpy tensors is untested!"); + TT_FATAL(!override_enable_borrow, "Disabling borrowed buffers for numpy tensors is untested!"); contiguous_py_tensor = np.attr("ascontiguousarray")(py_tensor); @@ -386,17 +404,35 @@ Tensor convert_python_tensor_to_tt_tensor( TT_THROW("The argument must be of type torch.Tensor or numpy.ndarray!"); } + // TODO: Remove check of num_elements from python against volume of ttnn::SimpleShape + TT_FATAL( + num_elements == shape.volume(), + "Number of elements from python tensor {} must match volume of shape {}!", + num_elements, + shape.volume()); + + Layout layout = optional_layout.value_or(Layout::ROW_MAJOR); + if (data_type == DataType::BFLOAT8_B or data_type == DataType::BFLOAT4_B) { + if (optional_layout.has_value() and optional_layout.value() != Layout::TILE) { + log_warning( + tt::LogAlways, + "Tensor layout must be Layout::TILE for bfloat8_b or bfloat4_b! Tensor layout will be {} instead of " + "the requested {}!", + Layout::TILE, + optional_layout.value()); + } + layout = Layout::TILE; + } + + auto tensor_spec = TensorSpec(shape, TensorLayout(data_type, PageConfig(layout, optional_tile), memory_config)); auto on_creation_callback = [tensor = contiguous_py_tensor] { tensor.inc_ref(); }; auto on_destruction_callback = [tensor = contiguous_py_tensor] { tensor.dec_ref(); }; auto output = create_tt_tensor_from_py_data( - num_elements, - py_data_ptr, - shape, - data_type, - optional_tile, - enable_borrow, - on_creation_callback, - on_destruction_callback); + py_data_ptr, tensor_spec, device, override_enable_borrow, on_creation_callback, on_destruction_callback); + + if (device) { + output = output.to(device, memory_config); + } output = tt::tt_metal::set_tensor_id(output); GraphTracker::instance().track_function_end(output); return output; @@ -411,7 +447,8 @@ Tensor convert_python_tensors_to_tt_tensors( "tt::tt_metal::detail::convert_python_tensors_to_tt_tensors", tensor_shards, data_type, strategy); std::vector tt_shards; for (const auto& shard : tensor_shards) { - tt_shards.push_back(detail::convert_python_tensor_to_tt_tensor(shard, data_type, tile, false)); + tt_shards.push_back(detail::convert_python_tensor_to_tt_tensor( + shard, data_type, Layout::ROW_MAJOR, tile, MemoryConfig{}, nullptr, true)); } std::vector host_owned_buffers; std::vector host_owned_shapes; @@ -432,15 +469,68 @@ Tensor convert_python_tensors_to_tt_tensors( return output; } -std::pair, DataType> get_buffer_and_dtype_from_tensor( - const Tensor& tt_tensor) { +template +owned_buffer::Buffer create_row_major_owned_buffer( + owned_buffer::Buffer owned_buffer, const ttnn::TensorSpec& tensor_spec) { + if (tensor_spec.layout() == Layout::TILE) { + auto data = tensor_impl::convert_layout_tile_to_row_major( + tensor_spec.physical_shape(), tensor_spec.tile(), owned_buffer); + return owned_buffer::create(std::move(data)); + } + return owned_buffer; +} + +std::variant get_host_buffer_from_tensor(const Tensor& tt_tensor) { TT_ASSERT(tt_tensor.storage_type() == StorageType::OWNED or tt_tensor.storage_type() == StorageType::BORROWED); - auto buffer = std::visit( - [](auto&& storage) -> std::variant { + const auto& tensor_spec = tt_tensor.get_tensor_spec(); + return std::visit( + [&tensor_spec, &tt_tensor](auto&& storage) -> std::variant { using T = std::decay_t; if constexpr (std::is_same_v) { - return storage.buffer; + auto tt_dtype = tensor_spec.data_type(); + switch (tt_dtype) { + case DataType::UINT8: { + return create_row_major_owned_buffer( + owned_buffer::get_as(storage.buffer), tensor_spec); + } + case DataType::UINT16: { + return create_row_major_owned_buffer( + owned_buffer::get_as(storage.buffer), tensor_spec); + } + case DataType::INT32: { + return create_row_major_owned_buffer( + owned_buffer::get_as(storage.buffer), tensor_spec); + } + case DataType::UINT32: { + return create_row_major_owned_buffer( + owned_buffer::get_as(storage.buffer), tensor_spec); + } + case DataType::FLOAT32: { + return create_row_major_owned_buffer(owned_buffer::get_as(storage.buffer), tensor_spec); + } + case DataType::BFLOAT16: { + return create_row_major_owned_buffer( + owned_buffer::get_as<::bfloat16>(storage.buffer), tensor_spec); + } + case DataType::BFLOAT8_B: + case DataType::BFLOAT4_B: { + const auto& tile = tensor_spec.tile(); + auto uint32_data = owned_buffer::get_as(storage.buffer).get(); + auto float_unpacked_data = + tt_dtype == DataType::BFLOAT8_B + ? unpack_bfp8_tiles_into_float_vec( + uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile) + : unpack_bfp4_tiles_into_float_vec( + uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile); + auto input_float_buffer = owned_buffer::create(std::move(float_unpacked_data)); + return create_row_major_owned_buffer(input_float_buffer, tensor_spec); + } + default: { + TT_THROW("Unsupported DataType: {}", tt_dtype); + break; + } + } } else if constexpr (std::is_same_v) { TT_THROW("Device tensor cannot be converted to torch"); } else if constexpr (std::is_same_v) { @@ -456,52 +546,64 @@ std::pair, DataType> get_buffer_and_dt } }, tt_tensor.get_storage()); - - const auto tile = tt_tensor.get_tensor_spec().tile(); - auto tt_dtype = tt_tensor.get_dtype(); - if (tt_dtype == DataType::BFLOAT8_B || tt_dtype == DataType::BFLOAT4_B) { - TT_ASSERT( - std::holds_alternative(buffer), - "Unexpected type {}", - tt::stl::get_active_type_name_in_variant(buffer)); - auto uint32_data = std::get>(std::get(buffer)).get(); - auto float_unpacked_data = - tt_dtype == DataType::BFLOAT8_B - ? unpack_bfp8_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile) - : unpack_bfp4_tiles_into_float_vec(uint32_data, /*row_major_output=*/false, /*is_exp_a=*/false, tile); - auto input_float_buffer = owned_buffer::create(std::move(float_unpacked_data)); - auto float_tensor = Tensor( - OwnedStorage{input_float_buffer}, - tt_tensor.get_shape(), - DataType::FLOAT32, - tt_tensor.get_layout(), - tile) - .to(Layout::ROW_MAJOR); - auto output_float_data = owned_buffer::get_as(float_tensor).get(); - buffer = owned_buffer::create(std::move(output_float_data)); - tt_dtype = DataType::FLOAT32; - } - - return {buffer, tt_dtype}; } py::object convert_tt_tensor_to_torch_tensor(const Tensor& tt_tensor) { GraphTracker::instance().track_function_start("tt::tt_metal::detail::convert_tt_tensor_to_torch_tensor", tt_tensor); - auto [buffer, buffer_dtype] = get_buffer_and_dtype_from_tensor(tt_tensor); + auto buffer = get_host_buffer_from_tensor(tt_tensor); py::object torch = py::module_::import("torch"); auto frombuffer = torch.attr("frombuffer"); - const auto tt_dtype_to_torch_dtype = std::map{ - {DataType::UINT8, torch.attr("uint8")}, - {DataType::UINT16, torch.attr("int16")}, // TODO(arakhmati): add DataType::INT16 - {DataType::INT32, torch.attr("int32")}, - {DataType::UINT32, torch.attr("int32")}, // TODO(arakhmati): add DataType::INT32 - {DataType::FLOAT32, torch.attr("float32")}, - {DataType::BFLOAT16, torch.attr("bfloat16")}, - }; - auto torch_dtype = tt_dtype_to_torch_dtype.at(buffer_dtype); + auto torch_dtype = [&]() { + if (std::holds_alternative(buffer)) { + return std::visit( + [&torch](auto& owned_buffer) -> py::object { + using T = std::decay_t; + if constexpr (std::is_same_v>) { + return torch.attr("uint8"); + } else if constexpr (std::is_same_v>) { + return torch.attr("int16"); + } else if constexpr (std::is_same_v>) { + return torch.attr("int32"); + } else if constexpr (std::is_same_v>) { + return torch.attr("int32"); + } else if constexpr (std::is_same_v>) { + return torch.attr("float32"); + } else if constexpr (std::is_same_v>) { + return torch.attr("bfloat16"); + } else { + static_assert(tt::stl::concepts::always_false_v, "Unsupported buffer!"); + } + }, + std::get(buffer)); + + } else if (std::holds_alternative(buffer)) { + return std::visit( + [&torch](auto& borrowed_buffer) -> py::object { + using T = std::decay_t; + if constexpr (std::is_same_v>) { + return torch.attr("uint8"); + } else if constexpr (std::is_same_v>) { + return torch.attr("int16"); + } else if constexpr (std::is_same_v>) { + return torch.attr("int32"); + } else if constexpr (std::is_same_v>) { + return torch.attr("int32"); + } else if constexpr (std::is_same_v>) { + return torch.attr("float32"); + } else if constexpr (std::is_same_v>) { + return torch.attr("bfloat16"); + } else { + static_assert(tt::stl::concepts::always_false_v, "Unsupported buffer!"); + } + }, + std::get(buffer)); + } else { + TT_THROW("Only OwnedBuffer or BorrowedBuffer is supported for converting to python buffers!"); + } + }(); auto shape = tt_tensor.get_legacy_shape(); auto torch_shape = std::vector(std::begin(shape), std::end(shape)); @@ -527,19 +629,59 @@ py::object convert_tt_tensor_to_torch_tensor(const Tensor& tt_tensor) { py::object convert_tt_tensor_to_numpy_tensor(const Tensor& tt_tensor) { GraphTracker::instance().track_function_start("tt::tt_metal::detail::convert_tt_tensor_to_numpy_tensor", tt_tensor); - auto [buffer, buffer_dtype] = get_buffer_and_dtype_from_tensor(tt_tensor); + auto buffer = get_host_buffer_from_tensor(tt_tensor); py::object np = py::module_::import("numpy"); auto frombuffer = np.attr("frombuffer"); - const auto tt_dtype_to_np_dtype = std::map{ - {DataType::UINT8, np.attr("ubyte")}, - {DataType::UINT16, np.attr("int16")}, // TODO(arakhmati): add DataType::INT16 - {DataType::INT32, np.attr("int32")}, - {DataType::UINT32, np.attr("int32")}, // TODO(arakhmati): add DataType::INT32 - {DataType::FLOAT32, np.attr("float32")}, - }; - auto np_dtype = tt_dtype_to_np_dtype.at(buffer_dtype); + auto np_dtype = [&]() { + if (std::holds_alternative(buffer)) { + return std::visit( + [&np](auto& owned_buffer) -> py::object { + using T = std::decay_t; + if constexpr (std::is_same_v>) { + return np.attr("ubyte"); + } else if constexpr (std::is_same_v>) { + return np.attr("int16"); + } else if constexpr (std::is_same_v>) { + return np.attr("int32"); + } else if constexpr (std::is_same_v>) { + return np.attr("int32"); + } else if constexpr (std::is_same_v>) { + return np.attr("float32"); + } else if constexpr (std::is_same_v>) { + TT_THROW("Bfloat16 is not supported for numpy!"); + } else { + static_assert(tt::stl::concepts::always_false_v, "Unsupported buffer!"); + } + }, + std::get(buffer)); + + } else if (std::holds_alternative(buffer)) { + return std::visit( + [&np](auto& borrowed_buffer) -> py::object { + using T = std::decay_t; + if constexpr (std::is_same_v>) { + return np.attr("ubyte"); + } else if constexpr (std::is_same_v>) { + return np.attr("int16"); + } else if constexpr (std::is_same_v>) { + return np.attr("int32"); + } else if constexpr (std::is_same_v>) { + return np.attr("int32"); + } else if constexpr (std::is_same_v>) { + return np.attr("float32"); + } else if constexpr (std::is_same_v>) { + TT_THROW("Bfloat16 is not supported for numpy!"); + } else { + static_assert(tt::stl::concepts::always_false_v, "Unsupported buffer!"); + } + }, + std::get(buffer)); + } else { + TT_THROW("Only OwnedBuffer or BorrowedBuffer is supported for converting to python buffers!"); + } + }(); auto shape = tt_tensor.get_legacy_shape(); auto np_shape = std::vector(std::begin(shape), std::end(shape)); @@ -842,7 +984,8 @@ void pytensor_module(py::module& m_tensor) { if (py::isinstance(tensor)) { return detail::convert_python_tensors_to_tt_tensors(tensor, data_type, tile, strategy); } - return detail::convert_python_tensor_to_tt_tensor(tensor, data_type, tile); + return detail::convert_python_tensor_to_tt_tensor( + tensor, data_type, std::nullopt, tile, MemoryConfig{}, nullptr); }), py::arg("tensor"), py::arg("data_type") = std::nullopt, @@ -857,6 +1000,8 @@ void pytensor_module(py::module& m_tensor) { +--------------+------------------------+ | data_type | TT Tensor data type | +--------------+------------------------+ + | tile | TT Tile Spec | + +--------------+------------------------+ Example of creating a TT Tensor that uses torch.Tensor's storage as its own storage: @@ -872,16 +1017,15 @@ void pytensor_module(py::module& m_tensor) { Layout layout, const MemoryConfig& mem_config, const std::optional& tile) { - auto tensor = detail::convert_python_tensor_to_tt_tensor(python_tensor, data_type, tile); - auto layout_tensor = tensor.to(layout); - return layout_tensor.to(device, mem_config); + return detail::convert_python_tensor_to_tt_tensor( + python_tensor, data_type, layout, tile, mem_config, device); }), py::arg("tensor"), py::arg("data_type") = std::nullopt, - py::arg("device").noconvert(), - py::arg("layout").noconvert(), - py::arg("mem_config").noconvert(), - py::arg("tile") = std::nullopt, + py::arg("device") = nullptr, + py::arg("layout").noconvert() = Layout::ROW_MAJOR, + py::arg("mem_config").noconvert() = MemoryConfig{}, + py::arg("tile").noconvert() = std::nullopt, py::return_value_policy::move, R"doc( +--------------+------------------------+ @@ -897,14 +1041,17 @@ void pytensor_module(py::module& m_tensor) { +--------------+------------------------+ | mem_config | TT memory_config | +--------------+------------------------+ + | tile | TT Tile Spec | + +--------------+------------------------+ - Example of creating a TT Tensor that uses torch.Tensor's storage as its own storage: + Example of creating a TT Tensor from numpy tensor: .. code-block:: python + device = ttnn.open_device(device_id=0) py_tensor = np.zeros((1, 1, 32, 32)) - ttnn.Tensor(py_tensor) + ttnn.Tensor(py_tensor, ttnn.bfloat16, device, ttnn.TILE_LAYOUT) )doc") .def_property_readonly("shape", [](const Tensor& self) { return self.get_shape(); }) .def_property_readonly("dtype", [](const Tensor& self) { return self.get_dtype(); }) diff --git a/ttnn/cpp/ttnn/tensor/tensor_spec.hpp b/ttnn/cpp/ttnn/tensor/tensor_spec.hpp index 125b3bb719f..172e0d881f5 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_spec.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_spec.hpp @@ -28,6 +28,7 @@ class TensorSpec final { DataType data_type() const { return tensor_layout_.get_data_type(); } Layout layout() const { return tensor_layout_.get_layout(); } PageConfig page_config() const { return tensor_layout_.get_page_config(); } + const MemoryConfig& memory_config() const { return tensor_layout_.get_memory_config(); } const ttnn::SimpleShape& padded_shape() const { return cached_padded_shape_; } const Size& physical_shape() const { return cached_physical_shape_; } ttnn::Shape shape() const { return ttnn::Shape(logical_shape_.view(), cached_padded_shape_.view()); } diff --git a/ttnn/tt_lib/fused_ops/softmax.py b/ttnn/tt_lib/fused_ops/softmax.py index f5b2f5fceb4..904b4cea008 100644 --- a/ttnn/tt_lib/fused_ops/softmax.py +++ b/ttnn/tt_lib/fused_ops/softmax.py @@ -42,7 +42,7 @@ def ref_stable_softmax(x): if __name__ == "__main__": - device = ttnn.open_device(0) + device = ttnn.open_device(device_id=0) H, W = 64, 96 torch.manual_seed(123) diff --git a/ttnn/tt_lib/utils.py b/ttnn/tt_lib/utils.py index 9883666b81f..a61f9759464 100644 --- a/ttnn/tt_lib/utils.py +++ b/ttnn/tt_lib/utils.py @@ -8,6 +8,8 @@ import torch import numpy as np +from typing_extensions import deprecated + def _nearest_32(x): return math.ceil(x / 32) * 32 @@ -134,108 +136,22 @@ def convert_act_2d_matrix(activation, kernel_y, kernel_x, stride_y, stride_x, pa return ret.reshape(ret_shape) +@deprecated("PyTorch data is handled automatically in tensor infra. This function does nothing now:") def tilize(x): - """ - This function tilizes a tensor. The last two tensor dims must be divisible by 32, after which this function - produces row major tiles and creates faces. The output of this function is a flattened list that - we can send to the device. - - :param x: Input PyTorch Tensor - :type x: class:`torch.Tensor` - - WARNING: This function should eventually be retired in favour of fully tilizing on device. - """ - nearest_32 = _nearest_32 - - assert isinstance( - x, (torch.Tensor, np.ndarray) - ), "Input to this function must be an instance of torch.Tensor or np.array" - assert len(x.shape) == 4, "Only 4D tensors suppported" - assert (x.shape[-2] % 32) == 0 and ( - x.shape[-1] % 32 - ) == 0, "The last two dimensions of the tensor must be divisible by 32" - - if isinstance(x, torch.Tensor): - ret = torch.zeros(np.prod(x.shape)) - else: - ret = np.zeros(np.prod(x.shape)) - - idx = 0 - for B in range(x.shape[0]): - for C in range(x.shape[1]): - for H in range(0, x.shape[2], 32): - for W in range(0, x.shape[3], 32): - unfaced_tile = x[B, C, H : H + 32, W : W + 32] - - face0 = unfaced_tile[:16, :16] - face1 = unfaced_tile[:16, 16:] - face2 = unfaced_tile[16:, :16] - face3 = unfaced_tile[16:, 16:] - - for face in (face0, face1, face2, face3): - ret[idx : idx + 256] = face.reshape(-1) - idx += 256 - - return ret.reshape(x.shape) + return x +@deprecated("PyTorch data is handled automatically in tensor infra. This function does nothing now:") def tilize_to_list(x): """ - Tilize a PyTorch and then return the values as a flat list. The last two - tensor dims must be divisible by 32, after which this function produces row - major tiles and creates faces. - - :param x: Input PyTorch Tensor - :type x: class:`torch.Tensor` - - WARNING: This function should eventually be retired in favour of fully tilizing on device. + Returns a flattened list of the tensor """ - return tilize(x).reshape(-1).tolist() +@deprecated("PyTorch data is handled automatically in tensor infra. This function does nothing now:") def untilize(x): - """ - This function untilizes a tensor to row major format. - - :param x: Input PyTorch Tensor - :type x: class:`torch.Tensor` - - WARNING: This function should eventually be retired in favour of fully tilizing on device. - """ - nearest_32 = _nearest_32 - - assert isinstance(x, (torch.Tensor, np.ndarray)), "Input to this function must be an instance of torch.Tensor" - assert len(x.shape) == 4, "Only 4D tensors suppported" - assert (x.shape[-2] % 32) == 0 and ( - x.shape[-1] % 32 - ) == 0, "The last two dimensions of the tensor must be divisible by 32" - - if isinstance(x, torch.Tensor): - ret = torch.zeros(x.shape) - else: - ret = np.zeros(x.shape) - - for B in range(x.shape[0]): - for C in range(x.shape[1]): - x_hw = x[B, C, :].reshape(-1) - hw = 0 - for h in range(0, x.shape[2], 32): - for w in range(0, x.shape[3], 32): - f_tile = x_hw[hw : hw + 256].reshape(16, 16) - ret[B, C, h : h + 16, w : w + 16] = f_tile - - f_tile = x_hw[hw + 256 : hw + 512].reshape(16, 16) - ret[B, C, h : h + 16, w + 16 : w + 32] = f_tile - - f_tile = x_hw[hw + 512 : hw + 768].reshape(16, 16) - ret[B, C, h + 16 : h + 32, w : w + 16] = f_tile - - f_tile = x_hw[hw + 768 : hw + 1024].reshape(16, 16) - ret[B, C, h + 16 : h + 32, w + 16 : w + 32] = f_tile - hw += 1024 # traverse tiles in RM-order - - return ret + return x def print_diff_argmax(a, b, annotation=""):