@@ -351,39 +351,23 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
351
351
352
352
} // namespace
353
353
354
- // / If `input` is a vector of bytes, concatentate those bytes in little-endian
355
- // / order to form a single integer of size 8 * [vector length]. This works
356
- // / around a wart in the AMDGPU intrinsics where operations that logically take
357
- // / vectors of bytes instead integers. Since we do not want to expose this
358
- // / implementation detail to MLIR, we correct for it here.
354
+ // / Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
355
+ // / and LLVM AMDGPU intrinsics convention.
359
356
// /
360
- // / In addition, convert vectors of LLVM bfloats to vectors of i16, since AMDGPU
361
- // / MFMA intrinsics pre-date the bfloat type.
362
- static Value mfmaConcatIfNeeded (ConversionPatternRewriter &rewriter,
363
- Location loc, Value input) {
357
+ // / Specifically:
358
+ // / 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
359
+ // / 2. If the element type is bfloat16, bitcast it to i16.
360
+ static Value convertMFMAVectorOperand (ConversionPatternRewriter &rewriter,
361
+ Location loc, Value input) {
364
362
Type inputType = input.getType ();
365
363
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
366
364
if (vectorType.getElementType ().isBF16 ())
367
365
return rewriter.create <LLVM::BitcastOp>(
368
366
loc, vectorType.clone (rewriter.getI16Type ()), input);
369
-
370
- if (!vectorType.getElementType ().isInteger (8 ))
371
- return input;
372
- int64_t numBytes = vectorType.getNumElements ();
373
- Type destType = rewriter.getIntegerType (numBytes * 8 );
374
- Value result = rewriter.create <LLVM::ConstantOp>(
375
- loc, destType, rewriter.getIntegerAttr (destType, 0 ));
376
- for (int64_t i = 0 ; i < numBytes; ++i) {
377
- Value idxConst = createI32Constant (rewriter, loc, i);
378
- Value element =
379
- rewriter.create <LLVM::ExtractElementOp>(loc, input, idxConst);
380
- Value extended = rewriter.create <LLVM::ZExtOp>(loc, destType, element);
381
- Value shiftConst = rewriter.create <LLVM::ConstantOp>(
382
- loc, destType, rewriter.getIntegerAttr (destType, i * 8 ));
383
- Value shifted = rewriter.create <LLVM::ShlOp>(loc, extended, shiftConst);
384
- result = rewriter.create <LLVM::OrOp>(loc, result, shifted);
367
+ if (vectorType.getElementType ().isInteger (8 )) {
368
+ return rewriter.create <LLVM::BitcastOp>(
369
+ loc, rewriter.getIntegerType (vectorType.getNumElements () * 8 ), input);
385
370
}
386
- return result;
387
371
}
388
372
return input;
389
373
}
@@ -656,8 +640,8 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
656
640
OperationState loweredOp (loc, *maybeIntrinsic);
657
641
loweredOp.addTypes (intrinsicOutType);
658
642
loweredOp.addOperands (
659
- {mfmaConcatIfNeeded (rewriter, loc, adaptor.getSourceA ()),
660
- mfmaConcatIfNeeded (rewriter, loc, adaptor.getSourceB ()),
643
+ {convertMFMAVectorOperand (rewriter, loc, adaptor.getSourceA ()),
644
+ convertMFMAVectorOperand (rewriter, loc, adaptor.getSourceB ()),
661
645
adaptor.getDestC (), createI32Constant (rewriter, loc, op.getCbsz ()),
662
646
createI32Constant (rewriter, loc, op.getAbid ()),
663
647
createI32Constant (rewriter, loc, getBlgpField)});
0 commit comments