-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] AMDGPUToROCDL: Use a bitcast op to reintepret a vector of i8 as single integer. #111400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Benoit Jacob <[email protected]>
94a9b01
to
f0909e4
Compare
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Benoit Jacob (bjacob) ChangesFound by inspecting AMDGPU assembly - so the arithmetic ops created there were definitely making their way into the target ISA. A Along the way, I thought that this helper function Full diff: https://github.com/llvm/llvm-project/pull/111400.diff 1 Files Affected:
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 2b33f3773dc7d1..0ccd4133d3761d 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -351,39 +351,23 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
} // namespace
-/// If `input` is a vector of bytes, concatentate those bytes in little-endian
-/// order to form a single integer of size 8 * [vector length]. This works
-/// around a wart in the AMDGPU intrinsics where operations that logically take
-/// vectors of bytes instead integers. Since we do not want to expose this
-/// implementation detail to MLIR, we correct for it here.
+/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
+/// and LLVM AMDGPU intrinsics convention.
///
-/// In addition, convert vectors of LLVM bfloats to vectors of i16, since AMDGPU
-/// MFMA intrinsics pre-date the bfloat type.
-static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter,
- Location loc, Value input) {
+/// Specifically:
+/// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
+/// 2. If the element type is bfloat16, bitcast it to i16.
+static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
+ Location loc, Value input) {
Type inputType = input.getType();
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
if (vectorType.getElementType().isBF16())
return rewriter.create<LLVM::BitcastOp>(
loc, vectorType.clone(rewriter.getI16Type()), input);
-
- if (!vectorType.getElementType().isInteger(8))
- return input;
- int64_t numBytes = vectorType.getNumElements();
- Type destType = rewriter.getIntegerType(numBytes * 8);
- Value result = rewriter.create<LLVM::ConstantOp>(
- loc, destType, rewriter.getIntegerAttr(destType, 0));
- for (int64_t i = 0; i < numBytes; ++i) {
- Value idxConst = createI32Constant(rewriter, loc, i);
- Value element =
- rewriter.create<LLVM::ExtractElementOp>(loc, input, idxConst);
- Value extended = rewriter.create<LLVM::ZExtOp>(loc, destType, element);
- Value shiftConst = rewriter.create<LLVM::ConstantOp>(
- loc, destType, rewriter.getIntegerAttr(destType, i * 8));
- Value shifted = rewriter.create<LLVM::ShlOp>(loc, extended, shiftConst);
- result = rewriter.create<LLVM::OrOp>(loc, result, shifted);
+ if (vectorType.getElementType().isInteger(8)) {
+ return rewriter.create<LLVM::BitcastOp>(
+ loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
}
- return result;
}
return input;
}
@@ -656,8 +640,8 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
OperationState loweredOp(loc, *maybeIntrinsic);
loweredOp.addTypes(intrinsicOutType);
loweredOp.addOperands(
- {mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceA()),
- mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceB()),
+ {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
+ convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
createI32Constant(rewriter, loc, op.getAbid()),
createI32Constant(rewriter, loc, getBlgpField)});
|
@llvm/pr-subscribers-backend-amdgpu Author: Benoit Jacob (bjacob) ChangesFound by inspecting AMDGPU assembly - so the arithmetic ops created there were definitely making their way into the target ISA. A Along the way, I thought that this helper function Full diff: https://github.com/llvm/llvm-project/pull/111400.diff 1 Files Affected:
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 2b33f3773dc7d1..0ccd4133d3761d 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -351,39 +351,23 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
} // namespace
-/// If `input` is a vector of bytes, concatentate those bytes in little-endian
-/// order to form a single integer of size 8 * [vector length]. This works
-/// around a wart in the AMDGPU intrinsics where operations that logically take
-/// vectors of bytes instead integers. Since we do not want to expose this
-/// implementation detail to MLIR, we correct for it here.
+/// Converts a MFMA vector operand from MLIR AMDGPU dialect convention to ROCDL
+/// and LLVM AMDGPU intrinsics convention.
///
-/// In addition, convert vectors of LLVM bfloats to vectors of i16, since AMDGPU
-/// MFMA intrinsics pre-date the bfloat type.
-static Value mfmaConcatIfNeeded(ConversionPatternRewriter &rewriter,
- Location loc, Value input) {
+/// Specifically:
+/// 1. If `input` is a vector of N bytes, bitcast it to a (N * 8)-bit integer.
+/// 2. If the element type is bfloat16, bitcast it to i16.
+static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
+ Location loc, Value input) {
Type inputType = input.getType();
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
if (vectorType.getElementType().isBF16())
return rewriter.create<LLVM::BitcastOp>(
loc, vectorType.clone(rewriter.getI16Type()), input);
-
- if (!vectorType.getElementType().isInteger(8))
- return input;
- int64_t numBytes = vectorType.getNumElements();
- Type destType = rewriter.getIntegerType(numBytes * 8);
- Value result = rewriter.create<LLVM::ConstantOp>(
- loc, destType, rewriter.getIntegerAttr(destType, 0));
- for (int64_t i = 0; i < numBytes; ++i) {
- Value idxConst = createI32Constant(rewriter, loc, i);
- Value element =
- rewriter.create<LLVM::ExtractElementOp>(loc, input, idxConst);
- Value extended = rewriter.create<LLVM::ZExtOp>(loc, destType, element);
- Value shiftConst = rewriter.create<LLVM::ConstantOp>(
- loc, destType, rewriter.getIntegerAttr(destType, i * 8));
- Value shifted = rewriter.create<LLVM::ShlOp>(loc, extended, shiftConst);
- result = rewriter.create<LLVM::OrOp>(loc, result, shifted);
+ if (vectorType.getElementType().isInteger(8)) {
+ return rewriter.create<LLVM::BitcastOp>(
+ loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input);
}
- return result;
}
return input;
}
@@ -656,8 +640,8 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
OperationState loweredOp(loc, *maybeIntrinsic);
loweredOp.addTypes(intrinsicOutType);
loweredOp.addOperands(
- {mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceA()),
- mfmaConcatIfNeeded(rewriter, loc, adaptor.getSourceB()),
+ {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
+ convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
adaptor.getDestC(), createI32Constant(rewriter, loc, op.getCbsz()),
createI32Constant(rewriter, loc, op.getAbid()),
createI32Constant(rewriter, loc, getBlgpField)});
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing test changes
It's because the test actually only checks that the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we have a testcase for this?
Ah, just saw that @arsenm had the same comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved conditional on "whoops, I forgot to add a test for that bit, please fix" above
Added the test. Note: added |
Pushed one more commit: hadn't noticed that the bf8 and fp8 instructions are also getting their bytes as i64. Crazy :-D |
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, | ||
Location loc, Value input) { | ||
Type inputType = input.getType(); | ||
if (auto vectorType = dyn_cast<VectorType>(inputType)) { | ||
if (vectorType.getElementType().isBF16()) | ||
return rewriter.create<LLVM::BitcastOp>( | ||
loc, vectorType.clone(rewriter.getI16Type()), input); | ||
|
||
if (!vectorType.getElementType().isInteger(8)) | ||
return input; | ||
int64_t numBytes = vectorType.getNumElements(); | ||
Type destType = rewriter.getIntegerType(numBytes * 8); | ||
Value result = rewriter.create<LLVM::ConstantOp>( | ||
loc, destType, rewriter.getIntegerAttr(destType, 0)); | ||
for (int64_t i = 0; i < numBytes; ++i) { | ||
Value idxConst = createI32Constant(rewriter, loc, i); | ||
Value element = | ||
rewriter.create<LLVM::ExtractElementOp>(loc, input, idxConst); | ||
Value extended = rewriter.create<LLVM::ZExtOp>(loc, destType, element); | ||
Value shiftConst = rewriter.create<LLVM::ConstantOp>( | ||
loc, destType, rewriter.getIntegerAttr(destType, i * 8)); | ||
Value shifted = rewriter.create<LLVM::ShlOp>(loc, extended, shiftConst); | ||
result = rewriter.create<LLVM::OrOp>(loc, result, shifted); | ||
if (vectorType.getElementType().isInteger(8)) { | ||
return rewriter.create<LLVM::BitcastOp>( | ||
loc, rewriter.getIntegerType(vectorType.getNumElements() * 8), input); | ||
} | ||
return result; | ||
} | ||
return input; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This whole function can just create one bitcast? I don't see why you need to consider the element types. Especially since bf16 should be natively consumed now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about that, but thought that maybe we want to avoid creating identity bitcasts unconditionally. And then if we're doing it conditionally when it would be needed, the code starts looking like its current form.
Feel free to improve this code further in a follow-up!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoiding identity bitcasts is one compare that the type matches (and the IRBuilder at least does that for you)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! But I mean: we sometimes want to bitcast, and sometimes not. We don't bitcast f32 and f16's. So (even after the bf16 simplification you mentioned) we still need to have some logic based on element types. I'd rather defer any further simplification to you as a follow-up, since you were aware of things such as this bf16 simplification, which I wasn't.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's also a nuance: in the bf16
case, we bitcast to a vector of i16
, while in the case of byte-size element types, we bitcast to a raw integer of higher bit width, not a vector of integers of the original element type bit width.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, the MFMA intrinsics still take i16
, last I checked
Found by inspecting AMDGPU assembly - so the arithmetic ops created there were definitely making their way into the target ISA. A
LLVM::BitcastOp
seems equivalent, and evaporates as expected in the target asm.Along the way, I thought that this helper function
mfmaConcatIfNeeded
could be renamed toconvertMFMAVectorOperand
to better convey its contract; so I don't need to think about whether a bitcast is a legitimate "concat" :-)