Skip to content

Commit

Permalink
[tt-train] GPT2-S Matmul tests (#15937)
Browse files Browse the repository at this point in the history
### Problem description
Some matmuls fail during GPT2S training. 

### What's changed
Add tests for all matmul combinations from GPT2-S. 

### Checklist
- [x] Post commit CI passes
https://github.com/tenstorrent/tt-metal/actions/runs/12286281817
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
rfurko-tt authored Dec 12, 2024
1 parent 7e1ec65 commit bc00438
Showing 1 changed file with 95 additions and 0 deletions.
95 changes: 95 additions & 0 deletions tt-train/tests/model/gpt2s_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#include <gtest/gtest.h>

#include <core/ttnn_all_includes.hpp>

#include "autograd/auto_context.hpp"
#include "core/compute_kernel_config.hpp"
#include "core/tt_tensor_utils.hpp"

enum class ExpectedResult { OK, ERROR };

struct MatmulInput {
ttnn::Shape shape_a;
ttnn::Shape shape_b;
bool transpose_a{false};
bool transpose_b{false};
};

struct MatmulTest {
MatmulInput input;
ExpectedResult expected_result;
};

// Matmul tests are based on GPT2-S model with batch size 64
TEST(GPT2SBatch64Test, Matmul) {
std::vector<MatmulTest> tests = {
{{{64, 12, 64, 1024}, {64, 12, 1024, 64}, false, false}, ExpectedResult::OK},
{{{64, 12, 1024, 64}, {64, 12, 1024, 64}, false, true}, ExpectedResult::OK},
{{{64, 12, 1024, 64}, {64, 12, 1024, 64}, true, false}, ExpectedResult::OK},
{{{64, 12, 1024, 64}, {64, 12, 64, 1024}, false, false}, ExpectedResult::OK},
{{{64, 12, 1024, 1024}, {64, 12, 1024, 64}, false, false}, ExpectedResult::OK},
{{{768, 65536}, {65536, 96}, false, false}, ExpectedResult::OK},
{{{65536, 768}, {65536, 96}, true, false}, ExpectedResult::OK},
{{{65536, 96}, {1, 1, 96, 768}, false, false}, ExpectedResult::ERROR},
{{{65536, 96}, {1, 1, 768, 96}, false, true}, ExpectedResult::ERROR},
{{{3072, 65536}, {65536, 768}, false, false}, ExpectedResult::OK},
{{{65536, 3072}, {65536, 768}, true, false}, ExpectedResult::OK},
{{{65536, 768}, {1, 1, 768, 3072}, false, false}, ExpectedResult::ERROR},
{{{65536, 768}, {1, 1, 3072, 768}, false, true}, ExpectedResult::ERROR},
{{{768, 65536}, {65536, 3072}, false, false}, ExpectedResult::OK},
{{{65536, 768}, {65536, 3072}, true, false}, ExpectedResult::OK},
{{{65536, 3072}, {1, 1, 3072, 768}, false, false}, ExpectedResult::ERROR},
{{{65536, 3072}, {1, 1, 768, 3072}, false, true}, ExpectedResult::ERROR},
{{{65536, 3072}, {3072, 768}, false, false}, ExpectedResult::ERROR},
{{{65536, 3072}, {768, 3072}, false, true}, ExpectedResult::ERROR},
{{{768, 65536}, {65536, 768}, false, false}, ExpectedResult::OK},
{{{65536, 768}, {65536, 768}, true, false}, ExpectedResult::OK},
{{{65536, 768}, {1, 1, 768, 768}, false, false}, ExpectedResult::ERROR},
{{{768, 65536}, {1, 1, 768, 768}, true, false}, ExpectedResult::ERROR},
{{{768, 65536}, {65536, 2304}, false, false}, ExpectedResult::OK},
{{{65536, 768}, {65536, 2304}, true, false}, ExpectedResult::OK},
{{{65536, 768}, {768, 50257}, false, false}, ExpectedResult::ERROR},
{{{65536, 768}, {50257, 768}, false, true}, ExpectedResult::ERROR},
{{{65536, 50257}, {50257, 768}, false, false}, ExpectedResult::ERROR},
};

auto run_matmul = [](auto& a, auto& b, bool transpose_a, bool transpose_b) {
fmt::println(
"Running matmul with shapes {} and {}, tranpose_a {} transpose_b {}",
a.get_shape(),
b.get_shape(),
transpose_a,
transpose_b);
[[maybe_unused]] auto c = ttnn::matmul(
a,
b,
transpose_a,
transpose_b,
/* memory_config */ std::nullopt,
/* dtype */ std::nullopt,
/* program_config */ std::nullopt,
/* activation */ std::nullopt,
/* compute_kernel_config */
ttml::core::ComputeKernelConfig::matmul(),
/* core_grid */ ttnn::CoreGrid{7, 8},
/* output_tile */ std::nullopt);
};

for (const auto& [input, expected_result] : tests) {
auto [shape_a, shape_b, transpose_a, transpose_b] = input;

auto* device = &ttml::autograd::ctx().get_device();
auto a = ttml::core::empty(shape_a, device, {});
auto b = ttml::core::empty(shape_b, device, {});

if (expected_result == ExpectedResult::OK) {
EXPECT_NO_THROW(run_matmul(a, b, transpose_a, transpose_b));
} else {
EXPECT_ANY_THROW(run_matmul(a, b, transpose_a, transpose_b));
}
}
}

0 comments on commit bc00438

Please sign in to comment.