Skip to content

[MLIR][AMDGPU] Add support for fp8 ops on gfx12 #106388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def MFMAOutTypes : AnyTypeOf<[F64,
VectorOfLengthAndType<[4, 16, 32], [I32]>,
VectorOfLengthAndType<[4], [F64]>]>;
// wmma
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[16], [F16, BF16, I8, SI8, UI8]>]>;
def WMMAInTypes : AnyTypeOf<[VectorOfLengthAndType<[8, 16], [F16, BF16, I8, SI8, UI8, F8E4M3FN, F8E5M2]>]>;
def WMMAOutTypes : AnyTypeOf<[VectorOfLengthAndType<[4, 8], [F32, I32]>,
VectorOfLengthAndType<[8, 16], [F16, BF16]>]>;

Expand Down
7 changes: 5 additions & 2 deletions mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def ROCDL_BallotOp :
let summary = "Vote across thread group";

let description = [{
Ballot provides a bit mask containing the 1-bit predicate value from each lane.
Ballot provides a bit mask containing the 1-bit predicate value from each lane.
The nth bit of the result contains the 1 bit contributed by the nth warp lane.
}];

Expand Down Expand Up @@ -328,13 +328,16 @@ class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,
"$args attr-dict `:` functional-type($args, $res)";
}

// Available on RDNA3
// Available from gfx11
def ROCDL_wmma_f32_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.f16", [0]>;
def ROCDL_wmma_f32_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf16", [0]>;
def ROCDL_wmma_f16_16x16x16_f16 : ROCDL_Wmma_IntrOp<"wmma.f16.16x16x16.f16", [0]>;
def ROCDL_wmma_bf16_16x16x16_bf16 : ROCDL_Wmma_IntrOp<"wmma.bf16.16x16x16.bf16", [0]>;
def ROCDL_wmma_i32_16x16x16_iu8 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu8", [1]>;
def ROCDL_wmma_i32_16x16x16_iu4 : ROCDL_Wmma_IntrOp<"wmma.i32.16x16x16.iu4", [1]>;
// Available from gfx12
def ROCDL_wmma_f32_16x16x16_fp8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.fp8_fp8", [1]>;
def ROCDL_wmma_f32_16x16x16_bf8 : ROCDL_Wmma_IntrOp<"wmma.f32.16x16x16.bf8_bf8", [1]>;

//===---------------------------------------------------------------------===//
// Operations on raw buffer resources (stride of 0, bounds checks either off or in
Expand Down
53 changes: 31 additions & 22 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
Location loc,
const TypeConverter *typeConverter,
bool isUnsigned, Value llvmInput,
Value mlirInput,
SmallVector<Value, 4> &operands) {
Type inputType = llvmInput.getType();
auto vectorType = dyn_cast<VectorType>(inputType);
Expand All @@ -398,23 +399,29 @@ static void wmmaPushInputOperand(ConversionPatternRewriter &rewriter,
return;
}

// We need to check the type of the input before conversion to properly test
// for int8. This is because, in LLVM, fp8 type is converted to int8, so the
// fp8/int8 information is lost during the conversion process.
auto mlirInputType = cast<VectorType>(mlirInput.getType());
bool isInputInt8 = mlirInputType.getElementType().isInteger(8);
if (isInputInt8) {
// if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
bool localIsUnsigned = isUnsigned;
if (elemType.isUnsignedInteger(8)) {
localIsUnsigned = true;
} else if (elemType.isSignedInteger(8)) {
localIsUnsigned = false;
}
Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
operands.push_back(sign);
}

int64_t numBytes = vectorType.getNumElements();
Type i32 = rewriter.getI32Type();
VectorType vectorType32bits = VectorType::get(numBytes * 8 / 32, i32);
auto llvmVectorType32bits = typeConverter->convertType(vectorType32bits);

Value result = rewriter.createOrFold<LLVM::BitcastOp>(
loc, llvmVectorType32bits, llvmInput);

// if element type is 8-bit signed or unsigned, ignore the isUnsigned flag
bool localIsUnsigned = isUnsigned;
if (elemType.isUnsignedInteger(8)) {
localIsUnsigned = true;
} else if (elemType.isSignedInteger(8)) {
localIsUnsigned = false;
}
Value sign = createI1Constant(rewriter, loc, !localIsUnsigned);
operands.push_back(sign);
operands.push_back(result);
}

Expand Down Expand Up @@ -590,18 +597,20 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
auto elemSourceType = sourceVectorType.getElementType();
auto elemDestType = destVectorType.getElementType();

