Skip to content

Commit 77be71e

Browse files
committed
Vulkan: Implement glu_split logic and shader support
1 parent e708394 commit 77be71e

File tree

6 files changed

+93
-113
lines changed

6 files changed

+93
-113
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,11 @@ struct vk_op_push_constants {
659659
float param2;
660660
};
661661

662+
struct vk_op_glu_push_constants {
663+
uint32_t ne00;
664+
uint32_t mode; // 0: default, 1: swapped, 2: split
665+
};
666+
662667
struct vk_op_unary_push_constants {
663668
uint32_t ne;
664669
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
@@ -2733,8 +2738,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
27332738
#undef CREATE_UNARY
27342739

27352740
#define CREATE_GLU(name) \
2736-
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); \
2737-
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
2741+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); \
2742+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
27382743

27392744
CREATE_GLU(geglu)
27402745
CREATE_GLU(reglu)
@@ -6947,7 +6952,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
69476952
}
69486953
}
69496954

6950-
if (op == GGML_OP_SOFT_MAX) {
6955+
if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) {
69516956
// Empty src1 is possible in soft_max, but the shader needs a buffer
69526957
vk_subbuffer subbuf_y;
69536958
if (use_src1) {
@@ -7539,12 +7544,23 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
75397544
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
75407545
}
75417546

7542-
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7543-
GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
7547+
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7548+
const bool swapped = (bool)dst->op_params[1];
7549+
const bool split = src1 != nullptr;
7550+
7551+
GGML_ASSERT(ggml_is_contiguous(src0));
7552+
7553+
if (!split) {
7554+
GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
7555+
} else {
7556+
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
7557+
GGML_ASSERT(src0->ne[0] == dst->ne[0]);
7558+
GGML_ASSERT(src0->type == src1->type);
7559+
}
75447560

7545-
const uint32_t swapped = (uint32_t)dst->op_params[1];
7561+
const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
75467562

7547-
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GLU, { (uint32_t)src0->ne[0], swapped, 0.0f, 0.0f }, dryrun);
7563+
ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)src0->ne[0], mode }, dryrun);
75487564
}
75497565

75507566
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -9003,7 +9019,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
90039019
case GGML_GLU_OP_GEGLU:
90049020
case GGML_GLU_OP_REGLU:
90059021
case GGML_GLU_OP_SWIGLU:
9006-
ggml_vk_glu(ctx, compute_ctx, src0, node, dryrun);
9022+
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
90079023
break;
90089024
default:
90099025
return false;
@@ -10725,7 +10741,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
1072510741
GGML_ABORT("fatal error");
1072610742
}
1072710743
} else if (tensor->op == GGML_OP_GLU) {
10728-
tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
10744+
if (src_clone[1] == nullptr) {
10745+
tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
10746+
} else {
10747+
tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
10748+
}
1072910749
} else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
1073010750
if (src1 == nullptr) {
1073110751
tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
Lines changed: 8 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,13 @@
11
#version 450
22

3-
#include "generic_head.comp"
4-
#include "types.comp"
3+
#include "glu_head.comp"
54

6-
#extension GL_EXT_control_flow_attributes : enable
5+
const float GELU_COEF_A = 0.044715f;
6+
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
77

8-
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
9-
10-
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11-
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12-
13-
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
14-
15-
void main() {
16-
const float GELU_COEF_A = 0.044715f;
17-
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
18-
19-
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
20-
const uint col = gl_LocalInvocationID.x;
21-
22-
const uint offset = p.KX / 2;
23-
24-
const bool swapped = p.KY > 0;
25-
26-
if (!swapped) {
27-
for (uint i = col; i < offset; i += BLOCK_SIZE) {
28-
const uint idx = row * p.KX + i;
29-
30-
const float xi = float(data_a[idx]);
31-
const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi);
32-
data_d[row * offset + i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1)) * float(data_a[idx + offset]));
33-
}
34-
} else {
35-
for (uint i = col; i < offset; i += BLOCK_SIZE) {
36-
const uint idx = row * p.KX + i;
37-
38-
const float xi = float(data_a[idx + offset]);
39-
const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi);
40-
data_d[row * offset + i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1)) * float(data_a[idx]));
41-
}
42-
}
8+
float op(float a, float b) {
9+
const float val = SQRT_2_OVER_PI*a*(1.0f + GELU_COEF_A*a*a);
10+
return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b;
4311
}
12+
13+
#include "glu_main.comp"
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#extension GL_EXT_shader_16bit_storage : require
2+
3+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
4+
5+
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
6+
layout (binding = 1) readonly buffer B {A_TYPE data_b[];};
7+
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
8+
9+
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
10+
11+
layout (push_constant) uniform parameter
12+
{
13+
uint ne00;
14+
uint mode;
15+
} p;
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
void main() {
2+
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
3+
const uint col = gl_LocalInvocationID.x;
4+
5+
if (p.mode == 0) {
6+
// Default
7+
const uint offset = p.ne00 / 2;
8+
9+
for (uint i = col; i < offset; i += BLOCK_SIZE) {
10+
const uint idx = row * p.ne00 + i;
11+
12+
data_d[row * offset + i] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
13+
}
14+
} else if (p.mode == 1) {
15+
// Swapped
16+
const uint offset = p.ne00 / 2;
17+
18+
for (uint i = col; i < offset; i += BLOCK_SIZE) {
19+
const uint idx = row * p.ne00 + i;
20+
21+
data_d[row * offset + i] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
22+
}
23+
} else {
24+
// Split
25+
for (uint i = col; i < p.ne00; i += BLOCK_SIZE) {
26+
const uint idx = row * p.ne00 + i;
27+
28+
data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
29+
}
30+
}
31+
}
Lines changed: 5 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,9 @@
11
#version 450
22

