Skip to content

[MLIR] AMDGPUToROCDL: Use a bitcast op to reintepret a vector of i8 as single integer. #111400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 7, 2024

Conversation

bjacob
Copy link
Contributor

@bjacob bjacob commented Oct 7, 2024

Found by inspecting AMDGPU assembly - so the arithmetic ops created there were definitely making their way into the target ISA. A LLVM::BitcastOp seems equivalent, and evaporates as expected in the target asm.

Along the way, I thought that this helper function mfmaConcatIfNeeded could be renamed to convertMFMAVectorOperand to better convey its contract; so I don't need to think about whether a bitcast is a legitimate "concat" :-)

@bjacob bjacob changed the title [MLIR] AMDGPUToROCDL: Use a bitcast op to reintepret a vector is i8 as single integer. [MLIR] AMDGPUToROCDL: Use a bitcast op to reintepret a vector of i8 as single integer. Oct 7, 2024
Signed-off-by: Benoit Jacob <[email protected]>
@bjacob bjacob marked this pull request as ready for review October 7, 2024 16:43
@bjacob bjacob requested review from krzysz00 and kuhar October 7, 2024 16:43
@llvmbot
Copy link
Member

llvmbot commented Oct 7, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Benoit Jacob (bjacob)

Changes

Found by inspecting AMDGPU assembly - so the arithmetic ops created there were definitely making their way into the target ISA. A LLVM::BitcastOp seems equivalent, and evaporates as expected in the target asm.

Along the way, I thought that this helper function mfmaConcatIfNeeded could be renamed to convertMFMAVectorOperand to better convey its contract; so I don't need to think about whether a bitcast is a legitimate "concat" :-)


Full diff: https://github.com/llvm/llvm-project/pull/111400.diff

1 Files Affected:

  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+12-28)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 2b33f3773dc7d1..0ccd4133d3761d 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -351,39 +351,23 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
 
 } // namespace
 
-/// If `input` is a vector of bytes, concatentate those bytes in little-endian
-/// order to form a single integer of size 8 * [vector length]. This works
-/// around a wart in the AMDGPU intrinsics where operations that logically take
-/// vectors of bytes instead integers. Since we do not want to expose this
-/// implementation detail to MLIR, we correct for it here.
+/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
+/// and LLVM AMDGPU intrinsics convention.
 ///
-/// In addition, convert vectors of LLVM bfloats to vectors of i16, since AMDGPU
-/// MFMA intrinsics pre-date the bfloat type.
-static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter,
-                                Location loc, Value input) {
+/// Specifically:
+/// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
+/// 2. If the element type is bfloat16, bitcast it to i16.
+static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
+                                      Location loc, Value input) {
   Type inputType = input.getType();
   if (auto vectorType = dyn_cast<VectorType>(inputType)) {
     if (vectorType.getElementType().isBF16())
       return rewriter.create<LLVM::BitcastOp>(
           loc, vectorType.clone(rewriter.getI16Type()), input);
-
-    if (!vectorType.getElementType().isInteger(8))
-      return input;
-    int64_t numBytes = vectorType.getNumElements();
-    Type destType = rewriter.getIntegerType(numBytes * 8);
-    Value result = rewriter.create<LLVM::ConstantOp>(
-        loc, destType, rewriter.getIntegerAttr(destType, 0));
-    for (int64_t i = 0; i < numBytes; ++i) {
-      Value idxConst = createI32Constant(rewriter, loc, i);
-      Value element =
-          rewriter.create<LLVM::ExtractElementOp>(loc, input, idxConst);
-      Value extended = rewriter.create<LLVM::ZExtOp>(loc, destType, element);
-      Value shiftConst = rewriter.create<LLVM::ConstantOp>(
-          loc, destType, rewriter.getIntegerAttr(destType, i * 8));
-      Value shifted = rewriter.create<LLVM::ShlOp>(loc, extended, shiftConst);
-      result = rewriter.create<LLVM::OrOp>(loc, result, shifted);
+    if (vectorType.getElementType().isInteger(8)) {
+      return rewriter.create<LLVM::BitcastOp>(
+          loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
     }
-    return result;
   }
   return input;
 }
@@ -656,8 +640,8 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
     OperationState loweredOp(loc, *maybeIntrinsic);
     loweredOp.addTypes(intrinsicOutType);
     loweredOp.addOperands(
-        {mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceA()),
-         mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceB()),
+        {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
+         convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
          adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
          createI32Constant(rewriter, loc, op.getAbid()),
          createI32Constant(rewriter, loc, getBlgpField)});

@llvmbot
Copy link
Member

llvmbot commented Oct 7, 2024

@llvm/pr-subscribers-backend-amdgpu

Author: Benoit Jacob (bjacob)

Changes

Found by inspecting AMDGPU assembly - so the arithmetic ops created there were definitely making their way into the target ISA. A LLVM::BitcastOp seems equivalent, and evaporates as expected in the target asm.

Along the way, I thought that this helper function mfmaConcatIfNeeded could be renamed to convertMFMAVectorOperand to better convey its contract; so I don't need to think about whether a bitcast is a legitimate "concat" :-)


Full diff: https://github.com/llvm/llvm-project/pull/111400.diff

1 Files Affected:

  • (modified) mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp (+12-28)
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 2b33f3773dc7d1..0ccd4133d3761d 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -351,39 +351,23 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
 
 } // namespace
 
-/// If `input` is a vector of bytes, concatentate those bytes in little-endian
-/// order to form a single integer of size 8 * [vector length]. This works
-/// around a wart in the AMDGPU intrinsics where operations that logically take
-/// vectors of bytes instead integers. Since we do not want to expose this
-/// implementation detail to MLIR, we correct for it here.
+/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
+/// and LLVM AMDGPU intrinsics convention.
 ///
