Skip to content

Commit

Permalink
#13127: Use tensor spec in conversion between python and tt tensors
Browse files Browse the repository at this point in the history
- Significant changes in ttnn/cpp/pybind11/pytensor.cpp:
  * Use tensor spec in create_owned_tensor
  * Add conversion between ROW_MAJOR and TILE layouts for ttnn.Tensor(...)/tensor.to(...) APIs
    ** For ttnn.Tensor(python_tensor, ...), handling is now internal and not through .to(layout)
    ** For ttnn.Tensor(float_vector, ...), use .to(layout) to convert to TILE if needed
    ** Make tilize, tilize_to_list, and untilize python utility functions no-ops and mark as deprecated
  * Add analogous create_row_major_owned_buffer from tensor buffer
    ** Commonize handling of BFLOAT8_B/BFLOAT4_B as float tensors/buffers
    ** Always use OwnedBuffer if conversion to/from TILE layout is required
  * Automatically deduce python dtype from owned buffers instead of mapping based on tt dtype
  * Set defaults for pybound init so it is more usable
  * Invert meaning of enable_borrow (now called override_enable_borrow)
    ** Make enable_borrow internal to create_tt_tensor_from_py_data
- Update tensor init documentation and sample code for tile arg and creating tensors on device
- Add memory_config() to TensorSpec
- Commonize tt_dtype_to_torch_dtype and tt_dtype_to_np_dtype dicts across ttnn unit tests
- Add test for host side tensor conversion in tests/ttnn/unit_tests/tensor/test_tensor_conversion.py
- Add new tests/ttnn/unit_tests/tensor/test_tensor_creation.py tests
  * Coverage for directly creating device tensors with ttnn.Tensor(...)
  * Coverage for API parity between ttnn.from_device/ttnn.to_device and ttnn.Tensor(...)/tensor.to(...)
  • Loading branch information
TT-BrianLiu committed Dec 12, 2024
1 parent bc00438 commit 5b04331
Show file tree
Hide file tree
Showing 13 changed files with 444 additions and 366 deletions.
100 changes: 8 additions & 92 deletions models/utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

from ttnn.device import Arch

from typing_extensions import deprecated


