Skip to content

Commit 990ee1b

Browse files
committed
ggml : add ggml_set_rows
Add ggml_set_rows(a, b, c) which copies rows from 'b' into 'a' using indices from 'c'. ref: ggml-org#8366
1 parent 58cba76 commit 990ee1b

File tree

5 files changed

+98
-2
lines changed

5 files changed

+98
-2
lines changed

ggml/include/ggml.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ extern "C" {
470470
GGML_OP_TRANSPOSE,
471471
GGML_OP_GET_ROWS,
472472
GGML_OP_GET_ROWS_BACK,
473+
GGML_OP_SET_ROWS,
473474
GGML_OP_DIAG,
474475
GGML_OP_DIAG_MASK_INF,
475476
GGML_OP_DIAG_MASK_ZERO,
@@ -1375,6 +1376,12 @@ extern "C" {
13751376
struct ggml_tensor * b, // row indices
13761377
struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape
13771378

1379+
GGML_API struct ggml_tensor * ggml_set_rows(
1380+
struct ggml_context * ctx,
1381+
struct ggml_tensor * a, // destination
1382+
struct ggml_tensor * b, // source
1383+
struct ggml_tensor * c); // row indices
1384+
13781385
GGML_API struct ggml_tensor * ggml_diag(
13791386
struct ggml_context * ctx,
13801387
struct ggml_tensor * a);

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1814,6 +1814,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
18141814
{
18151815
ggml_compute_forward_get_rows_back(params, tensor);
18161816
} break;
1817+
case GGML_OP_SET_ROWS:
1818+
{
1819+
ggml_compute_forward_set_rows(params, tensor);
1820+
} break;
18171821
case GGML_OP_DIAG:
18181822
{
18191823
ggml_compute_forward_diag(params, tensor);
@@ -2167,6 +2171,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
21672171
n_tasks = n_threads;
21682172
} break;
21692173
case GGML_OP_GET_ROWS:
2174+
case GGML_OP_SET_ROWS:
21702175
{
21712176
// FIXME: get_rows can use additional threads, but the cost of launching additional threads
21722177
// decreases performance with GPU offloading

ggml/src/ggml-cpu/ops.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4470,6 +4470,65 @@ void ggml_compute_forward_get_rows(
44704470
//}
44714471
}
44724472

4473+
static void ggml_compute_forward_set_rows_f32(
4474+
const ggml_compute_params * params,
4475+
ggml_tensor * dst) {
4476+
4477+
const ggml_tensor * src0 = dst->src[0];
4478+
const ggml_tensor * src1 = dst->src[1];
4479+
4480+
GGML_TENSOR_BINARY_OP_LOCALS
4481+
4482+
const int64_t nc = ne00;
4483+
const int64_t nr = ggml_nelements(src1);
4484+
4485+
assert(ne0 == nc);
4486+
assert(ne02 == ne11);
4487+
assert(nb00 == sizeof(float));
4488+
assert(ggml_nrows(src0) == nr);
4489+
4490+
const int ith = params->ith;
4491+
const int nth = params->nth;
4492+
4493+
// rows per thread
4494+
const int dr = (nr + nth - 1)/nth;
4495+
4496+
// row range for this thread
4497+
const int ir0 = dr*ith;
4498+
const int ir1 = MIN(ir0 + dr, nr);
4499+
4500+
for (int64_t i = ir0; i < ir1; ++i) {
4501+
const int64_t i12 = i/(ne11*ne10);
4502+
const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4503+
const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4504+
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4505+
4506+
GGML_ASSERT(i01 >= 0 && i01 < ne1);
4507+
4508+
ggml_cpu_fp32_to_fp16(
4509+
(const float *) ((char *) src0->data + i10*nb01 + i11*nb02 + i12*nb03),
4510+
(ggml_fp16_t *) ((char *) dst->data + i01*nb1 + i11*nb2 + i12*nb3), nc);
4511+
}
4512+
}
4513+
4514+
void ggml_compute_forward_set_rows(
4515+
const ggml_compute_params * params,
4516+
ggml_tensor * dst) {
4517+
4518+
const ggml_tensor * src0 = dst->src[0];
4519+
4520+
switch (src0->type) {
4521+
case GGML_TYPE_F32:
4522+
{
4523+
ggml_compute_forward_set_rows_f32(params, dst);
4524+
} break;
4525+
default:
4526+
{
4527+
GGML_ABORT("fatal error");
4528+
}
4529+
}
4530+
}
4531+
44734532
// ggml_compute_forward_get_rows_back
44744533

44754534
static void ggml_compute_forward_get_rows_back_f32_f16(

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ void ggml_compute_forward_permute(const struct ggml_compute_params * params, str
5353
void ggml_compute_forward_transpose(const struct ggml_compute_params * params, struct ggml_tensor * dst);
5454
void ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
5555
void ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
56+
void ggml_compute_forward_set_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
5657
void ggml_compute_forward_diag(const struct ggml_compute_params * params, struct ggml_tensor * dst);
5758
void ggml_compute_forward_diag_mask_inf(const struct ggml_compute_params * params, struct ggml_tensor * dst);
5859
void ggml_compute_forward_diag_mask_zero(const struct ggml_compute_params * params, struct ggml_tensor * dst);

ggml/src/ggml.c

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
936936
"TRANSPOSE",
937937
"GET_ROWS",
938938
"GET_ROWS_BACK",
939+
"SET_ROWS",
939940
"DIAG",
940941
"DIAG_MASK_INF",
941942
"DIAG_MASK_ZERO",
@@ -986,7 +987,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
986987
"OPT_STEP_ADAMW",
987988
};
988989

989-
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
990+
static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
990991

991992
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
992993
"none",
@@ -1032,6 +1033,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10321033
"transpose(x)",
10331034
"get_rows(x)",
10341035
"get_rows_back(x)",
1036+
"set_rows(x)",
10351037
"diag(x)",
10361038
"diag_mask_inf(x)",
10371039
"diag_mask_zero(x)",
@@ -1082,7 +1084,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
10821084
"adamw(x)",
10831085
};
10841086

1085-
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
1087+
static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
10861088

10871089
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
10881090

@@ -3395,6 +3397,28 @@ struct ggml_tensor * ggml_get_rows_back(
33953397
return result;
33963398
}
33973399

3400+
// ggml_set_rows
3401+
3402+
struct ggml_tensor * ggml_set_rows(
3403+
struct ggml_context * ctx,
3404+
struct ggml_tensor * a,
3405+
struct ggml_tensor * b,
3406+
struct ggml_tensor * c) {
3407+
GGML_ASSERT(b->ne[2] == c->ne[1]);
3408+
GGML_ASSERT(c->ne[3] == 1);
3409+
GGML_ASSERT(a->type == GGML_TYPE_F16);
3410+
GGML_ASSERT(b->type == GGML_TYPE_F32);
3411+
GGML_ASSERT(c->type == GGML_TYPE_I32);
3412+
3413+
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
3414+
3415+
result->op = GGML_OP_SET_ROWS;
3416+
result->src[0] = b;
3417+
result->src[1] = c;
3418+
3419+
return result;
3420+
}
3421+
33983422
// ggml_diag
33993423

34003424
struct ggml_tensor * ggml_diag(

0 commit comments

Comments
 (0)