Skip to content

Commit

Permalink
#13405: TTNN implementation of LENET model (#13473)
Browse files Browse the repository at this point in the history
### Ticket
[Link to the GitHub
Issue](#13405)

### Problem description
TTNN implementation for LENET model.

### What's changed
tnn support is enabled for lenet model along with demo, e2e and device
test perf.

### Checklist
- [x] Post commit CI passes
- [ ] Blackhole Post commit (if applicable)
- [x] Model regression CI testing passes (if applicable)
- [x] Device performance regression CI testing passes (if applicable)
- [x] New/Existing tests provide coverage for changes

CI Links:

(Single-card) Tests for new models -
[Link](https://github.com/tenstorrent/tt-metal/actions/runs/12273855198/job/34247084545)(Lenet
tests passed)
All Post Commit -
[Link](https://github.com/tenstorrent/tt-metal/actions/runs/12275562532)
  • Loading branch information
sabira-mcw authored Dec 12, 2024
1 parent 6e983a7 commit 82eb413
Show file tree
Hide file tree
Showing 8 changed files with 499 additions and 0 deletions.
29 changes: 29 additions & 0 deletions models/demos/lenet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# LENET

# Platforms:
E150, WH N300, N150

## Introduction

The LeNet model is a foundational convolutional neural network (CNN) architecture that was specifically developed for handwritten digit recognition on the MNIST dataset. This pioneering model consists of several convolutional layers interspersed with pooling layers, followed by fully connected layers that output the final classification. By utilizing convolutional layers, LeNet effectively captures spatial hierarchies and local patterns in images, leading to significantly enhanced performance compared to traditional, simpler architectures. Its design laid the groundwork for many modern deep learning models used in image classification tasks today.

### Batch size: 64

Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. It's recommended to set the batch_size to 64

## How to Run

To run the demo for digit classification using the LeNet model, follow these instructions:

Ensure you have the necessary dependencies installed and that your environment is set up correctly for running the model.

Use the following command to execute the LeNet demo
```
pytest models/demos/lenet/demo/demo.py::test_demo_dataset
```
This command will initiate the test for the demo dataset, allowing you to observe the model's performance in classifying handwritten digits


## Inputs

The demo accepts inputs from the MNIST dataset, which consists of a large collection of labeled handwritten digits. The dataset provides a diverse range of examples, enabling the model to learn and generalize effectively. Each input consists of a grayscale image of a handwritten digit, which is processed through the model to produce a predicted classification.
68 changes: 68 additions & 0 deletions models/demos/lenet/demo/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
import ttnn

from loguru import logger

from ttnn.model_preprocessing import preprocess_model_parameters
from models.demos.lenet.tt import tt_lenet
from models.demos.lenet import lenet_utils


def run_demo_dataset(device, batch_size, iterations, model_location_generator, reset_seeds):
num_classes = 10
test_input, images, outputs = lenet_utils.get_test_data(batch_size)

pt_model_path = model_location_generator("model.pt", model_subdir="LeNet")
torch_lenet, state_dict = lenet_utils.load_torch_lenet(pt_model_path, num_classes)
model = torch_lenet.float()

torch_output = model(test_input)
parameters = preprocess_model_parameters(
initialize_model=lambda: model,
custom_preprocessor=lenet_utils.custom_preprocessor,
)
parameters = lenet_utils.custom_preprocessor_device(parameters, device)
correct = 0
for iters in range(iterations):
x = test_input.permute(0, 2, 3, 1)
x = ttnn.from_torch(x, dtype=ttnn.bfloat16)
tt_output = tt_lenet.lenet(x, device, parameters)
tt_output = ttnn.to_torch(tt_output)
_, torch_predicted = torch.max(torch_output.data, -1)
_, ttnn_predicted = torch.max(tt_output.data, -1)

for i in range(batch_size):
logger.info(f"Iter: {iters} Sample {i}:")
logger.info(f"torch Label: {torch_predicted[i]}")
logger.info(f"Predicted Label: {ttnn_predicted[i]}")

if torch_predicted[i] == ttnn_predicted[i]:
correct += 1

accuracy = correct / (batch_size * iterations)
logger.info(f"Dataset Inference Accuracy for {batch_size}x{iterations} Samples : {accuracy}")
assert accuracy == 1.0, f"Expected accuracy : {1.0} Actual accuracy: {accuracy}"


@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
@pytest.mark.parametrize("batch_size", [64])
@pytest.mark.parametrize("iterations", [1])
def test_demo_dataset(
device,
batch_size,
iterations,
model_location_generator,
reset_seeds,
):
return run_demo_dataset(
reset_seeds=reset_seeds,
device=device,
batch_size=batch_size,
iterations=iterations,
model_location_generator=model_location_generator,
)
103 changes: 103 additions & 0 deletions models/demos/lenet/lenet_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import torchvision
import ttnn
import torchvision.transforms as transforms

from models.experimental.lenet.reference.lenet import LeNet5


def get_test_data(batch_size=64):
transform = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.1325,), std=(0.3105,)),
]
)

test_dataset = torchvision.datasets.MNIST(
root="./data",
train=False,
download=True,
)

batch = []
images = []
outputs = []

for i in range(batch_size):
img, output = test_dataset[i]
tensor = transform(img).unsqueeze(0)
batch.append(tensor)
images.append(img)
outputs.append(output)

batch = torch.cat(batch)
return batch, images, outputs


def load_torch_lenet(path, num_classes):
model = LeNet5(num_classes).to("cpu")
checkpoint = torch.load(path, map_location=torch.device("cpu"))
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
return model, checkpoint["model_state_dict"]


def custom_preprocessor(model, device):
parameters = {}