3-
#include "generic_head.comp"
4-
#include "types.comp"
3+
#include "glu_head.comp"
54

6-
#extension GL_EXT_control_flow_attributes : enable
7-
8-
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
9-
10-
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11-
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12-
13-
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
14-
15-
void main() {
16-
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
17-
const uint col = gl_LocalInvocationID.x;
18-
19-
const uint offset = p.KX / 2;
20-
21-
const bool swapped = p.KY > 0;
22-
23-
if (!swapped) {
24-
for (uint i = col; i < offset; i += BLOCK_SIZE) {
25-
const uint idx = row * p.KX + i;
26-
27-
data_d[row * offset + i] = D_TYPE(max(float(data_a[idx]), 0.0f) * float(data_a[idx + offset]));
28-
}
29-
} else {
30-
for (uint i = col; i < offset; i += BLOCK_SIZE) {
31-
const uint idx = row * p.KX + i;
32-
33-
data_d[row * offset + i] = D_TYPE(max(float(data_a[idx + offset]), 0.0f) * float(data_a[idx]));
34-
}
35-
}
5+
float op(float a, float b) {
6+
return max(a, 0.0f) * b;
367
}
8+
9+
#include "glu_main.comp"
Lines changed: 5 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,9 @@
11
#version 450
22

3-
#include "generic_head.comp"
4-
#include "types.comp"
3+
#include "glu_head.comp"
54

6-
#extension GL_EXT_control_flow_attributes : enable
7-
8-
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
9-
10-
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11-
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
12-
13-
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
14-
15-
void main() {
16-
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
17-
const uint col = gl_LocalInvocationID.x;
18-
19-
const uint offset = p.KX / 2;
20-
21-
const bool swapped = p.KY > 0;
22-
23-
if (!swapped) {
24-
for (uint i = col; i < offset; i += BLOCK_SIZE) {
25-
const uint idx = row * p.KX + i;
26-
27-
const float xi = float(data_a[idx]);
28-
data_d[row * offset + i] = D_TYPE(xi / (1.0f + exp(-xi)) * float(data_a[idx + offset]));
29-
}
30-
} else {
31-
for (uint i = col; i < offset; i += BLOCK_SIZE) {
32-
const uint idx = row * p.KX + i;
33-
34-
const float xi = float(data_a[idx + offset]);
35-
data_d[row * offset + i] = D_TYPE(xi / (1.0f + exp(-xi)) * float(data_a[idx]));
36-
}
37-
}
5+
float op(float a, float b) {
6+
return a / (1.0f + exp(-a)) * b;
387
}
8+
9+
#include "glu_main.comp"

0 commit comments

Comments
 (0)