### Math operations ###
def _nearest_32(x):
Expand Down Expand Up @@ -430,108 +432,22 @@ def convert_act_2d_matrix(activation, kernel_y, kernel_x, stride_y, stride_x, pa


### Tilizing / Untilizing ###
@deprecated("PyTorch data is handled automatically in tensor infra. This function does nothing now:")
def tilize(x):
"""
This function tilizes a tensor. The last two tensor dims must be divisible by 32, after which this function
produces row major tiles and creates faces. The output of this function is a flattened list that
we can send to the device.
:param x: Input PyTorch Tensor
:type x: class:`torch.Tensor`
WARNING: This function should eventually be retired in favour of fully tilizing on device.
"""
nearest_32 = _nearest_32

assert isinstance(
x, (torch.Tensor, np.ndarray)
), "Input to this function must be an instance of torch.Tensor or np.array"
assert len(x.shape) == 4, "Only 4D tensors suppported"
assert (x.shape[-2] % 32) == 0 and (
x.shape[-1] % 32
) == 0, "The last two dimensions of the tensor must be divisible by 32"

if isinstance(x, torch.Tensor):
ret = torch.zeros(np.prod(x.shape))
else:
ret = np.zeros(np.prod(x.shape))

idx = 0
for B in range(x.shape[0]):
for C in range(x.shape[1]):
for H in range(0, x.shape[2], 32):
for W in range(0, x.shape[3], 32):
unfaced_tile = x[B, C, H : H + 32, W : W + 32]

face0 = unfaced_tile[:16, :16]
face1 = unfaced_tile[:16, 16:]
face2 = unfaced_tile[16:, :16]
face3 = unfaced_tile[16:, 16:]

for face in (face0, face1, face2, face3):
ret[idx : idx + 256] = face.reshape(-1)
idx += 256

return ret.reshape(x.shape)
return x


@deprecated("PyTorch data is handled automatically in tensor infra. This function does nothing now:")
def tilize_to_list(x):
"""
Tilize a PyTorch and then return the values as a flat list. The last two
tensor dims must be divisible by 32, after which this function produces row
major tiles and creates faces.
:param x: Input PyTorch Tensor
:type x: class:`torch.Tensor`
WARNING: This function should eventually be retired in favour of fully tilizing on device.
Returns a flattened list of the tensor
"""

return tilize(x).reshape(-1).tolist()


@deprecated("PyTorch data is handled automatically in tensor infra. This function does nothing now:")
def untilize(x):
"""
This function untilizes a tensor to row major format.
:param x: Input PyTorch Tensor
:type x: class:`torch.Tensor`
WARNING: This function should eventually be retired in favour of fully tilizing on device.
"""
nearest_32 = _nearest_32

assert isinstance(x, (torch.Tensor, np.ndarray)), "Input to this function must be an instance of torch.Tensor"
assert len(x.shape) == 4, "Only 4D tensors suppported"
assert (x.shape[-2] % 32) == 0 and (
x.shape[-1] % 32
) == 0, "The last two dimensions of the tensor must be divisible by 32"

if isinstance(x, torch.Tensor):
ret = torch.zeros(x.shape, dtype=x.dtype)
else:
ret = np.zeros(x.shape, dtype=x.dtype)

for B in range(x.shape[0]):
for C in range(x.shape[1]):
x_hw = x[B, C, :].reshape(-1)
hw = 0
for h in range(0, x.shape[2], 32):
for w in range(0, x.shape[3], 32):
f_tile = x_hw[hw : hw + 256].reshape(16, 16)
ret[B, C, h : h + 16, w : w + 16] = f_tile

f_tile = x_hw[hw + 256 : hw + 512].reshape(16, 16)
ret[B, C, h : h + 16, w + 16 : w + 32] = f_tile

f_tile = x_hw[hw + 512 : hw + 768].reshape(16, 16)
ret[B, C, h + 16 : h + 32, w : w + 16] = f_tile

f_tile = x_hw[hw + 768 : hw + 1024].reshape(16, 16)
ret[B, C, h + 16 : h + 32, w + 16 : w + 32] = f_tile
hw += 1024 # traverse tiles in RM-order

return ret
return x


### Measuring accuracy and other metrics ###
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,7 @@
import ttnn
import torch
import numpy as np


tt_dtype_to_torch_dtype = {
ttnn.uint16: torch.int16,
ttnn.uint32: torch.int32,
ttnn.float32: torch.float,
ttnn.bfloat16: torch.bfloat16,
ttnn.bfloat8_b: torch.float,
}
from tests.ttnn.utils_for_testing import tt_dtype_to_torch_dtype


@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,7 @@
import torch
import numpy as np
import ttnn


tt_dtype_to_torch_dtype = {
ttnn.uint16: torch.int16,
ttnn.uint32: torch.int32,
ttnn.float32: torch.float,
ttnn.bfloat16: torch.bfloat16,
ttnn.bfloat8_b: torch.float,
}
from tests.ttnn.utils_for_testing import tt_dtype_to_torch_dtype


@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,9 @@
import ttnn

from models.utility_functions import get_debug_tensor
from tests.ttnn.utils_for_testing import tt_dtype_to_torch_dtype
from enum import Enum

tt_dtype_to_torch_dtype = {
ttnn.uint32: torch.int32,
ttnn.uint16: torch.int16,
ttnn.bfloat16: torch.bfloat16,
ttnn.bfloat8_b: torch.float,
}
TILE_WIDTH = 32
TILE_HEIGHT = 32

Expand Down
40 changes: 12 additions & 28 deletions tests/ttnn/unit_tests/tensor/test_tensor_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,10 @@
import numpy as np

import ttnn

tt_dtype_to_torch_dtype = {
ttnn.uint8: torch.uint8,
ttnn.uint16: torch.int16,
ttnn.uint32: torch.int32,
ttnn.int32: torch.int32,
ttnn.float32: torch.float,
ttnn.bfloat16: torch.bfloat16,
ttnn.bfloat8_b: torch.float,
ttnn.bfloat4_b: torch.float,
}

tt_dtype_to_np_dtype = {
ttnn.uint8: np.ubyte,
ttnn.uint16: np.int16,
ttnn.uint32: np.int32,
ttnn.int32: np.int32,
ttnn.float32: np.float32,
ttnn.bfloat8_b: np.float32,
ttnn.bfloat4_b: np.float32,
}
from tests.ttnn.utils_for_testing import tt_dtype_to_torch_dtype, tt_dtype_to_np_dtype


@pytest.mark.parametrize("convert_to_device", [True, False])
@pytest.mark.parametrize(
"tt_dtype",
[
Expand All @@ -49,7 +30,7 @@
)
@pytest.mark.parametrize("shape", [(2, 3, 64, 96)])
@pytest.mark.parametrize("python_lib", [torch, np])
def test_tensor_conversion_with_tt_dtype(python_lib, shape, tt_dtype, device):
def test_tensor_conversion_with_tt_dtype(python_lib, shape, tt_dtype, convert_to_device, device):
torch.manual_seed(0)

if python_lib == torch:
Expand All @@ -64,7 +45,7 @@ def test_tensor_conversion_with_tt_dtype(python_lib, shape, tt_dtype, device):

elif python_lib == np:
if tt_dtype == ttnn.bfloat16:
pytest.skip("ttnn.bloat16 dtype is not supported yet for numpy tensors!")
pytest.skip("ttnn.bfloat16 dtype is not supported yet for numpy tensors!")
dtype = tt_dtype_to_np_dtype[tt_dtype]

if dtype in {np.ubyte, np.int16, np.int32}:
Expand All @@ -82,8 +63,9 @@ def test_tensor_conversion_with_tt_dtype(python_lib, shape, tt_dtype, device):
assert tt_tensor.storage_type() == ttnn.StorageType.BORROWED
assert tt_tensor.layout == ttnn.ROW_MAJOR_LAYOUT

tt_tensor = tt_tensor.to(device)
tt_tensor = tt_tensor.cpu()
if convert_to_device:
tt_tensor = tt_tensor.to(device)
tt_tensor = tt_tensor.cpu()

if python_lib == torch:
py_tensor_after_round_trip = tt_tensor.to_torch()
Expand Down Expand Up @@ -123,6 +105,7 @@ def test_tensor_conversion_with_tt_dtype(python_lib, shape, tt_dtype, device):
}


@pytest.mark.parametrize("convert_to_device", [True, False])
@pytest.mark.parametrize(
"python_dtype_str",
[
Expand All @@ -137,7 +120,7 @@ def test_tensor_conversion_with_tt_dtype(python_lib, shape, tt_dtype, device):
)
@pytest.mark.parametrize("shape", [(2, 3, 64, 96)])
@pytest.mark.parametrize("python_lib", [torch, np])
def test_tensor_conversion_with_python_dtype(python_lib, shape, python_dtype_str, device):
def test_tensor_conversion_with_python_dtype(python_lib, shape, python_dtype_str, convert_to_device, device):
torch.manual_seed(0)

if python_lib == torch:
Expand Down Expand Up @@ -165,8 +148,9 @@ def test_tensor_conversion_with_python_dtype(python_lib, shape, python_dtype_str
tt_tensor = ttnn.Tensor(py_tensor)
assert tt_tensor.storage_type() == ttnn.StorageType.BORROWED

tt_tensor = tt_tensor.to(device)
tt_tensor = tt_tensor.cpu()
if convert_to_device:
tt_tensor = tt_tensor.to(device)
tt_tensor = tt_tensor.cpu()

if python_lib == torch:
py_tensor_after_round_trip = tt_tensor.to_torch()
Expand Down
122 changes: 122 additions & 0 deletions tests/ttnn/unit_tests/tensor/test_tensor_creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest

import os
import pathlib

import torch
import numpy as np

import ttnn
from tests.ttnn.utils_for_testing import tt_dtype_to_torch_dtype


@pytest.mark.parametrize(
"layout",
[
ttnn.ROW_MAJOR_LAYOUT,
ttnn.TILE_LAYOUT,
],
)
@pytest.mark.parametrize(
"tt_dtype",
[
ttnn.uint8,
ttnn.uint16,
ttnn.uint32,
ttnn.int32,
ttnn.float32,
ttnn.bfloat16,
ttnn.bfloat8_b,
ttnn.bfloat4_b,
],
)
@pytest.mark.parametrize("shape", [(2, 3, 64, 96)])
def test_tensor_creation(shape, tt_dtype, layout, device):
torch.manual_seed(0)

dtype = tt_dtype_to_torch_dtype[tt_dtype]

if dtype in {torch.uint8, torch.int16, torch.int32}:
py_tensor = torch.randint(torch.iinfo(dtype).min, torch.iinfo(dtype).max, shape, dtype=dtype)
else:
py_tensor = torch.rand(shape, dtype=dtype)

tt_tensor = ttnn.Tensor(py_tensor, tt_dtype, device, layout)

tt_tensor = tt_tensor.cpu()

py_tensor_after_round_trip = tt_tensor.to_torch()

assert py_tensor.dtype == py_tensor_after_round_trip.dtype
assert py_tensor.shape == py_tensor_after_round_trip.shape

allclose_kwargs = {}
if tt_dtype == ttnn.bfloat8_b:
allclose_kwargs = dict(atol=1e-2)
elif tt_dtype == ttnn.bfloat4_b:
allclose_kwargs = dict(atol=0.2)

passing = torch.allclose(py_tensor, py_tensor_after_round_trip, **allclose_kwargs)
assert passing


@pytest.mark.parametrize(
"layout",
[
ttnn.ROW_MAJOR_LAYOUT,
ttnn.TILE_LAYOUT,
],
)
@pytest.mark.parametrize(
"tt_dtype",
[
ttnn.uint8,
ttnn.uint16,
ttnn.uint32,
ttnn.int32,
ttnn.float32,
ttnn.bfloat16,
ttnn.bfloat8_b,
ttnn.bfloat4_b,
],
)
@pytest.mark.parametrize("shape", [(2, 3, 64, 96)])
def test_tensor_creation_api_parity(shape, tt_dtype, layout, device):
torch.manual_seed(0)

if tt_dtype in (ttnn.bfloat8_b, ttnn.bfloat4_b) and layout == ttnn.ROW_MAJOR_LAYOUT:
pytest.skip("{} is only valid for ttnn.TILE_LAYOUT!".format(tt_dtype))

dtype = tt_dtype_to_torch_dtype[tt_dtype]

if dtype in {torch.uint8, torch.int16, torch.int32}:
py_tensor = torch.randint(torch.iinfo(dtype).min, torch.iinfo(dtype).max, shape, dtype=dtype)
else:
py_tensor = torch.rand(shape, dtype=dtype)

tt_tensor_1 = ttnn.Tensor(py_tensor, tt_dtype, device, layout)
tt_tensor_2 = ttnn.from_torch(py_tensor, tt_dtype, device=device, layout=layout)

tt_tensor_1 = tt_tensor_1.cpu()
tt_tensor_2 = tt_tensor_2.cpu()

py_tensor_after_round_trip_1 = tt_tensor_1.to_torch()
py_tensor_after_round_trip_2 = tt_tensor_2.to_torch()
py_tensor_after_round_trip_3 = ttnn.to_torch(tt_tensor_1)
py_tensor_after_round_trip_4 = ttnn.to_torch(tt_tensor_2)

allclose_kwargs = {}
if tt_dtype == ttnn.bfloat8_b:
allclose_kwargs = dict(atol=1e-2)
elif tt_dtype == ttnn.bfloat4_b:
allclose_kwargs = dict(atol=0.2)

passing = torch.allclose(py_tensor, py_tensor_after_round_trip_1, **allclose_kwargs)
passing = torch.allclose(py_tensor, py_tensor_after_round_trip_2, **allclose_kwargs)
passing = torch.allclose(py_tensor, py_tensor_after_round_trip_3, **allclose_kwargs)
passing = torch.allclose(py_tensor, py_tensor_after_round_trip_4, **allclose_kwargs)
assert passing
Loading

0 comments on commit 5b04331

Please sign in to comment.