Skip to content

Commit ab3971f

Browse files
authored
vulkan: workaround FA compile failures on macos (#13517)
1 parent e5c834f commit ab3971f

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212

1313
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
1414

15+
layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
1516
layout (constant_id = 1) const uint32_t Br = 1;
1617
layout (constant_id = 2) const uint32_t Bc = 32;
1718
layout (constant_id = 3) const uint32_t D = 32;
1819

1920
layout (constant_id = 5) const uint32_t D_split = 16;
2021
const uint32_t D_per_thread = D / D_split;
2122

22-
const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split;
23+
const uint32_t cols_per_iter = WorkGroupSize / D_split;
2324
const uint32_t cols_per_thread = Bc / cols_per_iter;
2425

2526
layout (push_constant) uniform parameter {
@@ -134,8 +135,8 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
134135
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
135136
}
136137

137-
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
138-
shared vec4 tmpshv4[gl_WorkGroupSize.x];
138+
shared FLOAT_TYPE tmpsh[WorkGroupSize];
139+
shared vec4 tmpshv4[WorkGroupSize];
139140

140141
shared float masksh[Bc][Br];
141142
shared vec4 Qf[Br][D / 4];

0 commit comments

Comments
 (0)