if (elemSourceType.isF16() && elemDestType.isF32()) {
if (elemSourceType.isF16() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
}
if (elemSourceType.isBF16() && elemDestType.isF32()) {
if (elemSourceType.isBF16() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf16::getOperationName();
} else if (elemSourceType.isF16() && elemDestType.isF16()) {
if (elemSourceType.isF16() && elemDestType.isF16())
return ROCDL::wmma_f16_16x16x16_f16::getOperationName();
} else if (elemSourceType.isBF16() && elemDestType.isBF16()) {
if (elemSourceType.isBF16() && elemDestType.isBF16())
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
} else if (elemSourceType.isInteger(8) && elemDestType.isInteger(32)) {
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
}
if (elemSourceType.isFloat8E4M3FN() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_fp8::getOperationName();
if (elemSourceType.isFloat8E5M2() && elemDestType.isF32())
return ROCDL::wmma_f32_16x16x16_bf8::getOperationName();
return std::nullopt;
}

Expand Down Expand Up @@ -662,8 +671,8 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
Location loc = op.getLoc();
Type outType = typeConverter->convertType(op.getDestD().getType());

if (chipset.majorVersion != 11)
return op->emitOpError("WMMA only supported on gfx11");
if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
return op->emitOpError("WMMA only supported on gfx11 and gfx12");

std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);

Expand All @@ -675,9 +684,9 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {

SmallVector<Value, 4> operands;
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
adaptor.getSourceA(), operands);
adaptor.getSourceA(), op.getSourceA(), operands);
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedB(),
adaptor.getSourceB(), operands);
adaptor.getSourceB(), op.getSourceB(), operands);
wmmaPushOutputOperand(rewriter, loc, typeConverter, adaptor.getDestC(),
op.getSubwordOffset(), op.getClamp(), operands);

Expand Down
7 changes: 4 additions & 3 deletions mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,10 @@ LogicalResult WMMAOp::verify() {
Type sourceAElemType = sourceVectorAType.getElementType();
Type destElemType = destVectorType.getElementType();

bool isDestFloat =
(destElemType.isF32() || destElemType.isF16() || destElemType.isBF16());
bool isSrcFloat = (sourceAElemType.isF16() || sourceAElemType.isBF16());
bool isDestFloat = isa<Float32Type, Float16Type, BFloat16Type>(destElemType);
bool isSrcFloat =
isa<Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(
sourceAElemType);

if (isDestFloat && !isSrcFloat) {
return emitOpError("Expected float sources with float destination");
Expand Down
9 changes: 9 additions & 0 deletions mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx12.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1200 --allow-unregistered-dialect | FileCheck %s
func.func @mfma_to_rocdl(%arg0 : vector<8xf8E4M3FN>, %arg1 : vector<8xf8E5M2>, %arg2 : vector<8xf32>) {
// CHECK: rocdl.wmma.f32.16x16x16.fp8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
amdgpu.wmma %arg0 * %arg0 + %arg2: vector<8xf8E4M3FN>, vector<8xf8E4M3FN>, vector<8xf32>

// CHECK: rocdl.wmma.f32.16x16x16.bf8{{.*}}: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>
amdgpu.wmma %arg1 * %arg1 + %arg2: vector<8xf8E5M2>, vector<8xf8E5M2>, vector<8xf32>
func.return
}
10 changes: 10 additions & 0 deletions mlir/test/Target/LLVMIR/rocdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,16 @@ llvm.func @rocdl.make.buffer.rsrc(%ptr : !llvm.ptr,
llvm.return %rsrc : !llvm.ptr<8>
}

llvm.func @rocdl.wmma.fp8(%arg0 : vector<2 x i32>, %arg1 : vector<8xf32>) -> vector<8xf32> {
// CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.fp8.fp8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}})
%r0 = rocdl.wmma.f32.16x16x16.fp8_fp8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>

// CHECK: call <8 x float> @llvm.amdgcn.wmma.f32.16x16x16.bf8.bf8.v8f32.v2i32(<2 x i32> %{{.*}}, <2 x i32> %{{.*}}, <8 x float> %{{.*}})
%r1 = rocdl.wmma.f32.16x16x16.bf8_bf8 %arg0, %arg0, %arg1: (vector<2xi32>, vector<2xi32>, vector<8xf32>) -> vector<8xf32>

llvm.return %r0 : vector<8 x f32>
}

llvm.func @rocdl.raw.ptr.buffer(%rsrc : !llvm.ptr<8>,
%offset : i32, %soffset : i32,
%vdata1 : i32,
Expand Down
Loading