Skip to content

Commit 948862b

Browse files
authored
[mlir][nvvm] Fix the verifier of wgmma.mma_async wrt transposed layouts (#97538)
the WGMMA expect layouts for A/B are row/col, the transposed version should be col/row. when checking other datatypes cannot use transposed layout, it should reject col-major for A and row-major for B
1 parent 86187ed commit 948862b

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -878,9 +878,12 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
878878
}
879879

880880
// Check transpose (only available for f16/bf16)
881+
// Matrices A should be stored in row-major and B in column-major.
882+
// Only f16/bf16 matrices can be stored in either column-major or row-major
883+
// by setting the tranpose value(imm-trans-a,imm-trans-b) in PTX code.
881884
if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
882885
(getLayoutA() == mlir::NVVM::MMALayout::col ||
883-
getLayoutB() == mlir::NVVM::MMALayout::col)) {
886+
getLayoutB() == mlir::NVVM::MMALayout::row)) {
884887
return emitOpError()
885888
<< "given layouts layout_a = " << stringifyMMALayout(getLayoutA())
886889
<< " and layout_b = " << stringifyMMALayout(getLayoutB())

mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -397,19 +397,19 @@ func.func @wgmma_s32_s8_s8_satfinite(%descA : i64, %descB : i64) -> !mat16i32{
397397
#nvvm.shape<m = 64, n = 8, k = 32>,
398398
D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
399399
A [<s8>, #nvvm.wgmma_scale_in<one>, <row>],
400-
B [<s8>, #nvvm.wgmma_scale_in<one>, <row>]
400+
B [<s8>, #nvvm.wgmma_scale_in<one>, <col>]
401401
: !mat16i32 -> !mat16i32
402402
%result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
403403
#nvvm.shape<m = 64, n = 8, k = 32>,
404404
D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
405405
A [<s8>, #nvvm.wgmma_scale_in<one>, <row>],
406-
B [<s8>, #nvvm.wgmma_scale_in<one>, <row>]
406+
B [<s8>, #nvvm.wgmma_scale_in<one>, <col>]
407407
: !mat16i32 -> !mat16i32
408408
%result3 = nvvm.wgmma.mma_async %descA, %descB, %result2,
409409
#nvvm.shape<m = 64, n = 8, k = 32>,
410410
D [<s32>, #nvvm.wgmma_scale_out<one>, <satfinite>],
411411
A [<s8>, #nvvm.wgmma_scale_in<one>, <row>],
412-
B [<s8>, #nvvm.wgmma_scale_in<one>, <row>]
412+
B [<s8>, #nvvm.wgmma_scale_in<one>, <col>]
413413
: !mat16i32 -> !mat16i32
414414
return %result3 : !mat16i32
415415
}
@@ -458,19 +458,19 @@ func.func @wgmma_s32_u8_u8(%descA : i64, %descB : i64) -> !mat16i32 {
458458
#nvvm.shape<m = 64, n = 8, k = 32>,
459459
D [<s32>, #nvvm.wgmma_scale_out<one>],
460460
A [<u8>, #nvvm.wgmma_scale_in<one>, <row>],
461-
B [<u8>, #nvvm.wgmma_scale_in<one>, <row>]
461+
B [<u8>, #nvvm.wgmma_scale_in<one>, <col>]
462462
: !mat16i32 -> !mat16i32
463463
%result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
464464
#nvvm.shape<m = 64, n = 8, k = 32>,
465465
D [<s32>, #nvvm.wgmma_scale_out<one>],
466466
A [<u8>, #nvvm.wgmma_scale_in<one>, <row>],
467-
B [<u8>, #nvvm.wgmma_scale_in<one>, <row>]
467+
B [<u8>, #nvvm.wgmma_scale_in<one>, <col>]
468468
: !mat16i32 -> !mat16i32
469469
%result3 = nvvm.wgmma.mma_async %descA, %descB, %result2,
470470
#nvvm.shape<m = 64, n = 8, k = 32>,
471471
D [<s32>, #nvvm.wgmma_scale_out<one>],
472472
A [<u8>, #nvvm.wgmma_scale_in<one>, <row>],
473-
B [<u8>, #nvvm.wgmma_scale_in<one>, <row>]
473+
B [<u8>, #nvvm.wgmma_scale_in<one>, <col>]
474474
: !mat16i32 -> !mat16i32
475475
return %result3 : !mat16i32
476476
}
@@ -500,13 +500,13 @@ func.func @wgmma_f32_tf32_tf32(%descA : i64, %descB : i64) -> !mat32f32 {
500500
#nvvm.shape<m = 64, n = 64, k = 8>,
501501
D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
502502
A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
503-
B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
503+
B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
504504
: !mat32f32 -> !mat32f32
505505
%result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
506506
#nvvm.shape<m = 64, n = 64, k = 8>,
507507
D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
508508
A [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
509-
B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
509+
B [#nvvm.wgmma_type<tf32>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
510510
: !mat32f32 -> !mat32f32
511511
return %result2 : !mat32f32
512512
}
@@ -533,13 +533,13 @@ func.func @wgmma_f32_e4m3_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {
533533
#nvvm.shape<m = 64, n = 64, k = 32>,
534534
D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
535535
A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
536-
B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
536+
B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
537537
: !mat32f32 -> !mat32f32
538538
%result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
539539
#nvvm.shape<m = 64, n = 64, k = 32>,
540540
D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
541541
A [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
542-
B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
542+
B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
543543
: !mat32f32 -> !mat32f32
544544
return %result2 : !mat32f32
545545
}
@@ -565,13 +565,13 @@ func.func @wgmma_f32_e5m2_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {
565565
#nvvm.shape<m = 64, n = 64, k = 32>,
566566
D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
567567
A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
568-
B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
568+
B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
569569
: !mat32f32 -> !mat32f32
570570
%result2 = nvvm.wgmma.mma_async %descA, %descB, %result1,
571571
#nvvm.shape<m = 64, n = 64, k = 32>,
572572
D [#nvvm.wgmma_type<f32>, #nvvm.wgmma_scale_out<one>],
573573
A [#nvvm.wgmma_type<e5m2>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>],
574-
B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<row>]
574+
B [#nvvm.wgmma_type<e4m3>, #nvvm.wgmma_scale_in<one>, #nvvm.mma_layout<col>]
575575
: !mat32f32 -> !mat32f32
576576
return %result2 : !mat32f32
577577
}

0 commit comments

Comments
 (0)