Skip to content

Commit 12c241b

Browse files
authored
[MLIR][NVVM] Explicit Data Type for Output in wgmma.mma_async (#78713)
The current implementation of `nvvm.wgmma.mma_async` Op deduces the data type of the output matrix from the data type of struct member, which can be non-intuitive, especially in cases where types like `2xf16` are packed into `i32`. This PR addresses this issue by improving the Op to include an explicit data type for the output matrix. The modified Op now includes an explicit data type for Matrix-D (<f16>), and looks as follows: ``` %result = llvm.mlir.undef : !llvm.struct<(struct<(i32, i32, ... nvvm.wgmma.mma_async %descA, %descB, %result, #nvvm.shape<m = 64, n = 32, k = 16>, D [<f16>, #nvvm.wgmma_scale_out<zero>], A [<f16>, #nvvm.wgmma_scale_in<neg>, <col>], B [<f16>, #nvvm.wgmma_scale_in<neg>, <col>] ```
1 parent 21830c9 commit 12c241b

File tree

7 files changed

+135
-126
lines changed

7 files changed

+135
-126
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1833,11 +1833,14 @@ def WGMMATypeB1 : I32EnumAttrCase<"b1", 4>;
18331833
def WGMMATypeBF16 : I32EnumAttrCase<"bf16", 5>;
18341834
def WGMMATypeF8E4M3 : I32EnumAttrCase<"e4m3", 6>;
18351835
def WGMMATypeF8E5M2 : I32EnumAttrCase<"e5m2", 7>;
1836+
def WGMMATypeF32 : I32EnumAttrCase<"f32", 8>;
1837+
def WGMMATypeS32 : I32EnumAttrCase<"s32", 9>;
1838+
18361839
def WGMMATypes : I32EnumAttr<"WGMMATypes", "NVVM WGMMA types",
18371840
[WGMMATypeF16, WGMMATypeTF32,
18381841
WGMMATypeU8, WGMMATypeS8,
18391842
WGMMATypeB1, WGMMATypeBF16, WGMMATypeF8E4M3,
1840-
WGMMATypeF8E5M2]> {
1843+
WGMMATypeF8E5M2, WGMMATypeF32, WGMMATypeS32]> {
18411844
let genSpecializedAttr = 0;
18421845
let cppNamespace = "::mlir::NVVM";
18431846
}
@@ -1859,6 +1862,7 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
18591862
NVVM_MMAShapeAttr:$shape,
18601863
WGMMATypesAttr:$typeA,
18611864
WGMMATypesAttr:$typeB,
1865+
WGMMATypesAttr:$typeD,
18621866
WGMMAScaleOutAttr:$scaleD,
18631867
WGMMAScaleInAttr:$scaleA,
18641868
WGMMAScaleInAttr:$scaleB,
@@ -1868,8 +1872,8 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
18681872
);
18691873

