Skip to content

Commit aa17d32

Browse files
vulkan: avoid using workgroup size before it is referenced
1 parent 118b4f0 commit aa17d32

14 files changed

+26
-26
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
1313

1414
void main() {
1515
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
16-
init_iq_shmem();
16+
init_iq_shmem(gl_WorkGroupSize);
1717
if (gl_LocalInvocationIndex.x != 0) {
1818
return;
1919
}

ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ void quantize(uint dst_idx, uint src_idx)
218218

219219
void main() {
220220
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
221-
init_iq_shmem();
221+
init_iq_shmem(gl_WorkGroupSize);
222222
if (gl_LocalInvocationIndex.x != 0) {
223223
return;
224224
}

ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ void main() {
1111
// Each thread handles 1 subblock (32 values with 2 scales)
1212
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
1313

14-
init_iq_shmem();
14+
init_iq_shmem(gl_WorkGroupSize);
1515

1616
if (ib >= p.nel / 256) {
1717
return;

ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ void main() {
1111
// Each thread handles 1 subblock (32 values with 2 scales)
1212
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
1313

14-
init_iq_shmem();
14+
init_iq_shmem(gl_WorkGroupSize);
1515

1616
if (ib >= p.nel / 256) {
1717
return;

ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ void main() {
1212
// Each block is described by 4 lattice indices, 4x7 sign bits and 4 scale bits
1313
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
1414

15-
init_iq_shmem();
15+
init_iq_shmem(gl_WorkGroupSize);
1616

1717
if (ib >= p.nel / 256) {
1818
return;

ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ void main() {
1212
// Each block contains 4 scale bytes (8 scales) for 256 output values.
1313
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
1414

15-
init_iq_shmem();
15+
init_iq_shmem(gl_WorkGroupSize);
1616

1717
if (ib >= p.nel / 256) {
1818
return;

ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ void main() {
1212
// 8 threads handle 1 superblock
1313
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
1414

15-
init_iq_shmem();
15+
init_iq_shmem(gl_WorkGroupSize);
1616

1717
if (ib >= p.nel / 256) {
1818
return;

ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
1010
void main() {
1111
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
1212

13-
init_iq_shmem();
13+
init_iq_shmem(gl_WorkGroupSize);
1414

1515
const uint tid = gl_LocalInvocationID.x % 64;
1616
const uint il = tid/32;

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
105105

106106
void main() {
107107
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
108-
init_iq_shmem();
108+
init_iq_shmem(gl_WorkGroupSize);
109109
#endif
110110

111111
const uint32_t N = p.N;

ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ void main() {
1313
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
1414

1515
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
16-
init_iq_shmem();
16+
init_iq_shmem(gl_WorkGroupSize);
1717
#endif
1818

1919
if (i00 >= p.ne00) {

ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ void main() {
134134
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
135135

136136
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
137-
init_iq_shmem();
137+
init_iq_shmem(gl_WorkGroupSize);
138138
#endif
139139

140140
// do NUM_ROWS at a time, unless there aren't enough remaining rows

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
9696

9797
void main() {
9898
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
99-
init_iq_shmem();
99+
init_iq_shmem(gl_WorkGroupSize);
100100
#endif
101101

102102
#ifdef MUL_MAT_ID

ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
107107

108108
void main() {
109109
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
110-
init_iq_shmem();
110+
init_iq_shmem(gl_WorkGroupSize);
111111
#endif
112112

113113
#ifdef MUL_MAT_ID

ggml/src/ggml-vulkan/vulkan-shaders/types.comp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -380,10 +380,10 @@ const uvec2[256] iq2xxs_grid_const = {
380380

381381
shared uvec2 iq2xxs_grid[256];
382382

383-
void init_iq_shmem()
383+
void init_iq_shmem(uvec3 wgsize)
384384
{
385385
// copy the table into shared memory and sync
386-
for (uint i = gl_LocalInvocationIndex.x; i < iq2xxs_grid.length(); i += gl_WorkGroupSize.x) {
386+
for (uint i = gl_LocalInvocationIndex.x; i < iq2xxs_grid.length(); i += wgsize.x) {
387387
iq2xxs_grid[i] = iq2xxs_grid_const[i];
388388
}
389389
barrier();
@@ -547,10 +547,10 @@ const uvec2 iq2xs_grid_const[512] = {
547547

548548
shared uvec2 iq2xs_grid[512];
549549

550-
void init_iq_shmem()
550+
void init_iq_shmem(uvec3 wgsize)
551551
{
552552
// copy the table into shared memory and sync
553-
for (uint i = gl_LocalInvocationIndex.x; i < iq2xs_grid.length(); i += gl_WorkGroupSize.x) {
553+
for (uint i = gl_LocalInvocationIndex.x; i < iq2xs_grid.length(); i += wgsize.x) {
554554
iq2xs_grid[i] = iq2xs_grid_const[i];
555555
}
556556
barrier();
@@ -836,10 +836,10 @@ const uvec2 iq2s_grid_const[1024] = {
836836

837837
shared uvec2 iq2s_grid[1024];
838838

839-
void init_iq_shmem()
839+
void init_iq_shmem(uvec3 wgsize)
840840
{
841841
// copy the table into shared memory and sync
842-
for (uint i = gl_LocalInvocationIndex.x; i < iq2s_grid.length(); i += gl_WorkGroupSize.x) {
842+
for (uint i = gl_LocalInvocationIndex.x; i < iq2s_grid.length(); i += wgsize.x) {
843843
iq2s_grid[i] = iq2s_grid_const[i];
844844
}
845845
barrier();
@@ -904,10 +904,10 @@ const uint32_t iq3xxs_grid_const[256] = {
904904

905905
shared uint32_t iq3xxs_grid[256];
906906

907-
void init_iq_shmem()
907+
void init_iq_shmem(uvec3 wgsize)
908908
{
909909
// copy the table into shared memory and sync
910-
for (uint i = gl_LocalInvocationIndex.x; i < iq3xxs_grid.length(); i += gl_WorkGroupSize.x) {
910+
for (uint i = gl_LocalInvocationIndex.x; i < iq3xxs_grid.length(); i += wgsize.x) {
911911
iq3xxs_grid[i] = iq3xxs_grid_const[i];
912912
}
913913
barrier();
@@ -1011,10 +1011,10 @@ const uint32_t iq3s_grid_const[512] = {
10111011

10121012
shared uint32_t iq3s_grid[512];
10131013

1014-
void init_iq_shmem()
1014+
void init_iq_shmem(uvec3 wgsize)
10151015
{
10161016
// copy the table into shared memory and sync
1017-
for (uint i = gl_LocalInvocationIndex.x; i < iq3s_grid.length(); i += gl_WorkGroupSize.x) {
1017+
for (uint i = gl_LocalInvocationIndex.x; i < iq3s_grid.length(); i += wgsize.x) {
10181018
iq3s_grid[i] = iq3s_grid_const[i];
10191019
}
10201020
barrier();
@@ -1050,11 +1050,11 @@ const int8_t kvalues_iq4nl_const[16] = {
10501050

10511051
shared FLOAT_TYPE kvalues_iq4nl[16];
10521052

1053-
void init_iq_shmem()
1053+
void init_iq_shmem(uvec3 wgsize)
10541054
{
10551055
// copy the table into shared memory and sync
1056-
if (gl_LocalInvocationIndex.x < 16) {
1057-
kvalues_iq4nl[gl_LocalInvocationIndex.x] = FLOAT_TYPE(kvalues_iq4nl_const[gl_LocalInvocationIndex.x]);
1056+
for (uint i = gl_LocalInvocationIndex.x; i < kvalues_iq4nl.length(); i += wgsize.x) {
1057+
kvalues_iq4nl[i] = FLOAT_TYPE(kvalues_iq4nl_const[i]);
10581058
}
10591059
barrier();
10601060
}

0 commit comments

Comments
 (0)