Skip to content

Commit d8a656f

Browse files
authored
[MLIR] AMDGPUToROCDL: Use a bitcast op to reintepret a vector of i8 as single integer. (#111400)
Found by inspecting AMDGPU assembly - so the arithmetic ops created there were definitely making their way into the target ISA. A `LLVM::BitcastOp` seems equivalent, and evaporates as expected in the target asm. Along the way, I thought that this helper function `mfmaConcatIfNeeded` could be renamed to `convertMFMAVectorOperand` to better convey its contract; so I don't need to think about whether a bitcast is a legitimate "concat" :-) --------- Signed-off-by: Benoit Jacob <[email protected]>
1 parent fabe7e3 commit d8a656f

File tree

2 files changed

+25
-35
lines changed

2 files changed

+25
-35
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -351,39 +351,23 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
351351

352352
} // namespace
353353

354-
/// If `input` is a vector of bytes, concatentate those bytes in little-endian
355-
/// order to form a single integer of size 8 * [vector length]. This works
356-
/// around a wart in the AMDGPU intrinsics where operations that logically take
357-
/// vectors of bytes instead integers. Since we do not want to expose this
358-
/// implementation detail to MLIR, we correct for it here.
354+
/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
355+
/// and LLVM AMDGPU intrinsics convention.
359356
///
360-
/// In addition, convert vectors of LLVM bfloats to vectors of i16, since AMDGPU
361-
/// MFMA intrinsics pre-date the bfloat type.
362-
static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter,
363-
Location loc, Value input) {
357+
/// Specifically:
358+
/// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
359+
/// 2. If the element type is bfloat16, bitcast it to i16.
360+
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
361+
Location loc, Value input) {
364362
Type inputType = input.getType();
365363
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
366364
if (vectorType.getElementType().isBF16())
367365
return rewriter.create<LLVM::BitcastOp>(
368366
loc, vectorType.clone(rewriter.getI16Type()), input);
369-
370-
if (!vectorType.getElementType().isInteger(8))
371-
return input;
372-
int64_t numBytes = vectorType.getNumElements();
373-
Type destType = rewriter.getIntegerType(numBytes * 8);
374-
Value result = rewriter.create<LLVM::ConstantOp>(
375-
loc, destType, rewriter.getIntegerAttr(destType, 0));
376-
for (int64_t i = 0; i < numBytes; ++i) {
377-
Value idxConst = createI32Constant(rewriter, loc, i);
378-
Value element =
379-
rewriter.create<LLVM::ExtractElementOp>(loc, input, idxConst);
380-
Value extended = rewriter.create<LLVM::ZExtOp>(loc, destType, element);
381-
Value shiftConst = rewriter.create<LLVM::ConstantOp>(
382-
loc, destType, rewriter.getIntegerAttr(destType, i * 8));
383-
Value shifted = rewriter.create<LLVM::ShlOp>(loc, extended, shiftConst);
384-
result = rewriter.create<LLVM::OrOp>(loc, result, shifted);
367+
if (vectorType.getElementType().isInteger(8)) {
368+
return rewriter.create<LLVM::BitcastOp>(
369+
loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
385370
}
386-
return result;
387371
}
388372
return input;
389373
}
@@ -656,8 +640,8 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
656640
OperationState loweredOp(loc, *maybeIntrinsic);
657641
loweredOp.addTypes(intrinsicOutType);
658642
loweredOp.addOperands(
659-
{mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceA()),
660-
mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceB()),
643+
{convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
644+
convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
661645
adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
662646
createI32Constant(rewriter, loc, op.getAbid()),
663647
createI32Constant(rewriter, loc, getBlgpField)});

mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx940 | FileCheck %s
1+
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx940 -cse | FileCheck %s
22
func.func @mfma_to_rocdl(%arg0 : f32, %arg1 : vector<32xf32>,
33
%arg2 : vector<16xf32>, %arg3 : vector<4xf32>,
44
%arg4 : vector<4xf16>, %arg5 : vector<4xi8>,
@@ -28,7 +28,8 @@ func.func @mfma_to_rocdl(%arg0 : f32, %arg1 : vector<32xf32>,
2828
amdgpu.mfma %arg4 * %arg4 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<16xf32>
2929
// CHECK: rocdl.mfma.f32.16x16x16f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
3030
amdgpu.mfma %arg4 * %arg4 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
31-
// CHECK: rocdl.mfma.i32.32x32x4i8{{.*}}: (i32, i32, vector<32xi32>, i32, i32, i32) -> vector<32xi32>
31+
// CHECK: %[[BITCAST_4xi8_i32:.+]] = llvm.bitcast {{.*}} : vector<4xi8> to i32
32+
// CHECK: rocdl.mfma.i32.32x32x4i8 %[[BITCAST_4xi8_i32]], %[[BITCAST_4xi8_i32]], {{.*}}: (i32, i32, vector<32xi32>, i32, i32, i32) -> vector<32xi32>
3233
amdgpu.mfma %arg5 * %arg5 + %arg6 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<32xi32>
3334
// CHECK: rocdl.mfma.i32.16x16x4i8{{.*}}: (i32, i32, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
3435
amdgpu.mfma %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<16xi32>
@@ -38,7 +39,8 @@ func.func @mfma_to_rocdl(%arg0 : f32, %arg1 : vector<32xf32>,
3839
amdgpu.mfma %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<16xi32>
3940
// CHECK: rocdl.mfma.i32.16x16x16i8{{.*}}: (i32, i32, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
4041
amdgpu.mfma %arg5 * %arg5 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<4xi32>
41-
// CHECK: rocdl.mfma.f32.32x32x2bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
42+
// CHECK: %[[BITCAST_2xbf16_2xi16:.+]] = llvm.bitcast {{.*}} : vector<2xbf16> to vector<2xi16>
43+
// CHECK: rocdl.mfma.f32.32x32x2bf16 %[[BITCAST_2xbf16_2xi16]], %[[BITCAST_2xbf16_2xi16]], %{{.*}}: (vector<2xi16>, vector<2xi16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
4244
amdgpu.mfma %arg9 * %arg9 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<32xf32>
4345
// CHECK: rocdl.mfma.f32.16x16x2bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
4446
amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<16xf32>
@@ -48,7 +50,8 @@ func.func @mfma_to_rocdl(%arg0 : f32, %arg1 : vector<32xf32>,
4850
amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<16xf32>
4951
// CHECK: rocdl.mfma.f32.16x16x8bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
5052
amdgpu.mfma %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<4xf32>
51-
// CHECK: rocdl.mfma.f32.32x32x4bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
53+
// CHECK: %[[BITCAST_4xbf16_4xi16:.+]] = llvm.bitcast {{.*}} : vector<4xbf16> to vector<4xi16>
54+
// CHECK: rocdl.mfma.f32.32x32x4bf16.1k %[[BITCAST_4xbf16_4xi16]], %[[BITCAST_4xbf16_4xi16]], {{.*}}: (vector<4xi16>, vector<4xi16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
5255
amdgpu.mfma %arg10 * %arg10 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<32xf32>
5356
// CHECK: rocdl.mfma.f32.16x16x4bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
5457
amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<16xf32>
@@ -62,17 +65,20 @@ func.func @mfma_to_rocdl(%arg0 : f32, %arg1 : vector<32xf32>,
6265
amdgpu.mfma %arg11 * %arg11 + %arg12 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : f64, f64, vector<4xf64>
6366
// CHECK: rocdl.mfma.f64.4x4x4f64{{.*}}: (f64, f64, f64, i32, i32, i32) -> f64
6467
amdgpu.mfma %arg11 * %arg11 + %arg11 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 4 : i32 } blgp = none : f64, f64, f64
65-
// CHECK: rocdl.mfma.i32.16x16x32.i8{{.*}}: (i64, i64, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
68+
// CHECK: %[[BITCAST_8xi8_i64:.+]] = llvm.bitcast {{.*}} : vector<8xi8> to i64
69+
// CHECK: rocdl.mfma.i32.16x16x32.i8 %[[BITCAST_8xi8_i64]], %[[BITCAST_8xi8_i64]], {{.*}}: (i64, i64, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
6670
amdgpu.mfma %arg13 * %arg13 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xi8>, vector<8xi8>, vector<4xi32>
6771
// CHECK: rocdl.mfma.i32.32x32x16.i8{{.*}}: (i64, i64, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
6872
amdgpu.mfma %arg13 * %arg13 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xi8>, vector<8xi8>, vector<16xi32>
6973
// CHECK: rocdl.mfma.f32.16x16x8.xf32{{.*}}: (vector<2xf32>, vector<2xf32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
7074
amdgpu.mfma %arg14 * %arg14 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32, reducePrecision } blgp = none : vector<2xf32>, vector<2xf32>, vector<4xf32>
7175
// CHECK: rocdl.mfma.f32.32x32x4.xf32{{.*}}: (vector<2xf32>, vector<2xf32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
7276
amdgpu.mfma %arg14 * %arg14 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32, reducePrecision } blgp = none : vector<2xf32>, vector<2xf32>, vector<16xf32>
73-
// CHECK: rocdl.mfma.f32.16x16x32.bf8.bf8{{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
77+
// CHECK: %[[BITCAST_8xi8_i64_1:.+]] = llvm.bitcast {{.*}} : vector<8xi8> to i64
78+
// CHECK: rocdl.mfma.f32.16x16x32.bf8.bf8 %[[BITCAST_8xi8_i64_1]], %[[BITCAST_8xi8_i64_1]], {{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
7479
amdgpu.mfma %arg15 * %arg15 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xf8E5M2FNUZ>, vector<8xf8E5M2FNUZ>, vector<4xf32>
75-
// CHECK: rocdl.mfma.f32.16x16x32.bf8.fp8{{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
80+
// CHECK: %[[BITCAST_8xi8_i64_2:.+]] = llvm.bitcast {{.*}} : vector<8xi8> to i64
81+
// CHECK: rocdl.mfma.f32.16x16x32.bf8.fp8 %[[BITCAST_8xi8_i64_1]], %[[BITCAST_8xi8_i64_2]], {{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
7682
amdgpu.mfma %arg15 * %arg16 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xf8E5M2FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>
7783
// CHECK: rocdl.mfma.f32.16x16x32.fp8.bf8{{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
7884
amdgpu.mfma %arg16 * %arg15 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E5M2FNUZ>, vector<4xf32>

0 commit comments

Comments
 (0)