18701874
let assemblyFormat = [{
1871-
$descriptorA `,` $descriptorB `,` $shape `,`
1872-
`D` `[` $inouts `,` $scaleD (`,` $satfinite^)? `]` `,`
1875+
$descriptorA `,` $descriptorB `,` $inouts `,` $shape `,`
1876+
`D` `[` $typeD `,` $scaleD (`,` $satfinite^)? `]` `,`
18731877
`A` `[` $typeA `,` $scaleA `,` $layoutA `]` `,`
18741878
`B` `[` $typeB `,` $scaleB `,` $layoutB `]`
18751879
attr-dict `:`

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,10 +1267,11 @@ struct NVGPUWarpgroupMmaOpLowering
12671267
}
12681268

12691269
/// Generates WGMMATypesAttr from MLIR Type
1270-
NVVM::WGMMATypesAttr generateWgmmaType(Type type) const {
1271-
auto getWgmmaType = [](Type elemType) {
1270+
NVVM::WGMMATypesAttr generateWgmmaType(Type type,
1271+
bool useF32 = false) const {
1272+
auto getWgmmaType = [=](Type elemType) {
12721273
if (elemType.isF32() || elemType.isTF32())
1273-
return NVVM::WGMMATypes::tf32;
1274+
return useF32 ? NVVM::WGMMATypes::f32 : NVVM::WGMMATypes::tf32;
12741275
if (elemType.isF16())
12751276
return NVVM::WGMMATypes::f16;
12761277
if (elemType.isBF16())
@@ -1285,6 +1286,8 @@ struct NVGPUWarpgroupMmaOpLowering
12851286
return NVVM::WGMMATypes::s8;
12861287
if (elemType.isUnsignedInteger(8))
12871288
return NVVM::WGMMATypes::u8;
1289+
if (elemType.isInteger(32))
1290+
return NVVM::WGMMATypes::s32;
12881291
llvm_unreachable("unsupported type");
12891292
};
12901293
return NVVM::WGMMATypesAttr::get(op->getContext(), getWgmmaType(type));
@@ -1397,6 +1400,9 @@ struct NVGPUWarpgroupMmaOpLowering
13971400
Type elemB = op.getDescriptorB().getType().getTensor().getElementType();
13981401
NVVM::WGMMATypesAttr itypeB = generateWgmmaType(elemB);
13991402

1403+
Type elemD = op.getMatrixC().getType().getFragmented().getElementType();
1404+
NVVM::WGMMATypesAttr itypeD = generateWgmmaType(elemD, true);
1405+
14001406
NVVM::MMAShapeAttr shape = generateWgmmaShape();
14011407
NVVM::WGMMAScaleOutAttr scaleOut = generateScaleOut();
14021408
NVVM::WGMMAScaleInAttr scaleIn = generateScaleIn();
@@ -1408,7 +1414,8 @@ struct NVGPUWarpgroupMmaOpLowering
14081414

14091415
return b.create<NVVM::WgmmaMmaAsyncOp>(
14101416
matrixC.getType(), matrixC, descriptorA, descriptorB, shape, itypeA,
1411-
itypeB, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow);
1417+
itypeB, itypeD, scaleOut, scaleIn, scaleIn, layoutA, layoutB,
1418+
overflow);
14121419
}
14131420

14141421
/// Generates multiple wgmma instructions to complete the given GEMM shape

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 57 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -755,37 +755,44 @@ FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
755755
return failure();
756756
}
757757

758-
LogicalResult isAllowedWGMMADataType(Type typeD, NVVM::WGMMATypes typeA,
758+
LogicalResult isAllowedWGMMADataType(NVVM::WGMMATypes typeD,
759+
NVVM::WGMMATypes typeA,
759760
NVVM::WGMMATypes typeB) {
760761
switch (typeA) {
761762
case NVVM::WGMMATypes::f16:
762-
if ((typeD.isF32() || typeD.isF16()) && typeB == NVVM::WGMMATypes::f16)
763+
if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
764+
typeB == NVVM::WGMMATypes::f16)
763765
return success();
764766
break;
765767
case NVVM::WGMMATypes::tf32:
766-
if (typeD.isF32() && typeB == NVVM::WGMMATypes::tf32)
768+
if (typeD == NVVM::WGMMATypes::f32 && typeB == NVVM::WGMMATypes::tf32)
767769
return success();
768770
break;
769771
case NVVM::WGMMATypes::u8:
770772
case NVVM::WGMMATypes::s8:
771-
if (typeD.isInteger(32) &&
773+
if (typeD == NVVM::WGMMATypes::s32 &&
772774
(typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
773775
return success();
774776
break;
775777
case NVVM::WGMMATypes::b1:
776-
if (typeD.isInteger(32) && typeB == NVVM::WGMMATypes::b1)
778+
if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
777779
return success();
778780
break;
779781
case NVVM::WGMMATypes::bf16:
780-
if ((typeD.isF32() || typeD.isF16()) && typeB == NVVM::WGMMATypes::bf16)
782+
if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
783+
typeB == NVVM::WGMMATypes::bf16)
781784
return success();
782785
break;
783786
case NVVM::WGMMATypes::e4m3:
784787
case NVVM::WGMMATypes::e5m2:
785-
if ((typeD.isF32() || typeD.isF16()) &&
788+
if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16) &&
786789
(typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
787790
return success();
788791
break;
792+
case WGMMATypes::f32:
793+
case WGMMATypes::s32:
794+
llvm_unreachable("unsupported input types");
795+
break;
789796
}
790797
return failure();
791798
}
@@ -799,19 +806,24 @@ LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
799806
80, 96, 112, 128, 144, 160,
800807
176, 192, 208, 224, 240, 256};
801808
switch (typeA) {
802-
case mlir::NVVM::WGMMATypes::f16:
803-
case mlir::NVVM::WGMMATypes::tf32:
804-
case mlir::NVVM::WGMMATypes::bf16:
805-
case mlir::NVVM::WGMMATypes::e4m3:
806-
case mlir::NVVM::WGMMATypes::e5m2:
809+
case WGMMATypes::f16:
810+
case WGMMATypes::tf32:
811+
case WGMMATypes::bf16:
812+
case WGMMATypes::e4m3:
813+
case WGMMATypes::e5m2:
807814
if (llvm::is_contained(allowedN, sizeN))
808815
return success();
809816
break;
810-
case mlir::NVVM::WGMMATypes::u8:
811-
case mlir::NVVM::WGMMATypes::s8:
812-
case mlir::NVVM::WGMMATypes::b1:
817+
case WGMMATypes::u8:
818+
case WGMMATypes::s8:
819+
case WGMMATypes::b1:
813820
if (llvm::is_contained(allowedNshort, sizeN))
814821
return success();
822+
break;
823+
case WGMMATypes::f32:
824+
case WGMMATypes::s32:
825+
llvm_unreachable("unsupported input types");
826+
break;
815827
}
816828
return failure();
817829
}
@@ -821,27 +833,29 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
821833
auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
822834
if (!stype)
823835
return emitOpError() << "expected results to be struct";
824-
Type outputType = stype.getBody().front();
825836
int outputSize = stype.getBody().size();
837+
WGMMATypes typeD = getTypeD();
838+
WGMMATypes typeA = getTypeA();
839+
WGMMATypes typeB = getTypeB();
840+
826841
for (Type t : stype.getBody()) {
827-
if (t != outputType)
842+
if (t != stype.getBody().front())
828843
return emitOpError()
829844
<< "all elements in struct must be same type but there is " << t;
830845
}
831846

832-
if (!outputType.isF32() && !outputType.isInteger(32) && !outputType.isF16()) {
847+
if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
848+
typeD != WGMMATypes::s32) {
833849
return emitOpError() << "does not support the given output type "
834-
<< outputType;
850+
<< NVVM::stringifyWGMMATypes(typeD);
835851
}
836-
if (outputType.isInteger(32) && (getScaleA() == NVVM::WGMMAScaleIn::neg ||
837-
getScaleB() == NVVM::WGMMAScaleIn::neg)) {
852+
if (typeD == WGMMATypes::s32 &&
853+
(getScaleA() == WGMMAScaleIn::neg || getScaleB() == WGMMAScaleIn::neg)) {
838854
return emitOpError() << "has s32 output, scaleA and scaleB cannot be neg";
839855
}
840856

841-
mlir::NVVM::WGMMATypes typeA = getTypeA();
842-
mlir::NVVM::WGMMATypes typeB = getTypeB();
843-
if (failed(isAllowedWGMMADataType(outputType, typeA, typeB))) {
844-
return emitOpError() << outputType
857+
if (failed(isAllowedWGMMADataType(typeD, typeA, typeB))) {
858+
return emitOpError() << NVVM::stringifyWGMMATypes(typeD)
845859
<< " += " << NVVM::stringifyWGMMATypes(typeA) << " * "
846860
<< NVVM::stringifyWGMMATypes(typeB)
847861
<< ", it is not supported.";
@@ -866,8 +880,7 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
866880
}
867881

868882
// Check transpose (only available for f16/bf16)
869-
if ((typeA != mlir::NVVM::WGMMATypes::f16 &&
870-
typeA != mlir::NVVM::WGMMATypes::bf16) &&
883+
if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16) &&
871884
(getLayoutA() == mlir::NVVM::MMALayout::col ||
872885
getLayoutB() == mlir::NVVM::MMALayout::col)) {
873886
return emitOpError()
@@ -876,29 +889,29 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
876889
<< " for input types " << stringifyWGMMATypes(typeA) << " and "
877890
<< stringifyWGMMATypes(typeB)
878891
<< " requires transpose. However, this is only supported for: "
879-
<< stringifyMMATypes(mlir::NVVM::MMATypes::f16) << " and "
880-
<< stringifyMMATypes(mlir::NVVM::MMATypes::bf16);
892+
<< stringifyMMATypes(MMATypes::f16) << " and "
893+
<< stringifyMMATypes(MMATypes::bf16);
881894
}
882895

883896
// Check result registers
884-
int expectedOutput;
885-
if (outputType.isF32() || outputType.isInteger(32))
897+
int expectedOutput = 0;
898+
if (typeD == WGMMATypes::f32 || typeD == WGMMATypes::s32)
886899
expectedOutput = getShape().getN() / 2;
887-
if (outputType.isF16())
900+
if (typeD == WGMMATypes::f16)
888901
expectedOutput = getShape().getN() / 4;
889902
if (outputSize != expectedOutput) {
890903
return emitOpError() << "results " << expectedOutput
891904
<< ", however output struct has " << outputSize
892905
<< " elements";
893906
}
894-
// Check satfinite (only availalbe for s32 accumulator)
895-
if (!outputType.isInteger(32) &&
907+
// Check satfinite (only available for s32 accumulator)
908+
if (typeD != WGMMATypes::s32 &&
896909
getSatfinite().value_or(NVVM::MMAIntOverflow::wrapped) ==
897910
NVVM::MMAIntOverflow::satfinite) {
898911
return emitOpError()
899912
<< " `satfinite` can be only used with s32 accumulator, however "
900913
"the current accumulator is "
901-
<< outputType;
914+
<< NVVM::stringifyWGMMATypes(typeD);
902915
}
903916

904917
return success();
@@ -907,27 +920,15 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
907920
std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
908921

909922
int m = getShape().getM(), n = getShape().getN(), k = getShape().getK();
910-
bool isF16 = getTypeA() == mlir::NVVM::WGMMATypes::f16 ||
911-
getTypeA() == mlir::NVVM::WGMMATypes::bf16;
923+
bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
912924

913-
Value outValue = getResults() ? getResults() : getInouts();
914-
auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
915-
Type outputType = stype.getBody().front();
916-
std::string outputTypeName;
917-
if (outputType.isF16())
918-
outputTypeName = "f16";
919-
else if (outputType.isF32())
920-
outputTypeName = "f32";
921-
else if (outputType.isInteger(32))
922-
outputTypeName = "s32";
923-
else
924-
assert(false && "unsupported output type");
925+
StringRef outputTypeName = stringifyWGMMATypes(getTypeD());
925926

926-
int expectedOutputRegisters;
927-
if (outputType.isF32() || outputType.isInteger(32))
928-
expectedOutputRegisters = getShape().getN() / 2;
929-
if (outputType.isF16())
927+
int expectedOutputRegisters = 0;
928+
if (getTypeD() == WGMMATypes::f16)
930929
expectedOutputRegisters = getShape().getN() / 4;
930+
else
931+
expectedOutputRegisters = getShape().getN() / 2;
931932

932933
std::string ptx;
933934
llvm::raw_string_ostream ss(ptx);
@@ -958,7 +959,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
958959
ss << " $" << (regCnt) << ","
959960
<< " $" << (regCnt + 1) << ","
960961
<< " p";
961-
if (!outputType.isInteger(32)) {
962+
if (getTypeD() != WGMMATypes::s32) {
962963
ss << ", $" << (regCnt + 3) << ", $" << (regCnt + 4);
963964
}
964965
// Don't add transpose parameters unless needed.
@@ -975,11 +976,7 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
975976
RewriterBase &rewriter,
976977
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
977978
&asmValues) {
978-
Value outValue = getResults() ? getResults() : getInouts();
979-
auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType());
980-
Type outputType = stype.getBody().front();
981-
bool isF16 = getTypeA() == mlir::NVVM::WGMMATypes::f16 ||
982-
getTypeA() == mlir::NVVM::WGMMATypes::bf16;
979+
bool isF16 = getTypeA() == WGMMATypes::f16 || getTypeA() == WGMMATypes::bf16;
983980
if (getResults())
984981
asmValues.push_back({getResults(), mlir::NVVM::PTXRegisterMod::Write});
985982
if (getInouts())
@@ -988,7 +985,7 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
988985
asmValues.push_back({getDescriptorB(), mlir::NVVM::PTXRegisterMod::Read});
989986
asmValues.push_back({makeConstantI32(rewriter, static_cast<int>(getScaleD())),
990987
mlir::NVVM::PTXRegisterMod::Read});
991-
if (!outputType.isInteger(32)) {
988+
if (getTypeD() != WGMMATypes::s32) {
992989
asmValues.push_back(
993990
{makeConstantI32(rewriter,
994991
getScaleA() == NVVM::WGMMAScaleIn::neg ? -1 : 1),

0 commit comments

Comments
 (0)