Skip to content

Commit 6292ea6

Browse files
krzysz00kuhar
andauthored
[mlir][AMDGPU] Remove an old bf16 workaround (#108409)
The AMDGPU backend now implements LLVM's `bfloat` type. Therefore, we no longer need to type convert MLIR's `bf16` to `i16` during lowerings to ROCDL. As a result of this change, we discovered that, whel the code for MFMA and WMMA intrinsics was mainly prepared for this change, we were failing to bitcast the bf16 results of WMMA operations out from the i16 they're natively represented as. This commit also fixes that issue. --------- Co-authored-by: Jakub Kuderski <[email protected]>
1 parent 48088dc commit 6292ea6

File tree

3 files changed

+28
-21
lines changed

3 files changed

+28
-21
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -671,18 +671,27 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
671671
matchAndRewrite(WMMAOp op, WMMAOpAdaptor adaptor,
672672
ConversionPatternRewriter &rewriter) const override {
673673
Location loc = op.getLoc();
674-
Type outType = typeConverter->convertType(op.getDestD().getType());
674+
auto outType =
675+
typeConverter->convertType<VectorType>(op.getDestD().getType());
676+
if (!outType)
677+
return rewriter.notifyMatchFailure(op, "type conversion failed");
675678

676679
if (chipset.majorVersion != 11 && chipset.majorVersion != 12)
677680
return op->emitOpError("WMMA only supported on gfx11 and gfx12");
678681

682+
// The WMMA operations represent vectors of bf16s as vectors of i16s, so we
683+
// need to bitcast bfloats to i16 and then bitcast them back.
684+
VectorType rawOutType = outType;
685+
if (outType.getElementType().isBF16())
686+
rawOutType = outType.clone(rewriter.getI16Type());
687+
679688
std::optional<StringRef> maybeIntrinsic = wmmaOpToIntrinsic(op, chipset);
680689

681690
if (!maybeIntrinsic.has_value())
682691
return op.emitOpError("no intrinsic matching WMMA on the given chipset");
683692

684693
OperationState loweredOp(loc, *maybeIntrinsic);
685-
loweredOp.addTypes(outType);
694+
loweredOp.addTypes(rawOutType);
686695

687696
SmallVector<Value, 4> operands;
688697
wmmaPushInputOperand(rewriter, loc, typeConverter, op.getUnsignedA(),
@@ -694,7 +703,12 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern<WMMAOp> {
694703

695704
loweredOp.addOperands(operands);
696705
Operation *lowered = rewriter.create(loweredOp);
697-
rewriter.replaceOp(op, lowered->getResults());
706+
707+
Operation *maybeCastBack = lowered;
708+
if (rawOutType != outType)
709+
maybeCastBack =
710+
rewriter.create<LLVM::BitcastOp>(loc, outType, lowered->getResult(0));
711+
rewriter.replaceOp(op, maybeCastBack->getResults());
698712

699713
return success();
700714
}
@@ -1033,15 +1047,6 @@ struct ConvertAMDGPUToROCDLPass
10331047
void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter,
10341048
RewritePatternSet &patterns,
10351049
Chipset chipset) {
1036-
converter.addConversion([](BFloat16Type t) -> Type {
1037-
return IntegerType::get(t.getContext(), 16);
1038-
});
1039-
converter.addConversion([&converter](VectorType t) -> std::optional<Type> {
1040-
if (!t.getElementType().isBF16())
1041-
return std::nullopt;
1042-
return converter.convertType(t.clone(IntegerType::get(t.getContext(), 16)));
1043-
});
1044-
10451050
patterns
10461051
.add<RawBufferOpLowering<RawBufferLoadOp, ROCDL::RawPtrBufferLoadOp>,
10471052
RawBufferOpLowering<RawBufferStoreOp, ROCDL::RawPtrBufferStoreOp>,

mlir/test/Conversion/AMDGPUToROCDL/wmma.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@ func.func @mfma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 :
1515
amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16>
1616
// CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16>
1717
amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16>
18-
// CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
18+
// CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16>
19+
// CHECK-NEXT: llvm.bitcast %[[raw_bf16x16]] : vector<16xi16> to vector<16xbf16>
1920
amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16>
20-
// CHECK: rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
21+
// CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16>
22+
// CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16>
2123
amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16>
2224
// CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32>
2325
amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<4xi32>

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -507,22 +507,22 @@ gpu.module @test_module {
507507

508508
// -----
509509

510-
// Test that the bf16 type is lowered away on this target.
510+
// Test that the bf16 type is passed through to LLVM.
511511

512512
gpu.module @test_module {
513513
// CHECK-LABEL: func @bf16_id
514514
func.func @bf16_id(%arg0 : bf16) -> bf16 {
515-
// CHECK-SAME: (%[[ARG0:.+]]: i16)
516-
// CHECK-SAME: -> i16
517-
// CHECK: return %[[ARG0]] : i16
515+
// CHECK-SAME: (%[[ARG0:.+]]: bf16)
516+
// CHECK-SAME: -> bf16
517+
// CHECK: return %[[ARG0]] : bf16
518518
func.return %arg0 : bf16
519519
}
520520

521521
// CHECK-LABEL: func @bf16x4_id
522522
func.func @bf16x4_id(%arg0 : vector<4xbf16>) -> vector<4xbf16> {
523-
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xi16>)
524-
// CHECK-SAME: -> vector<4xi16>
525-
// CHECK: return %[[ARG0]] : vector<4xi16>
523+
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xbf16>)
524+
// CHECK-SAME: -> vector<4xbf16>
525+
// CHECK: return %[[ARG0]] : vector<4xbf16>
526526
func.return %arg0 : vector<4xbf16>
527527
}
528528

0 commit comments

Comments
 (0)