diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index bff7dea3a539b..1a966f15c2ac3 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -518,6 +518,8 @@ extern "C" { GGML_OP_CROSS_ENTROPY_LOSS_BACK, GGML_OP_OPT_STEP_ADAMW, + GGML_OP_FILL, + GGML_OP_COUNT, }; @@ -1818,6 +1820,12 @@ extern "C" { float stop, float step); + // fill in-place the tensor with a constant value, return view(a) + GGML_API struct ggml_tensor * ggml_fill( + struct ggml_context * ctx, + struct ggml_tensor * a, + float value); + // top k elements per row GGML_API struct ggml_tensor * ggml_top_k( struct ggml_context * ctx, diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index aa51dc21a5de4..504d7552b5828 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1959,6 +1959,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_arange(params, tensor); } break; + case GGML_OP_FILL: + { + ggml_compute_forward_fill(params, tensor); + } break; case GGML_OP_TIMESTEP_EMBEDDING: { ggml_compute_forward_timestep_embedding(params, tensor); @@ -2242,6 +2246,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_TRANSPOSE: case GGML_OP_GET_ROWS_BACK: case GGML_OP_DIAG: + case GGML_OP_FILL: { n_tasks = 1; } break; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 26501b7118b95..381ca1a46eb52 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -6833,6 +6833,55 @@ void ggml_compute_forward_arange( } } +// ggml_compute_forward_fill + +static void ggml_compute_forward_fill_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + float v; + memcpy(&v, dst->op_params, sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int n = ggml_nrows(dst); + const int nc = dst->ne[0]; + + const size_t nb00 = dst->nb[0]; + + const size_t nb0 = dst->nb[0]; + const size_t nb1 = dst->nb[1]; + + GGML_ASSERT( nb0 == sizeof(float)); + GGML_ASSERT(nb00 == sizeof(float)); + + for (int j = ith; j < n; j += nth) { + float * dst_ptr = (float *) ((char *) dst->data + j*nb1); + + for (int i = 0; i < nc; i++) { + dst_ptr[i] = v; + } + } +} + +void ggml_compute_forward_fill( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_fill_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + static void ggml_compute_forward_timestep_embedding_f32( const ggml_compute_params * params, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index dc081b9e66397..560e4dff056e0 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -73,6 +73,7 @@ void ggml_compute_forward_upscale(const struct ggml_compute_params * params, str void ggml_compute_forward_pad(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_pad_reflect_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 57d3e39adf758..57f11b3e43eb2 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -982,9 +982,11 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS", "CROSS_ENTROPY_LOSS_BACK", "OPT_STEP_ADAMW", + + "FILL", }; -static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); +static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1077,9 +1079,11 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss(x,y)", "cross_entropy_loss_back(x,y)", "adamw(x)", + + "fill(x)", }; -static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); +static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4342,6 +4346,20 @@ struct ggml_tensor * ggml_arange( return result; } +struct ggml_tensor * ggml_fill( + struct ggml_context * ctx, + struct ggml_tensor * a, + float value) { + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + ggml_set_op_params(result, &value, sizeof(value)); + + result->op = GGML_OP_FILL; + result->src[0] = a; + + return result; +} + // ggml_timestep_embedding struct ggml_tensor * ggml_timestep_embedding( diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 543db93402190..280d56a849a39 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2422,6 +2422,32 @@ struct test_clamp : public test_case { } }; +// GGML_OP_FILL +struct test_fill : public test_case { + const ggml_type type; + const std::array ne; + float v; + + std::string vars() override { + return VARS_TO_STR3(type, ne, v); + } + + test_fill(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 5, 4, 3}, + float v = 0.5f) + : type(type), ne(ne), v(v) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_fill(ctx, a, v); + ggml_set_name(out, "out"); + + return out; + } +}; + // GGML_OP_DIAG_MASK_INF struct test_diag_mask_inf : public test_case { const ggml_type type; @@ -4199,6 +4225,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4)); test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4)); + test_cases.emplace_back(new test_fill(GGML_TYPE_F32)); + for (ggml_type type_a : all_types) { for (int i = 1; i < 10; ++i) { test_cases.emplace_back(new test_mul_mat(type_a, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));