@@ -659,6 +659,11 @@ struct vk_op_push_constants {
659
659
float param2;
660
660
};
661
661
662
+ struct vk_op_glu_push_constants {
663
+ uint32_t ne00;
664
+ uint32_t mode; // 0: default, 1: swapped, 2: split
665
+ };
666
+
662
667
struct vk_op_unary_push_constants {
663
668
uint32_t ne;
664
669
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
@@ -2733,8 +2738,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2733
2738
#undef CREATE_UNARY
2734
2739
2735
2740
#define CREATE_GLU(name) \
2736
- ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2 , sizeof(vk_op_push_constants ), {1, 1, 1}, { device->subgroup_size }, 1); \
2737
- ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2 , sizeof(vk_op_push_constants ), {1, 1, 1}, { device->subgroup_size }, 1);
2741
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3 , sizeof(vk_op_glu_push_constants ), {1, 1, 1}, { device->subgroup_size }, 1); \
2742
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3 , sizeof(vk_op_glu_push_constants ), {1, 1, 1}, { device->subgroup_size }, 1);
2738
2743
2739
2744
CREATE_GLU(geglu)
2740
2745
CREATE_GLU(reglu)
@@ -6947,7 +6952,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6947
6952
}
6948
6953
}
6949
6954
6950
- if (op == GGML_OP_SOFT_MAX) {
6955
+ if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU ) {
6951
6956
// Empty src1 is possible in soft_max, but the shader needs a buffer
6952
6957
vk_subbuffer subbuf_y;
6953
6958
if (use_src1) {
@@ -7539,12 +7544,23 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
7539
7544
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
7540
7545
}
7541
7546
7542
- static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7543
- GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
7547
+ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7548
+ const bool swapped = (bool)dst->op_params[1];
7549
+ const bool split = src1 != nullptr;
7550
+
7551
+ GGML_ASSERT(ggml_is_contiguous(src0));
7552
+
7553
+ if (!split) {
7554
+ GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
7555
+ } else {
7556
+ GGML_ASSERT(src0->ne[0] == src1->ne[0]);
7557
+ GGML_ASSERT(src0->ne[0] == dst->ne[0]);
7558
+ GGML_ASSERT(src0->type == src1->type);
7559
+ }
7544
7560
7545
- const uint32_t swapped = (uint32_t)dst->op_params[1] ;
7561
+ const uint32_t mode = split ? 2 : (swapped ? 1 : 0) ;
7546
7562
7547
- ggml_vk_op_f32<vk_op_push_constants >(ctx, subctx, src0, nullptr , nullptr, dst, GGML_OP_GLU, { (uint32_t)src0->ne[0], swapped, 0.0f, 0.0f }, dryrun);
7563
+ ggml_vk_op_f32<vk_op_glu_push_constants >(ctx, subctx, src0, src1 , nullptr, dst, GGML_OP_GLU, { (uint32_t)src0->ne[0], mode }, dryrun);
7548
7564
}
7549
7565
7550
7566
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -9003,7 +9019,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
9003
9019
case GGML_GLU_OP_GEGLU:
9004
9020
case GGML_GLU_OP_REGLU:
9005
9021
case GGML_GLU_OP_SWIGLU:
9006
- ggml_vk_glu(ctx, compute_ctx, src0, node, dryrun);
9022
+ ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
9007
9023
break;
9008
9024
default:
9009
9025
return false;
@@ -10725,7 +10741,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
10725
10741
GGML_ABORT("fatal error");
10726
10742
}
10727
10743
} else if (tensor->op == GGML_OP_GLU) {
10728
- tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
10744
+ if (src_clone[1] == nullptr) {
10745
+ tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
10746
+ } else {
10747
+ tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
10748
+ }
10729
10749
} else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
10730
10750
if (src1 == nullptr) {
10731
10751
tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
0 commit comments