Skip to content

Commit

Permalink
Add conv_transpose_1d_gemm operator
Browse files Browse the repository at this point in the history
Signed-off-by: Salvatore Mesoraca <[email protected]>
  • Loading branch information
smeso committed Sep 7, 2024
1 parent 19c5440 commit 3de2490
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
8 changes: 8 additions & 0 deletions include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -1663,6 +1663,14 @@ extern "C" {
int p0, // padding
int d0); // dilation

GGML_API struct ggml_tensor * ggml_conv_transpose_1d_gemm(
struct ggml_context * ctx,
struct ggml_tensor * a, // convolution kernel
struct ggml_tensor * b, // data
int s0, // stride
int p0, // padding
int d0); // dilation

GGML_API struct ggml_tensor * ggml_conv_2d(
struct ggml_context * ctx,
struct ggml_tensor * a, // convolution kernel
Expand Down
42 changes: 42 additions & 0 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -6774,6 +6774,48 @@ GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
return result;
}

GGML_API struct ggml_tensor * ggml_conv_transpose_1d_gemm(
struct ggml_context * ctx,
struct ggml_tensor * a, // KW OC IC
struct ggml_tensor * b, // IW IC N
int s0,
int p0,
int d0) {
GGML_ASSERT(a->ne[3] == 1);
GGML_ASSERT(b->ne[3] == 1);
GGML_ASSERT(a->ne[2] == b->ne[1]);

a = ggml_cont(ctx, ggml_permute(ctx, a, 2, 1, 0, 3)); // KW OC IC -> IC OC KW
b = ggml_permute(ctx, b, 1, 0, 2, 3); // IW IC N -> IC IW N
if (a->type == b->type)
b = ggml_cont(ctx, b);
else
b = ggml_cast(ctx, b, a->type);
const int64_t IC = a->ne[0];
assert(IC == b->ne[0]);
const int64_t KW = a->ne[2];
const int64_t OC = a->ne[1];
const int64_t IW = b->ne[1];
const int64_t N = b->ne[2];
// The following line isn't necessary, in theory,
// but makes CUDA use cublasSgemm instead of
// cublasGemmBatchedEx.
// The latter doesn't pass test-backend-ops
// because of F16 approximations
a = ggml_reshape_4d(ctx, a, IC, OC*KW, 1, 1);
b = ggml_reshape_4d(ctx, b, IC, IW*N, 1, 1);
struct ggml_tensor * mulres = ggml_mul_mat(ctx, b, a);
mulres = ggml_reshape_4d(ctx, mulres, IW, N, OC, KW);
mulres = ggml_permute(ctx, mulres, 0, 3, 2, 1); // -> IW KW OC N
return ggml_col2im(ctx,
mulres,
s0, 1 /* s1 */,
p0, 0 /* p1 */,
d0, 1 /* d1 */,
1 /* KH */,
1 /* IH */);
}

// ggml_conv_depthwise
struct ggml_tensor * ggml_conv_depthwise_2d(
struct ggml_context * ctx,
Expand Down

0 comments on commit 3de2490

Please sign in to comment.