layers_to_process = ["layer1", "layer2", "fc", "fc1", "fc2"]

for layer in layers_to_process:
if layer.startswith("layer"):
conv_layer = getattr(model, layer)[0]
bn_layer = getattr(model, layer)[1]

weight = conv_layer.weight
bias = conv_layer.bias

running_mean = bn_layer.running_mean
running_var = bn_layer.running_var
eps = 1e-05

scale = bn_layer.weight
shift = bn_layer.bias

weight = weight * (scale / torch.sqrt(running_var + eps))[:, None, None, None]

if bias is not None:
bias = (bias - running_mean) * (scale / torch.sqrt(running_var + eps)) + shift
else:
bias = shift - running_mean * (scale / torch.sqrt(running_var + eps))

weight = ttnn.from_torch(weight, dtype=ttnn.bfloat16)
bias = ttnn.from_torch(bias, dtype=ttnn.bfloat16)
bias = ttnn.reshape(bias, (1, 1, 1, -1))

else: # Handling linear layers
linear_layer = getattr(model, layer)
weight = linear_layer.weight
weight = torch.permute(weight, (1, 0))
bias = linear_layer.bias
weight = ttnn.from_torch(weight, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT)
bias = ttnn.from_torch(bias, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT)

parameters[layer] = {"weight": weight, "bias": bias}

return parameters


def custom_preprocessor_device(parameters, device):
parameters.fc.weight = ttnn.to_device(parameters.fc.weight, device)
parameters.fc.bias = ttnn.to_device(parameters.fc.bias, device)
parameters.fc1.weight = ttnn.to_device(parameters.fc1.weight, device)
parameters.fc1.bias = ttnn.to_device(parameters.fc1.bias, device)
parameters.fc2.weight = ttnn.to_device(parameters.fc2.weight, device)
parameters.fc2.bias = ttnn.to_device(parameters.fc2.bias, device)

return parameters
126 changes: 126 additions & 0 deletions models/demos/lenet/tests/test_perf_lenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import pytest
import ttnn
import time
import ttnn

from loguru import logger

from ttnn.model_preprocessing import preprocess_model_parameters
from models.utility_functions import (
enable_persistent_kernel_cache,
disable_persistent_kernel_cache,
)
from models.utility_functions import is_grayskull, is_wormhole_b0
from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report
from models.perf.perf_utils import prep_perf_report
from models.demos.lenet import lenet_utils
from models.demos.lenet.tt import tt_lenet


def get_expected_times(tt_lenet):
if is_grayskull():
return {
tt_lenet: (7.62, 0.05),
}[tt_lenet]
elif is_wormhole_b0():
return {
tt_lenet: (10.75, 0.049),
}[tt_lenet]


@pytest.mark.parametrize(
"batch_size",
[64],
)
@pytest.mark.parametrize(
"tt_lenet",
[tt_lenet],
)
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
@pytest.mark.models_performance_bare_metal
def test_perf_lenet(device, batch_size, tt_lenet, model_location_generator, reset_seeds):
num_classes = 10
test_input, images, outputs = lenet_utils.get_test_data(batch_size)
pt_model_path = model_location_generator("model.pt", model_subdir="LeNet")
torch_lenet, state_dict = lenet_utils.load_torch_lenet(pt_model_path, num_classes)
model = torch_lenet.float()
disable_persistent_kernel_cache()

parameters = preprocess_model_parameters(
initialize_model=lambda: model,
custom_preprocessor=lenet_utils.custom_preprocessor,
)
parameters = lenet_utils.custom_preprocessor_device(parameters, device)
x = test_input.permute(0, 2, 3, 1)
x = ttnn.from_torch(x, dtype=ttnn.bfloat16)
durations = []

for _ in range(100):
start = time.time()
ttnn_output = tt_lenet.lenet(
device=device,
input_tensor=x,
parameters=parameters,
)
end = time.time()
durations.append(end - start)
enable_persistent_kernel_cache()

inference_and_compile_time, *inference_times = durations
inference_time = sum(inference_times) / len(inference_times)
expected_compile_time, expected_inference_time = get_expected_times(tt_lenet)

prep_perf_report(
model_name="tt_lenet",
batch_size=batch_size,
inference_and_compile_time=inference_and_compile_time,
inference_time=inference_time,
expected_compile_time=expected_compile_time,
expected_inference_time=expected_inference_time,
comments="",
inference_time_cpu=0.0,
)

logger.info(f"Compile time: {inference_and_compile_time - inference_time}")
logger.info(f"Inference time: {inference_time}")
logger.info(f"Inference times: {inference_times}")
logger.info(f"Sample(s) per second: {1 / inference_time * batch_size}")
assert (
inference_time < expected_inference_time
), f"Expected inference time: {expected_inference_time} Actual inference time: {inference_time}"
logger.info("Exit Lenet perf test")


@pytest.mark.parametrize(
"batch_size",
[64],
)
@pytest.mark.models_device_performance_bare_metal
def test_perf_device_bare_metal(batch_size, reset_seeds):
subdir = "tt_lenet"
num_iterations = 1
margin = 0.03
if is_grayskull():
expected_perf = 83102.20
elif is_wormhole_b0():
expected_perf = 46313.985

command = f"pytest tests/ttnn/integration_tests/lenet/test_lenet.py"
cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"]

inference_time_key = "AVG DEVICE KERNEL SAMPLES/S"
expected_perf_cols = {inference_time_key: expected_perf}

post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size)
expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols, assert_on_fail=True)
prep_device_perf_report(
model_name=f"tt_lenet{batch_size}",
batch_size=batch_size,
post_processed_results=post_processed_results,
expected_results=expected_results,
comments="",
)
Loading

0 comments on commit 82eb413

Please sign in to comment.