Skip to content

Commit f0909e4

Browse files
committed
bitcast
Signed-off-by: Benoit Jacob <[email protected]>
1 parent d4c1789 commit f0909e4

File tree

1 file changed

+12
-28
lines changed

1 file changed

+12
-28
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -351,39 +351,23 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
351351

352352
} // namespace
353353

354-
/// If `input` is a vector of bytes, concatentate those bytes in little-endian
355-
/// order to form a single integer of size 8 * [vector length]. This works
356-
/// around a wart in the AMDGPU intrinsics where operations that logically take
357-
/// vectors of bytes instead integers. Since we do not want to expose this
358-
/// implementation detail to MLIR, we correct for it here.
354+
/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
355+
/// and LLVM AMDGPU intrinsics convention.
359356
///
360-
/// In addition, convert vectors of LLVM bfloats to vectors of i16, since AMDGPU
361-
/// MFMA intrinsics pre-date the bfloat type.
362-
static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter,
363-
Location loc, Value input) {
357+
/// Specifically:
358+
/// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
359+
/// 2. If the element type is bfloat16, bitcast it to i16.
360+
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
361+
Location loc, Value input) {
364362
Type inputType = input.getType();
365363
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
366364
if (vectorType.getElementType().isBF16())
367365
return rewriter.create<LLVM::BitcastOp>(
368366
loc, vectorType.clone(rewriter.getI16Type()), input);
369-
370-
if (!vectorType.getElementType().isInteger(8))
371-
return input;
372-
int64_t numBytes = vectorType.getNumElements();
373-
Type destType = rewriter.getIntegerType(numBytes * 8);
374-
Value result = rewriter.create<LLVM::ConstantOp>(
375-
loc, destType, rewriter.getIntegerAttr(destType, 0));
376-
for (int64_t i = 0; i < numBytes; ++i) {
377-
Value idxConst = createI32Constant(rewriter, loc, i);
378-
Value element =
379-
rewriter.create<LLVM::ExtractElementOp>(loc, input, idxConst);
380-
Value extended = rewriter.create<LLVM::ZExtOp>(loc, destType, element);
381-
Value shiftConst = rewriter.create<LLVM::ConstantOp>(
382-
loc, destType, rewriter.getIntegerAttr(destType, i * 8));
383-
Value shifted = rewriter.create<LLVM::ShlOp>(loc, extended, shiftConst);
384-
result = rewriter.create<LLVM::OrOp>(loc, result, shifted);
367+
if (vectorType.getElementType().isInteger(8)) {
368+
return rewriter.create<LLVM::BitcastOp>(
369+
loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
385370
}
386-
return result;
387371
}
388372
return input;
389373
}
@@ -656,8 +640,8 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
656640
OperationState loweredOp(loc, *maybeIntrinsic);
657641
loweredOp.addTypes(intrinsicOutType);
658642
loweredOp.addOperands(
659-
{mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceA()),
660-
mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceB()),
643+
{convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
644+
convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
661645
adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
662646
createI32Constant(rewriter, loc, op.getAbid()),
663647
createI32Constant(rewriter, loc, getBlgpField)});

0 commit comments

Comments
 (0)