@@ -397,19 +397,19 @@ func.func @wgmma_s32_s8_s8_satfinite(%descA : i64, %descB : i64) -> !mat16i32{
397
397
#nvvm.shape <m = 64 , n = 8 , k = 32 >,
398
398
D [<s32 >, #nvvm.wgmma_scale_out <one >, <satfinite >],
399
399
A [<s8 >, #nvvm.wgmma_scale_in <one >, <row >],
400
- B [<s8 >, #nvvm.wgmma_scale_in <one >, <row >]
400
+ B [<s8 >, #nvvm.wgmma_scale_in <one >, <col >]
401
401
: !mat16i32 -> !mat16i32
402
402
%result2 = nvvm.wgmma.mma_async %descA , %descB , %result1 ,
403
403
#nvvm.shape <m = 64 , n = 8 , k = 32 >,
404
404
D [<s32 >, #nvvm.wgmma_scale_out <one >, <satfinite >],
405
405
A [<s8 >, #nvvm.wgmma_scale_in <one >, <row >],
406
- B [<s8 >, #nvvm.wgmma_scale_in <one >, <row >]
406
+ B [<s8 >, #nvvm.wgmma_scale_in <one >, <col >]
407
407
: !mat16i32 -> !mat16i32
408
408
%result3 = nvvm.wgmma.mma_async %descA , %descB , %result2 ,
409
409
#nvvm.shape <m = 64 , n = 8 , k = 32 >,
410
410
D [<s32 >, #nvvm.wgmma_scale_out <one >, <satfinite >],
411
411
A [<s8 >, #nvvm.wgmma_scale_in <one >, <row >],
412
- B [<s8 >, #nvvm.wgmma_scale_in <one >, <row >]
412
+ B [<s8 >, #nvvm.wgmma_scale_in <one >, <col >]
413
413
: !mat16i32 -> !mat16i32
414
414
return %result3 : !mat16i32
415
415
}
@@ -458,19 +458,19 @@ func.func @wgmma_s32_u8_u8(%descA : i64, %descB : i64) -> !mat16i32 {
458
458
#nvvm.shape <m = 64 , n = 8 , k = 32 >,
459
459
D [<s32 >, #nvvm.wgmma_scale_out <one >],
460
460
A [<u8 >, #nvvm.wgmma_scale_in <one >, <row >],
461
- B [<u8 >, #nvvm.wgmma_scale_in <one >, <row >]
461
+ B [<u8 >, #nvvm.wgmma_scale_in <one >, <col >]
462
462
: !mat16i32 -> !mat16i32
463
463
%result2 = nvvm.wgmma.mma_async %descA , %descB , %result1 ,
464
464
#nvvm.shape <m = 64 , n = 8 , k = 32 >,
465
465
D [<s32 >, #nvvm.wgmma_scale_out <one >],
466
466
A [<u8 >, #nvvm.wgmma_scale_in <one >, <row >],
467
- B [<u8 >, #nvvm.wgmma_scale_in <one >, <row >]
467
+ B [<u8 >, #nvvm.wgmma_scale_in <one >, <col >]
468
468
: !mat16i32 -> !mat16i32
469
469
%result3 = nvvm.wgmma.mma_async %descA , %descB , %result2 ,
470
470
#nvvm.shape <m = 64 , n = 8 , k = 32 >,
471
471
D [<s32 >, #nvvm.wgmma_scale_out <one >],
472
472
A [<u8 >, #nvvm.wgmma_scale_in <one >, <row >],
473
- B [<u8 >, #nvvm.wgmma_scale_in <one >, <row >]
473
+ B [<u8 >, #nvvm.wgmma_scale_in <one >, <col >]
474
474
: !mat16i32 -> !mat16i32
475
475
return %result3 : !mat16i32
476
476
}
@@ -500,13 +500,13 @@ func.func @wgmma_f32_tf32_tf32(%descA : i64, %descB : i64) -> !mat32f32 {
500
500
#nvvm.shape <m = 64 , n = 64 , k = 8 >,
501
501
D [#nvvm.wgmma_type <f32 >, #nvvm.wgmma_scale_out <one >],
502
502
A [#nvvm.wgmma_type <tf32 >, #nvvm.wgmma_scale_in <one >, #nvvm.mma_layout <row >],
503
- B [#nvvm.wgmma_type <tf32 >, #nvvm.wgmma_scale_in <one >, #nvvm.mma_layout <row >]
503
+ B [#nvvm.wgmma_type <tf32 >, #nvvm.wgmma_scale_in <one >, #nvvm.mma_layout <col >]
504
504
: !mat32f32 -> !mat32f32
505
505
%result2 = nvvm.wgmma.mma_async %descA , %descB , %result1 ,
506
506
#nvvm.shape <m = 64 , n = 64 , k = 8 >,
507
507
D [#nvvm.wgmma_type <f32 >, #nvvm.wgmma_scale_out <one >],
508
508
A [#nvvm.wgmma_type <tf32 >, #nvvm.wgmma_scale_in <one >, #nvvm.mma_layout <row >],
509
- B [#nvvm.wgmma_type <tf32 >, #nvvm.wgmma_scale_in <one >, #nvvm.mma_layout <row >]
509
+ B [#nvvm.wgmma_type <tf32 >, #nvvm.wgmma_scale_in <one >, #nvvm.mma_layout <col >]
510
510
: !mat32f32 -> !mat32f32
511
511
return %result2 : !mat32f32
512
512
}
@@ -533,13 +533,13 @@ func.func @wgmma_f32_e4m3_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {
533
533
#nvvm.shape <m = 64 , n = 64 , k = 32 >,
534
534
D [#nvvm.wgmma_type <f32 >, #nvvm.wgmma_scale_out <one >],
535
535
A [#nvvm.wgmma_type <e4m3 >, #nvvm.wgmma_scale_in <one >, #nvvm.mma_layout <row >],
536
- B [#nvvm.wgmma_type <e4m3 >, #nvvm.wgmma_scale_in <one >, #nvvm.mma_layout <row >]
536
+ B [#nvvm.wgmma_type <e4m3 >, #nvvm.wgmma_scale_in <one >, #nvvm.mma_layout <col >]
537
537
: !mat32f32 -> !mat32f32
538
538
%result2 = nvvm.wgmma.mma_async %descA , %descB , %result1 ,
539
539
#nvvm.shape <m = 64 , n = 64 , k = 32 >,
540
540
D [#nvvm.wgmma_type <f32 >, #nvvm.wgmma_scale_out <one >],
541
541
A [#nvvm.wgmma_type <e4m3 >, #nvvm.wgmma_scale_in <one >, #nvvm.mma_layout <row >],
542
- B [#nvvm.wgmma_type <e4m3 >, #nvvm.wgmma_scale_in <one >, #nvvm.mma_layout <row >]
542
+ B [#nvvm.wgmma_type <e4m3 >, #nvvm.wgmma_scale_in <one >, #nvvm.mma_layout <col >]
543
543
: !mat32f32 -> !mat32f32
544
544
return %result2 : !mat32f32
545
545
}
@@ -565,13 +565,13 @@ func.func @wgmma_f32_e5m2_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {
565
565
#nvvm.shape <m = 64 , n = 64 , k = 32 >,
566
566
D [#nvvm.wgmma_type <f32 >, #nvvm.wgmma_scale_out <one >],
567
567
A [#nvvm.wgmma_type <e5m2 >, #nvvm.wgmma_scale_in <one >, #nvvm.mma_layout <row >],
568
- B [#nvvm.wgmma_type <e4m3 >, #nvvm.wgmma_scale_in <one >, #nvvm.mma_layout <row >]
568
+ B [#nvvm.wgmma_type <e4m3 >, #nvvm.wgmma_scale_in <one >, #nvvm.mma_layout <col >]
569
569
: !mat32f32 -> !mat32f32
570
570
%result2 = nvvm.wgmma.mma_async %descA , %descB , %result1 ,
571
571
#nvvm.shape <m = 64 , n = 64 , k = 32 >,
572
572
D [#nvvm.wgmma_type <f32 >, #nvvm.wgmma_scale_out <one >],
573
573
A [#nvvm.wgmma_type <e5m2 >, #nvvm.wgmma_scale_in <one >, #nvvm.mma_layout <row >],
574
- B [#nvvm.wgmma_type <e4m3 >, #nvvm.wgmma_scale_in <one >, #nvvm.mma_layout <row >]
574
+ B [#nvvm.wgmma_type <e4m3 >, #nvvm.wgmma_scale_in <one >, #nvvm.mma_layout <col >]
575
575
: !mat32f32 -> !mat32f32
576
576
return %result2 : !mat32f32
577
577
}
0 commit comments