Skip to content

Commit 105ce58

Browse files
[mlir][amdgpu] Define an amdgpu.scaling_mfma wrapper (#137498)
Create a wrapper around the new scaled MFMAs that operate on specific element types and tile sizes. See [Issue](iree-org/iree#20616). --------- Signed-off-by: Muzammiluddin Syed <[email protected]>
1 parent 692f832 commit 105ce58

File tree

4 files changed

+189
-4
lines changed

4 files changed

+189
-4
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,10 @@ def MFMAOutTypes : AnyTypeOf<[F64,
687687
VectorOfLengthAndType<[4, 16, 32], [F32]>,
688688
VectorOfLengthAndType<[4, 16, 32], [I32]>,
689689
VectorOfLengthAndType<[4], [F64]>]>;
690+
// scaled_mfma
691+
def ScaledMFMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[32], [F8E5M2, F8E4M3FN]>,
692+
VectorOfLengthAndType<[32], [F6E2M3FN, F6E3M2FN, F4E2M1FN]>]>;
693+
def ScaledMFMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 16], [F32]>]>;
690694
// wmma
691695
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<
692696
[4, 8, 16],
@@ -804,7 +808,7 @@ def AMDGPU_GatherToLDSOp :
804808
TypeAttr:$transferType
805809
)>,
806810
Results<(outs)> {
807-
let summary = "MLIR wrapper for CDNA mfma instructions";
811+
let summary = "MLIR wrapper for CDNA Gather to LDS instructions";
808812
let description = [{
809813
The `amdgpu.global_load` op is a wrapper around the `global_load_lds` instructions.
810814

@@ -830,4 +834,52 @@ def AMDGPU_GatherToLDSOp :
830834
let hasVerifier = 1;
831835
}
832836

837+
def AMDGPU_ScaledMFMAOp :
838+
AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
839+
Pure]>,
840+
Arguments<(ins
841+
I32Attr:$m,
842+
I32Attr:$n,
843+
I32Attr:$k,
844+
ScaledMFMAInTypes:$sourceA,
845+
ScaledMFMAInTypes:$sourceB,
846+
ScaledMFMAOutTypes:$destC,
847+
AnyTypeOf<[F8E8M0FNU, FixedVectorOfLengthAndType<[4], [F8E8M0FNU]>]>:$scalesA,
848+
AnyTypeOf<[F8E8M0FNU, FixedVectorOfLengthAndType<[4], [F8E8M0FNU]>]>:$scalesB,
849+
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$scalesIdxA,
850+
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$scalesIdxB
851+
)>,
852+
Results<(outs ScaledMFMAOutTypes: $destD)> {
853+
let summary = "MLIR wrapper for CDNA scaled mfma instructions";
854+
let description = [{
855+
The `amdgpu.scaled_mfma` op is an MLIR wrapper around intrinsics
856+
for various scaled versions of `mfma` instructions in the CDNA architecture, which perform
857+
multiple outer products in order to allow fast matrix multiplication.
858+
859+
The wrapper will select an appropriate `mfma` instruction, if one is available,
860+
based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the
861+
types of the source and destination arguments.
862+
863+
Note, this wrapper allows specifying `vector<4Kxi8>` arguments to MFMA
864+
intrinsics that take an integer type of width `4K`. For example,
865+
one can provide a `vector<4xi8>` as an argument to an MFMA instruction that
866+
logically takes 4 i8s but whose intrinsics are specified to take an i32.
867+
In these cases, the bytes in the vector will be concatenated in little-endian
868+
order (that is, v[0] will go to arg[7:0], v[1] to arg[15:8] and so on).
869+
870+
This wrapper takes inspiration from `amdgpu.mfma`, but has some key differences:
871+
- `amdgpu.scaled_mfma` operates on fp4 (f4E2M1FN), fp6 (f6E2M3FN and f6E3M2FN) and
872+
fp8 (f8E4M3FN and f8E5M2) types using either M=N=16, K=128 or M=N=32, K=64 as their tile
873+
size.
874+
- `amdgpu.scaled_mfma` does not support broadcasting. So, `cbsz`, `abid`, and `blgp`
875+
are omitted from this wrapper.
876+
- The `negateA`, `negateB`, and `negateC` flags in `amdgpu.mfma` are only supported for
877+
double-precision operations on gfx94x and so are not included here.
878+
}];
879+
let assemblyFormat = [{
880+
`(` $scalesA `[` $scalesIdxA `]` `*` $sourceA `)` `*` `(` $scalesB `[` $scalesIdxB `]` `*` $sourceB `)` `+` $destC
881+
attr-dict
882+
`:` type($scalesA) `,` type($sourceA) `,` type($scalesB) `,` type($sourceB) `,` type($destC)
883+
}];
884+
}
833885
#endif // AMDGPU

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
#include "llvm/ADT/STLExtras.h"
2525
#include "llvm/ADT/TypeSwitch.h"
26+
#include "llvm/Support/Casting.h"
2627
#include <optional>
2728

2829
namespace mlir {
@@ -528,6 +529,25 @@ static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
528529
return input;
529530
}
530531

532+
/// Converts the scaled MFMA operands, `scalesA` and `scalesB`, from MLIR AMDGPU
533+
/// dialect convention to ROCDL and LLVM AMDGPU intrinsics convention.
534+
///
535+
/// Specifically:
536+
/// 1. If `input` is a i8 value, zero extend it to i32
537+
/// 2. If `input` is a vector of length 4 and type i8, cast it to i32
538+
///
539+
/// Note that the type of `input` has already been LLVM type converted:
540+
/// therefore 8-bit and smaller floats are represented as their corresponding
541+
/// `iN` integers.
542+
static Value castMFMAScaleOperand(ConversionPatternRewriter &rewriter,
543+
Location loc, Value input) {
544+
Type inputType = input.getType();
545+
Type outputType = rewriter.getI32Type();
546+
if (auto intType = dyn_cast<IntegerType>(inputType))
547+
return rewriter.create<LLVM::ZExtOp>(loc, outputType, input);
548+
return rewriter.create<LLVM::BitcastOp>(loc, outputType, input);
549+
}
550+
531551
/// Push an input operand. If it is a float type, nothing to do. If it is
532552
/// an integer type, then we need to also push its signdness (1 for signed, 0
533553
/// for unsigned) and we need to pack the input 16xi8 vector into a 4xi32
@@ -833,6 +853,14 @@ mfmaOpToScaledIntrinsic(MFMAOp mfma, Chipset chipset) {
833853
mfma.getBlocks(), chipset);
834854
}
835855

