-
Notifications
You must be signed in to change notification settings - Fork 87
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#13127: Use tensor spec in conversion between python and tt tensors
- Significant changes in ttnn/cpp/pybind11/pytensor.cpp: * Use tensor spec in create_owned_tensor * Add conversion between ROW_MAJOR and TILE layouts for ttnn.Tensor(...)/tensor.to(...) APIs ** For ttnn.Tensor(python_tensor, ...), handling is now internal and not through .to(layout) ** For ttnn.Tensor(float_vector, ...), use .to(layout) to convert to TILE if needed ** Make tilize, tilize_to_list, and untilize python utility functions no-ops and mark as deprecated * Add analogous create_row_major_owned_buffer from tensor buffer ** Commonize handling of BFLOAT8_B/BFLOAT4_B as float tensors/buffers ** Always use OwnedBuffer if conversion to/from TILE layout is required * Automatically deduce python dtype from owned buffers instead of mapping based on tt dtype * Set defaults for pybound init so it is more usable * Invert meaning of enable_borrow (now called override_enable_borrow) ** Make enable_borrow internal to create_tt_tensor_from_py_data - Update tensor init documentation and sample code for tile arg and creating tensors on device - Add memory_config() to TensorSpec - Commonize tt_dtype_to_torch_dtype and tt_dtype_to_np_dtype dicts across ttnn unit tests - Add test for host side tensor conversion in tests/ttnn/unit_tests/tensor/test_tensor_conversion.py - Add new tests/ttnn/unit_tests/tensor/test_tensor_creation.py tests * Coverage for directly creating device tensors with ttnn.Tensor(...) * Coverage for API parity between ttnn.from_device/ttnn.to_device and ttnn.Tensor(...)/tensor.to(...)
- Loading branch information
1 parent
bc00438
commit 5b04331
Showing
13 changed files
with
444 additions
and
366 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
|
||
import os | ||
import pathlib | ||
|
||
import torch | ||
import numpy as np | ||
|
||
import ttnn | ||
from tests.ttnn.utils_for_testing import tt_dtype_to_torch_dtype | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"layout", | ||
[ | ||
ttnn.ROW_MAJOR_LAYOUT, | ||
ttnn.TILE_LAYOUT, | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"tt_dtype", | ||
[ | ||
ttnn.uint8, | ||
ttnn.uint16, | ||
ttnn.uint32, | ||
ttnn.int32, | ||
ttnn.float32, | ||
ttnn.bfloat16, | ||
ttnn.bfloat8_b, | ||
ttnn.bfloat4_b, | ||
], | ||
) | ||
@pytest.mark.parametrize("shape", [(2, 3, 64, 96)]) | ||
def test_tensor_creation(shape, tt_dtype, layout, device): | ||
torch.manual_seed(0) | ||
|
||
dtype = tt_dtype_to_torch_dtype[tt_dtype] | ||
|
||
if dtype in {torch.uint8, torch.int16, torch.int32}: | ||
py_tensor = torch.randint(torch.iinfo(dtype).min, torch.iinfo(dtype).max, shape, dtype=dtype) | ||
else: | ||
py_tensor = torch.rand(shape, dtype=dtype) | ||
|
||
tt_tensor = ttnn.Tensor(py_tensor, tt_dtype, device, layout) | ||
|
||
tt_tensor = tt_tensor.cpu() | ||
|
||
py_tensor_after_round_trip = tt_tensor.to_torch() | ||
|
||
assert py_tensor.dtype == py_tensor_after_round_trip.dtype | ||
assert py_tensor.shape == py_tensor_after_round_trip.shape | ||
|
||
allclose_kwargs = {} | ||
if tt_dtype == ttnn.bfloat8_b: | ||
allclose_kwargs = dict(atol=1e-2) | ||
elif tt_dtype == ttnn.bfloat4_b: | ||
allclose_kwargs = dict(atol=0.2) | ||
|
||
passing = torch.allclose(py_tensor, py_tensor_after_round_trip, **allclose_kwargs) | ||
assert passing | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"layout", | ||
[ | ||
ttnn.ROW_MAJOR_LAYOUT, | ||
ttnn.TILE_LAYOUT, | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"tt_dtype", | ||
[ | ||
ttnn.uint8, | ||
ttnn.uint16, | ||
ttnn.uint32, | ||
ttnn.int32, | ||
ttnn.float32, | ||
ttnn.bfloat16, | ||
ttnn.bfloat8_b, | ||
ttnn.bfloat4_b, | ||
], | ||
) | ||
@pytest.mark.parametrize("shape", [(2, 3, 64, 96)]) | ||
def test_tensor_creation_api_parity(shape, tt_dtype, layout, device): | ||
torch.manual_seed(0) | ||
|
||
if tt_dtype in (ttnn.bfloat8_b, ttnn.bfloat4_b) and layout == ttnn.ROW_MAJOR_LAYOUT: | ||
pytest.skip("{} is only valid for ttnn.TILE_LAYOUT!".format(tt_dtype)) | ||
|
||
dtype = tt_dtype_to_torch_dtype[tt_dtype] | ||
|
||
if dtype in {torch.uint8, torch.int16, torch.int32}: | ||
py_tensor = torch.randint(torch.iinfo(dtype).min, torch.iinfo(dtype).max, shape, dtype=dtype) | ||
else: | ||
py_tensor = torch.rand(shape, dtype=dtype) | ||
|
||
tt_tensor_1 = ttnn.Tensor(py_tensor, tt_dtype, device, layout) | ||
tt_tensor_2 = ttnn.from_torch(py_tensor, tt_dtype, device=device, layout=layout) | ||
|
||
tt_tensor_1 = tt_tensor_1.cpu() | ||
tt_tensor_2 = tt_tensor_2.cpu() | ||
|
||
py_tensor_after_round_trip_1 = tt_tensor_1.to_torch() | ||
py_tensor_after_round_trip_2 = tt_tensor_2.to_torch() | ||
py_tensor_after_round_trip_3 = ttnn.to_torch(tt_tensor_1) | ||
py_tensor_after_round_trip_4 = ttnn.to_torch(tt_tensor_2) | ||
|
||
allclose_kwargs = {} | ||
if tt_dtype == ttnn.bfloat8_b: | ||
allclose_kwargs = dict(atol=1e-2) | ||
elif tt_dtype == ttnn.bfloat4_b: | ||
allclose_kwargs = dict(atol=0.2) | ||
|
||
passing = torch.allclose(py_tensor, py_tensor_after_round_trip_1, **allclose_kwargs) | ||
passing = torch.allclose(py_tensor, py_tensor_after_round_trip_2, **allclose_kwargs) | ||
passing = torch.allclose(py_tensor, py_tensor_after_round_trip_3, **allclose_kwargs) | ||
passing = torch.allclose(py_tensor, py_tensor_after_round_trip_4, **allclose_kwargs) | ||
assert passing |
Oops, something went wrong.