@@ -365,15 +365,20 @@ void main() {
365
365
366
366
const vec2 loadd = vec2(data_a[ib].d);
367
367
368
- uint8_t sc;
369
- uint8_t mbyte;
370
- if (is < 4) {
371
- sc = uint8_t(data_a[ib].scales[is ] & 63);
372
- mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
373
- } else {
374
- sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
375
- mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
376
- }
368
+ const uint scidx0 = (is < 4) ? is : (is + 4);
369
+ const uint scidx1 = (is < 4) ? is : (is - 4);
370
+ const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
371
+ const uint scidxshift1 = (is < 4) ? 0 : 2;
372
+ const uint mbidx0 = is + 4;
373
+ const uint mbidx1 = (is < 4) ? is + 4 : is;
374
+ const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
375
+ const uint mbidxshift0 = (is < 4) ? 0 : 4;
376
+ const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
377
+ const uint mbidxshift1 = (is < 4) ? 0 : 2;
378
+
379
+ const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
380
+ const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
381
+
377
382
const float d = loadd.x * sc;
378
383
const float m = -loadd.y * mbyte;
379
384
@@ -396,15 +401,20 @@ void main() {
396
401
397
402
const vec2 loadd = vec2(data_a[ib].d);
398
403
399
- uint8_t sc;
400
- uint8_t mbyte;
401
- if (is < 4) {
402
- sc = uint8_t(data_a[ib].scales[is ] & 63);
403
- mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
404
- } else {
405
- sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
406
- mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
407
- }
404
+ const uint scidx0 = (is < 4) ? is : (is + 4);
405
+ const uint scidx1 = (is < 4) ? is : (is - 4);
406
+ const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0;
407
+ const uint scidxshift1 = (is < 4) ? 0 : 2;
408
+ const uint mbidx0 = is + 4;
409
+ const uint mbidx1 = (is < 4) ? is + 4 : is;
410
+ const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0;
411
+ const uint mbidxshift0 = (is < 4) ? 0 : 4;
412
+ const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
413
+ const uint mbidxshift1 = (is < 4) ? 0 : 2;
414
+
415
+ const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1));
416
+ const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1));
417
+
408
418
const float d = loadd.x * sc;
409
419
const float m = -loadd.y * mbyte;
410
420
@@ -547,8 +557,8 @@ void main() {
547
557
548
558
#ifdef COOPMAT
549
559
#ifdef MUL_MAT_ID
550
- for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
551
- for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
560
+ [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
561
+ [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
552
562
coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
553
563
554
564
[[unroll]] for (uint col = 0; col < BN; col += storestride) {
@@ -564,8 +574,8 @@ void main() {
564
574
#else
565
575
const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
566
576
567
- for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
568
- for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
577
+ [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
578
+ [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
569
579
const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
570
580
571
581
if (is_aligned && is_in_bounds) {
0 commit comments