856+
static std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
857+
mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
858+
return mfmaOpToScaledIntrinsic(smfma.getSourceA().getType(),
859+
smfma.getSourceB().getType(),
860+
smfma.getDestC().getType(), smfma.getM(),
861+
smfma.getN(), smfma.getK(), 1u, chipset);
862+
}
863+
836864
/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
837865
/// if one exists. This includes checking to ensure the intrinsic is supported
838866
/// on the architecture you are compiling for.
@@ -954,6 +982,52 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
954982
}
955983
};
956984

985+
struct ScaledMFMAOpLowering : public ConvertOpToLLVMPattern<ScaledMFMAOp> {
986+
ScaledMFMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
987+
: ConvertOpToLLVMPattern(converter), chipset(chipset) {}
988+
989+
Chipset chipset;
990+
991+
LogicalResult
992+
matchAndRewrite(ScaledMFMAOp op, ScaledMFMAOpAdaptor adaptor,
993+
ConversionPatternRewriter &rewriter) const override {
994+
Location loc = op.getLoc();
995+
Type intrinsicOutType = typeConverter->convertType(op.getDestD().getType());
996+
997+
if (chipset.majorVersion != 9 || chipset < kGfx950)
998+
return op->emitOpError("scaled MFMA only supported on gfx908+");
999+
std::optional<std::tuple<StringRef, uint32_t, uint32_t>>
1000+
maybeScaledIntrinsic = mfmaOpToScaledIntrinsic(op, chipset);
1001+
if (!maybeScaledIntrinsic.has_value())
1002+
return op.emitOpError(
1003+
"no intrinsic matching scaled MFMA size on given chipset");
1004+
1005+
auto [intrinsicName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;
1006+
OperationState loweredOp(loc, intrinsicName);
1007+
loweredOp.addTypes(intrinsicOutType);
1008+
loweredOp.addOperands(
1009+
{convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
1010+
convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
1011+
adaptor.getDestC()});
1012+
Value scalesIdxA =
1013+
createI32Constant(rewriter, loc, adaptor.getScalesIdxA());
1014+
Value scalesIdxB =
1015+
createI32Constant(rewriter, loc, adaptor.getScalesIdxB());
1016+
loweredOp.addOperands(
1017+
{createI32Constant(rewriter, loc, aTypeCode),
1018+
createI32Constant(rewriter, loc, bTypeCode),
1019+
/*scales idx A=*/scalesIdxA,
1020+
/*scales A*/
1021+
castMFMAScaleOperand(rewriter, loc, adaptor.getScalesA()),
1022+
/*scales idx B=*/scalesIdxB,
1023+
/*scales B*/
1024+
castMFMAScaleOperand(rewriter, loc, adaptor.getScalesB())});
1025+
Value lowered = rewriter.create(loweredOp)->getResult(0);
1026+
rewriter.replaceOp(op, lowered);
1027+
return success();
1028+
}
1029+
};
1030+
9571031
struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
9581032
WMMAOpLowering(const LLVMTypeConverter &converter, Chipset chipset)
9591033
: ConvertOpToLLVMPattern<WMMAOp>(converter), chipset(chipset) {}
@@ -1474,8 +1548,9 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
14741548
RawBufferOpLowering<RawBufferAtomicCmpswapOp,
14751549
ROCDL::RawPtrBufferAtomicCmpSwap>,
14761550
AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering,
1477-
MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering,
1478-
PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering,
1479-
GatherToLDSOpLowering>(converter, chipset);
1551+
MFMAOpLowering, ScaledMFMAOpLowering, WMMAOpLowering,
1552+
ExtPackedFp8OpLowering, PackedTrunc2xFp8OpLowering,
1553+
PackedStochRoundFp8OpLowering, GatherToLDSOpLowering>(converter,
1554+
chipset);
14801555
patterns.add<AMDGPUSwizzleBitModeLowering>(converter);
14811556
}

mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,54 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
5151

5252
func.return
5353
}
54+
55+
// CHECK-LABEL: func @scaled_mfma_to_rocdl(
56+
// CHECK-SAME: %[[ARG0:.*]]: vector<16xf32>, %[[ARG1:.*]]: vector<4xf32>, %[[ARG2:.*]]: vector<32xf8E4M3FN>, %[[ARG3:.*]]: vector<32xf8E5M2>, %[[ARG4:.*]]: vector<32xf6E2M3FN>, %[[ARG5:.*]]: vector<32xf6E3M2FN>, %[[ARG6:.*]]: vector<32xf4E2M1FN>, %[[ARG7:.*]]: vector<4xf8E8M0FNU>, %[[ARG8:.*]]: f8E8M0FNU
57+
func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
58+
%arg1 : vector<4xf32>, %arg2 : vector<32xf8E4M3FN>,
59+
%arg3 : vector<32xf8E5M2>, %arg4 : vector<32xf6E2M3FN>,
60+
%arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>,
61+
%arg7 : vector<4xf8E8M0FNU>, %arg8 : f8E8M0FNU) {
62+
63+
// CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
64+
// CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
65+
// CHECK: %[[b0:.+]] = llvm.bitcast {{.*}} : vector<4xi8> to i32
66+
// CHECK: %[[z0:.+]] = llvm.zext {{.*}} : i8 to i32
67+
68+
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
69+
amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E4M3FN>, f8E8M0FNU, vector<32xf8E4M3FN>, vector<16xf32>
70+
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
71+
amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E4M3FN>, f8E8M0FNU, vector<32xf8E4M3FN>, vector<4xf32>
72+
73+
// CHECK: llvm.bitcast
74+
75+
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
76+
amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E5M2>, f8E8M0FNU, vector<32xf8E5M2>, vector<16xf32>
77+
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
78+
amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E5M2>, f8E8M0FNU, vector<32xf8E5M2>, vector<4xf32>
79+
80+
// CHECK: llvm.bitcast
81+
82+
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
83+
amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
84+
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
85+
amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<4xf32>
86+
87+
// CHECK: llvm.bitcast
88+
// CHECK: llvm.mlir.constant(3 : i32) : i32
89+
90+
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
91+
amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E3M2FN>, f8E8M0FNU, vector<32xf6E3M2FN>, vector<16xf32>
92+
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
93+
amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E3M2FN>, f8E8M0FNU, vector<32xf6E3M2FN>, vector<4xf32>
94+
95+
// CHECK: llvm.bitcast
96+
// CHECK: llvm.mlir.constant(4 : i32) : i32
97+
98+
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
99+
amdgpu.scaled_mfma(%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<16xf32>
100+
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
101+
amdgpu.scaled_mfma(%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32>
102+
103+
func.return
104+
}

mlir/test/Dialect/AMDGPU/ops.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,10 @@ func.func @swizzle_bitmode(%arg0 : f32) -> f32 {
164164
%0 = amdgpu.swizzle_bitmode %arg0 1 2 4 : f32
165165
func.return %0 : f32
166166
}
167+
168+
// CHECK-LABEL: func @scaled_mfma
169+
func.func @scaled_mfma(%arg0 : f8E8M0FNU, %arg1 : vector<32xf6E2M3FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
170+
// CHECK: amdgpu.scaled_mfma
171+
%0 = amdgpu.scaled_mfma(%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : f8E8M0FNU, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
172+
func.return %0 : vector<16xf32>
173+
}

0 commit comments

Comments
 (0)