Skip to content

Commit 20a6246

Browse files
committed
vulkan: avoid using Float16 capability in scalar FA
1 parent 615958f commit 20a6246

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
137137
shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
138138
shared vec4 tmpshv4[gl_WorkGroupSize.x];
139139

140-
shared float16_t masksh[Bc][Br];
140+
shared float masksh[Bc][Br];
141141
shared vec4 Qf[Br][D / 4];
142142

143143
void main() {
@@ -296,14 +296,14 @@ void main() {
296296
uint32_t c = (idx + tid) % Bc;
297297
uint32_t r = (idx + tid) / Bc;
298298
if (idx + tid < Bc * Br) {
299-
masksh[c][r] = data_m[(i * Br + r) * m_stride + (j * Bc + c)];
299+
masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]);
300300
}
301301
}
302302
barrier();
303303

304304
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
305305
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
306-
float mvf = float(masksh[c * cols_per_iter + col_tid][r]);
306+
float mvf = masksh[c * cols_per_iter + col_tid][r];
307307

308308
Sf[r][c] += slope[r]*mvf;
309309
}

0 commit comments

Comments
 (0)