Skip to content

Commit 2235e14

Browse files
ggerganovrgerganov
authored andcommitted
metal : add ggml_set_rows implementation
ggml-ci
1 parent a723fdf commit 2235e14

File tree

3 files changed

+411
-184
lines changed

3 files changed

+411
-184
lines changed

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,22 @@ typedef struct {
521521
uint64_t nb2;
522522
} ggml_metal_kargs_get_rows;
523523

524+
typedef struct {
525+
int32_t nk0;
526+
int32_t ne01;
527+
uint64_t nb01;
528+
uint64_t nb02;
529+
uint64_t nb03;
530+
int32_t ne11;
531+
int32_t ne12;
532+
uint64_t nb10;
533+
uint64_t nb11;
534+
uint64_t nb12;
535+
uint64_t nb1;
536+
uint64_t nb2;
537+
uint64_t nb3;
538+
} ggml_metal_kargs_set_rows;
539+
524540
typedef struct {
525541
int64_t ne00;
526542
int64_t ne01;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 105 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,15 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
202202
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
203203
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
204204
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
205+
GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
206+
GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
207+
GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
208+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,
209+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,
210+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,
211+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
212+
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
213+
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
205214
GGML_METAL_KERNEL_TYPE_RMS_NORM,
206215
GGML_METAL_KERNEL_TYPE_L2_NORM,
207216
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -1166,6 +1175,15 @@ @implementation GGMLMetalClass
11661175
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
11671176
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
11681177
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
1178+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true);
1179+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true);
1180+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat);
1181+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true);
1182+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true);
1183+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true);
1184+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true);
1185+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
1186+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
11691187
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
11701188
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
11711189
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
@@ -1630,7 +1648,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
16301648

16311649
if (!use_bfloat) {
16321650
for (size_t i = 0, n = 3; i < n; ++i) {
1633-
if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
1651+
if (op->src[i] != NULL && (op->src[i]->type == GGML_TYPE_BF16 || op->type == GGML_TYPE_BF16)) {
16341652
return false;
16351653
}
16361654
}
@@ -1798,6 +1816,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
17981816
{
17991817
return op->ne[3] == 1;
18001818
}
1819+
case GGML_OP_SET_ROWS:
1820+
{
1821+
if (op->src[0]->type != GGML_TYPE_F32) {
1822+
return false;
1823+
}
1824+
1825+
switch (op->type) {
1826+
case GGML_TYPE_F32:
1827+
case GGML_TYPE_F16:
1828+
case GGML_TYPE_BF16:
1829+
case GGML_TYPE_Q8_0:
1830+
case GGML_TYPE_Q4_0:
1831+
case GGML_TYPE_Q4_1:
1832+
case GGML_TYPE_Q5_0:
1833+
case GGML_TYPE_Q5_1:
1834+
case GGML_TYPE_IQ4_NL:
1835+
return true;
1836+
default:
1837+
return false;
1838+
};
1839+
}
18011840
default:
18021841
return false;
18031842
}
@@ -3757,13 +3796,74 @@ static bool ggml_metal_encode_node(
37573796
};
37583797

37593798
[encoder setComputePipelineState:pipeline];
3760-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
3761-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
3762-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
3763-
[encoder setBytes:&args length:sizeof(args) atIndex:3];
3799+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
3800+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3801+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3802+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
37643803

37653804
[encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
37663805
} break;
3806+
case GGML_OP_SET_ROWS:
3807+
{
3808+
id<MTLComputePipelineState> pipeline = nil;
3809+
3810+
switch (dst->type) {
3811+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline; break;
3812+
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline; break;
3813+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline; break;
3814+
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline; break;
3815+
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline; break;
3816+
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline; break;
3817+
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline; break;
3818+
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline; break;
3819+
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline; break;
3820+
default: GGML_ABORT("not implemented");
3821+
}
3822+
3823+
const int32_t nk0 = ne0/ggml_blck_size(dst->type);
3824+
3825+
int nth = 32; // SIMD width
3826+
3827+
while (nth < nk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
3828+
nth *= 2;
3829+
}
3830+
3831+
int nrptg = 1;
3832+
if (nth > nk0) {
3833+
nrptg = (nth + nk0 - 1)/nk0;
3834+
nth = nk0;
3835+
3836+
if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
3837+
nrptg--;
3838+
}
3839+
}
3840+
3841+
nth = MIN(nth, nk0);
3842+
3843+
ggml_metal_kargs_set_rows args = {
3844+
/*.nk0 =*/ nk0,
3845+
/*.ne01 =*/ ne01,
3846+
/*.nb01 =*/ nb01,
3847+
/*.nb02 =*/ nb02,
3848+
/*.nb03 =*/ nb03,
3849+
/*.ne11 =*/ ne11,
3850+
/*.ne12 =*/ ne12,
3851+
/*.nb10 =*/ nb10,
3852+
/*.nb11 =*/ nb11,
3853+
/*.nb12 =*/ nb12,
3854+
/*.nb1 =*/ nb1,
3855+
/*.nb2 =*/ nb2,
3856+
/*.nb3 =*/ nb3,
3857+
};
3858+
3859+
[encoder setComputePipelineState:pipeline];
3860+
[encoder setBytes:&args length:sizeof(args) atIndex:0];
3861+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
3862+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
3863+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
3864+
3865+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
3866+
} break;
37673867
case GGML_OP_RMS_NORM:
37683868
{
37693869
GGML_ASSERT(ne00 % 4 == 0);

0 commit comments

Comments
 (0)