@@ -755,37 +755,44 @@ FailureOr<int> getAllowedSizeK(NVVM::WGMMATypes typeA) {
755
755
return failure ();
756
756
}
757
757
758
- LogicalResult isAllowedWGMMADataType (Type typeD, NVVM::WGMMATypes typeA,
758
+ LogicalResult isAllowedWGMMADataType (NVVM::WGMMATypes typeD,
759
+ NVVM::WGMMATypes typeA,
759
760
NVVM::WGMMATypes typeB) {
760
761
switch (typeA) {
761
762
case NVVM::WGMMATypes::f16 :
762
- if ((typeD.isF32 () || typeD.isF16 ()) && typeB == NVVM::WGMMATypes::f16 )
763
+ if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16 ) &&
764
+ typeB == NVVM::WGMMATypes::f16 )
763
765
return success ();
764
766
break ;
765
767
case NVVM::WGMMATypes::tf32:
766
- if (typeD. isF32 () && typeB == NVVM::WGMMATypes::tf32)
768
+ if (typeD == NVVM::WGMMATypes:: f32 && typeB == NVVM::WGMMATypes::tf32)
767
769
return success ();
768
770
break ;
769
771
case NVVM::WGMMATypes::u8 :
770
772
case NVVM::WGMMATypes::s8:
771
- if (typeD. isInteger ( 32 ) &&
773
+ if (typeD == NVVM::WGMMATypes::s32 &&
772
774
(typeB == NVVM::WGMMATypes::u8 || typeB == NVVM::WGMMATypes::s8))
773
775
return success ();
774
776
break ;
775
777
case NVVM::WGMMATypes::b1:
776
- if (typeD. isInteger ( 32 ) && typeB == NVVM::WGMMATypes::b1)
778
+ if (typeD == NVVM::WGMMATypes::s32 && typeB == NVVM::WGMMATypes::b1)
777
779
return success ();
778
780
break ;
779
781
case NVVM::WGMMATypes::bf16 :
780
- if ((typeD.isF32 () || typeD.isF16 ()) && typeB == NVVM::WGMMATypes::bf16 )
782
+ if ((typeD == NVVM::WGMMATypes::f32 || typeD == NVVM::WGMMATypes::f16 ) &&
783
+ typeB == NVVM::WGMMATypes::bf16 )
781
784
return success ();
782
785
break ;
783
786
case NVVM::WGMMATypes::e4m3:
784
787
case NVVM::WGMMATypes::e5m2:
785
- if ((typeD. isF32 () || typeD. isF16 () ) &&
788
+ if ((typeD == NVVM::WGMMATypes:: f32 || typeD == NVVM::WGMMATypes:: f16 ) &&
786
789
(typeB == NVVM::WGMMATypes::e5m2 || typeB == NVVM::WGMMATypes::e4m3))
787
790
return success ();
788
791
break ;
792
+ case WGMMATypes::f32 :
793
+ case WGMMATypes::s32:
794
+ llvm_unreachable (" unsupported input types" );
795
+ break ;
789
796
}
790
797
return failure ();
791
798
}
@@ -799,19 +806,24 @@ LogicalResult isAllowedSizeN(int sizeN, NVVM::WGMMATypes typeA) {
799
806
80 , 96 , 112 , 128 , 144 , 160 ,
800
807
176 , 192 , 208 , 224 , 240 , 256 };
801
808
switch (typeA) {
802
- case mlir::NVVM:: WGMMATypes::f16 :
803
- case mlir::NVVM:: WGMMATypes::tf32:
804
- case mlir::NVVM:: WGMMATypes::bf16 :
805
- case mlir::NVVM:: WGMMATypes::e4m3:
806
- case mlir::NVVM:: WGMMATypes::e5m2:
809
+ case WGMMATypes::f16 :
810
+ case WGMMATypes::tf32:
811
+ case WGMMATypes::bf16 :
812
+ case WGMMATypes::e4m3:
813
+ case WGMMATypes::e5m2:
807
814
if (llvm::is_contained (allowedN, sizeN))
808
815
return success ();
809
816
break ;
810
- case mlir::NVVM:: WGMMATypes::u8 :
811
- case mlir::NVVM:: WGMMATypes::s8:
812
- case mlir::NVVM:: WGMMATypes::b1:
817
+ case WGMMATypes::u8 :
818
+ case WGMMATypes::s8:
819
+ case WGMMATypes::b1:
813
820
if (llvm::is_contained (allowedNshort, sizeN))
814
821
return success ();
822
+ break ;
823
+ case WGMMATypes::f32 :
824
+ case WGMMATypes::s32:
825
+ llvm_unreachable (" unsupported input types" );
826
+ break ;
815
827
}
816
828
return failure ();
817
829
}
@@ -821,27 +833,29 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
821
833
auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType ());
822
834
if (!stype)
823
835
return emitOpError () << " expected results to be struct" ;
824
- Type outputType = stype.getBody ().front ();
825
836
int outputSize = stype.getBody ().size ();
837
+ WGMMATypes typeD = getTypeD ();
838
+ WGMMATypes typeA = getTypeA ();
839
+ WGMMATypes typeB = getTypeB ();
840
+
826
841
for (Type t : stype.getBody ()) {
827
- if (t != outputType )
842
+ if (t != stype. getBody (). front () )
828
843
return emitOpError ()
829
844
<< " all elements in struct must be same type but there is " << t;
830
845
}
831
846
832
- if (!outputType.isF32 () && !outputType.isInteger (32 ) && !outputType.isF16 ()) {
847
+ if (typeD != WGMMATypes::f32 && typeD != WGMMATypes::f16 &&
848
+ typeD != WGMMATypes::s32) {
833
849
return emitOpError () << " does not support the given output type "
834
- << outputType ;
850
+ << NVVM::stringifyWGMMATypes (typeD) ;
835
851
}
836
- if (outputType. isInteger ( 32 ) && ( getScaleA () == NVVM::WGMMAScaleIn::neg ||
837
- getScaleB () == NVVM:: WGMMAScaleIn::neg)) {
852
+ if (typeD == WGMMATypes::s32 &&
853
+ ( getScaleA () == WGMMAScaleIn::neg || getScaleB () == WGMMAScaleIn::neg)) {
838
854
return emitOpError () << " has s32 output, scaleA and scaleB cannot be neg" ;
839
855
}
840
856
841
- mlir::NVVM::WGMMATypes typeA = getTypeA ();
842
- mlir::NVVM::WGMMATypes typeB = getTypeB ();
843
- if (failed (isAllowedWGMMADataType (outputType, typeA, typeB))) {
844
- return emitOpError () << outputType
857
+ if (failed (isAllowedWGMMADataType (typeD, typeA, typeB))) {
858
+ return emitOpError () << NVVM::stringifyWGMMATypes (typeD)
845
859
<< " += " << NVVM::stringifyWGMMATypes (typeA) << " * "
846
860
<< NVVM::stringifyWGMMATypes (typeB)
847
861
<< " , it is not supported." ;
@@ -866,8 +880,7 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
866
880
}
867
881
868
882
// Check transpose (only available for f16/bf16)
869
- if ((typeA != mlir::NVVM::WGMMATypes::f16 &&
870
- typeA != mlir::NVVM::WGMMATypes::bf16 ) &&
883
+ if ((typeA != WGMMATypes::f16 && typeA != WGMMATypes::bf16 ) &&
871
884
(getLayoutA () == mlir::NVVM::MMALayout::col ||
872
885
getLayoutB () == mlir::NVVM::MMALayout::col)) {
873
886
return emitOpError ()
@@ -876,29 +889,29 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
876
889
<< " for input types " << stringifyWGMMATypes (typeA) << " and "
877
890
<< stringifyWGMMATypes (typeB)
878
891
<< " requires transpose. However, this is only supported for: "
879
- << stringifyMMATypes (mlir::NVVM:: MMATypes::f16 ) << " and "
880
- << stringifyMMATypes (mlir::NVVM:: MMATypes::bf16 );
892
+ << stringifyMMATypes (MMATypes::f16 ) << " and "
893
+ << stringifyMMATypes (MMATypes::bf16 );
881
894
}
882
895
883
896
// Check result registers
884
- int expectedOutput;
885
- if (outputType. isF32 () || outputType. isInteger ( 32 ) )
897
+ int expectedOutput = 0 ;
898
+ if (typeD == WGMMATypes:: f32 || typeD == WGMMATypes::s32 )
886
899
expectedOutput = getShape ().getN () / 2 ;
887
- if (outputType. isF16 () )
900
+ if (typeD == WGMMATypes:: f16 )
888
901
expectedOutput = getShape ().getN () / 4 ;
889
902
if (outputSize != expectedOutput) {
890
903
return emitOpError () << " results " << expectedOutput
891
904
<< " , however output struct has " << outputSize
892
905
<< " elements" ;
893
906
}
894
- // Check satfinite (only availalbe for s32 accumulator)
895
- if (!outputType. isInteger ( 32 ) &&
907
+ // Check satfinite (only available for s32 accumulator)
908
+ if (typeD != WGMMATypes::s32 &&
896
909
getSatfinite ().value_or (NVVM::MMAIntOverflow::wrapped) ==
897
910
NVVM::MMAIntOverflow::satfinite) {
898
911
return emitOpError ()
899
912
<< " `satfinite` can be only used with s32 accumulator, however "
900
913
" the current accumulator is "
901
- << outputType ;
914
+ << NVVM::stringifyWGMMATypes (typeD) ;
902
915
}
903
916
904
917
return success ();
@@ -907,27 +920,15 @@ LogicalResult NVVM::WgmmaMmaAsyncOp::verify() {
907
920
std::string NVVM::WgmmaMmaAsyncOp::getPtx () {
908
921
909
922
int m = getShape ().getM (), n = getShape ().getN (), k = getShape ().getK ();
910
- bool isF16 = getTypeA () == mlir::NVVM::WGMMATypes::f16 ||
911
- getTypeA () == mlir::NVVM::WGMMATypes::bf16 ;
923
+ bool isF16 = getTypeA () == WGMMATypes::f16 || getTypeA () == WGMMATypes::bf16 ;
912
924
913
- Value outValue = getResults () ? getResults () : getInouts ();
914
- auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType ());
915
- Type outputType = stype.getBody ().front ();
916
- std::string outputTypeName;
917
- if (outputType.isF16 ())
918
- outputTypeName = " f16" ;
919
- else if (outputType.isF32 ())
920
- outputTypeName = " f32" ;
921
- else if (outputType.isInteger (32 ))
922
- outputTypeName = " s32" ;
923
- else
924
- assert (false && " unsupported output type" );
925
+ StringRef outputTypeName = stringifyWGMMATypes (getTypeD ());
925
926
926
- int expectedOutputRegisters;
927
- if (outputType.isF32 () || outputType.isInteger (32 ))
928
- expectedOutputRegisters = getShape ().getN () / 2 ;
929
- if (outputType.isF16 ())
927
+ int expectedOutputRegisters = 0 ;
928
+ if (getTypeD () == WGMMATypes::f16 )
930
929
expectedOutputRegisters = getShape ().getN () / 4 ;
930
+ else
931
+ expectedOutputRegisters = getShape ().getN () / 2 ;
931
932
932
933
std::string ptx;
933
934
llvm::raw_string_ostream ss (ptx);
@@ -958,7 +959,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
958
959
ss << " $" << (regCnt) << " ,"
959
960
<< " $" << (regCnt + 1 ) << " ,"
960
961
<< " p" ;
961
- if (!outputType. isInteger ( 32 ) ) {
962
+ if (getTypeD () != WGMMATypes::s32 ) {
962
963
ss << " , $" << (regCnt + 3 ) << " , $" << (regCnt + 4 );
963
964
}
964
965
// Don't add transpose parameters unless needed.
@@ -975,11 +976,7 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
975
976
RewriterBase &rewriter,
976
977
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
977
978
&asmValues) {
978
- Value outValue = getResults () ? getResults () : getInouts ();
979
- auto stype = dyn_cast<LLVM::LLVMStructType>(outValue.getType ());
980
- Type outputType = stype.getBody ().front ();
981
- bool isF16 = getTypeA () == mlir::NVVM::WGMMATypes::f16 ||
982
- getTypeA () == mlir::NVVM::WGMMATypes::bf16 ;
979
+ bool isF16 = getTypeA () == WGMMATypes::f16 || getTypeA () == WGMMATypes::bf16 ;
983
980
if (getResults ())
984
981
asmValues.push_back ({getResults (), mlir::NVVM::PTXRegisterMod::Write});
985
982
if (getInouts ())
@@ -988,7 +985,7 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
988
985
asmValues.push_back ({getDescriptorB (), mlir::NVVM::PTXRegisterMod::Read});
989
986
asmValues.push_back ({makeConstantI32 (rewriter, static_cast <int >(getScaleD ())),
990
987
mlir::NVVM::PTXRegisterMod::Read});
991
- if (!outputType. isInteger ( 32 ) ) {
988
+ if (getTypeD () != WGMMATypes::s32 ) {
992
989
asmValues.push_back (
993
990
{makeConstantI32 (rewriter,
994
991
getScaleA () == NVVM::WGMMAScaleIn::neg ? -1 : 1 ),
0 commit comments