-/// In addition, convert vectors of LLVM bfloats to vectors of i16, since AMDGPU
-/// MFMA intrinsics pre-date the bfloat type.
-static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter,
-                                Location loc, Value input) {
+/// Specifically:
+/// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
+/// 2. If the element type is bfloat16, bitcast it to i16.
+static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
+                                      Location loc, Value input) {
   Type inputType = input.getType();
   if (auto vectorType = dyn_cast<VectorType>(inputType)) {
     if (vectorType.getElementType().isBF16())
       return rewriter.create<LLVM::BitcastOp>(
           loc, vectorType.clone(rewriter.getI16Type()), input);
-
-    if (!vectorType.getElementType().isInteger(8))
-      return input;
-    int64_t numBytes = vectorType.getNumElements();
-    Type destType = rewriter.getIntegerType(numBytes * 8);
-    Value result = rewriter.create<LLVM::ConstantOp>(
-        loc, destType, rewriter.getIntegerAttr(destType, 0));
-    for (int64_t i = 0; i < numBytes; ++i) {
-      Value idxConst = createI32Constant(rewriter, loc, i);
-      Value element =
-          rewriter.create<LLVM::ExtractElementOp>(loc, input, idxConst);
-      Value extended = rewriter.create<LLVM::ZExtOp>(loc, destType, element);
-      Value shiftConst = rewriter.create<LLVM::ConstantOp>(
-          loc, destType, rewriter.getIntegerAttr(destType, i * 8));
-      Value shifted = rewriter.create<LLVM::ShlOp>(loc, extended, shiftConst);
-      result = rewriter.create<LLVM::OrOp>(loc, result, shifted);
+    if (vectorType.getElementType().isInteger(8)) {
+      return rewriter.create<LLVM::BitcastOp>(
+          loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
     }
-    return result;
   }
   return input;
 }
@@ -656,8 +640,8 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
     OperationState loweredOp(loc, *maybeIntrinsic);
     loweredOp.addTypes(intrinsicOutType);
     loweredOp.addOperands(
-        {mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceA()),
-         mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceB()),
+        {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
+         convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
          adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
          createI32Constant(rewriter, loc, op.getAbid()),
          createI32Constant(rewriter, loc, getBlgpField)});

Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing test changes

@bjacob
Copy link
Contributor Author

bjacob commented Oct 7, 2024

missing test changes

It's because the test actually only checks that the mfma op is as expected, it wasn't covering the IR that is being simplified here. I will add that.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a testcase for this?

@kuhar
Copy link
Member

kuhar commented Oct 7, 2024

Ah, just saw that @arsenm had the same comment

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved conditional on "whoops, I forgot to add a test for that bit, please fix" above

@bjacob
Copy link
Contributor Author

bjacob commented Oct 7, 2024

Added the test. Note: added -cse to the existing test file, because this test passes the same Value for LHS and RHS, so we are getting the same IR generated twice; writing the CHECKs against the non-CSE'd form seems like it would be fragile.

@bjacob
Copy link
Contributor Author

bjacob commented Oct 7, 2024

Pushed one more commit: hadn't noticed that the bf8 and fp8 instructions are also getting their bytes as i64. Crazy :-D

Comment on lines +360 to 372
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
Location loc, Value input) {
Type inputType = input.getType();
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
if (vectorType.getElementType().isBF16())
return rewriter.create<LLVM::BitcastOp>(
loc, vectorType.clone(rewriter.getI16Type()), input);

if (!vectorType.getElementType().isInteger(8))
return input;
int64_t numBytes = vectorType.getNumElements();
Type destType = rewriter.getIntegerType(numBytes * 8);
Value result = rewriter.create<LLVM::ConstantOp>(
loc, destType, rewriter.getIntegerAttr(destType, 0));
for (int64_t i = 0; i < numBytes; ++i) {
Value idxConst = createI32Constant(rewriter, loc, i);
Value element =
rewriter.create<LLVM::ExtractElementOp>(loc, input, idxConst);
Value extended = rewriter.create<LLVM::ZExtOp>(loc, destType, element);
Value shiftConst = rewriter.create<LLVM::ConstantOp>(
loc, destType, rewriter.getIntegerAttr(destType, i * 8));
Value shifted = rewriter.create<LLVM::ShlOp>(loc, extended, shiftConst);
result = rewriter.create<LLVM::OrOp>(loc, result, shifted);
if (vectorType.getElementType().isInteger(8)) {
return rewriter.create<LLVM::BitcastOp>(
loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
}
return result;
}
return input;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole function can just create one bitcast? I don't see why you need to consider the element types. Especially since bf16 should be natively consumed now

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about that, but thought that maybe we want to avoid creating identity bitcasts unconditionally. And then if we're doing it conditionally when it would be needed, the code starts looking like its current form.
Feel free to improve this code further in a follow-up!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoiding identity bitcasts is one compare that the type matches (and the IRBuilder at least does that for you)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! But I mean: we sometimes want to bitcast, and sometimes not. We don't bitcast f32 and f16's. So (even after the bf16 simplification you mentioned) we still need to have some logic based on element types. I'd rather defer any further simplification to you as a follow-up, since you were aware of things such as this bf16 simplification, which I wasn't.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's also a nuance: in the bf16 case, we bitcast to a vector of i16, while in the case of byte-size element types, we bitcast to a raw integer of higher bit width, not a vector of integers of the original element type bit width.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the MFMA intrinsics still take i16, last I checked

@bjacob bjacob merged commit d8a656f into llvm:main Oct 7, 2024
7 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants