-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] ArithToLLVM: fix memref bitcast lowering #125148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Ivan Butygin (Hardcode84) Changes
With the opaque pointers this is a no-op operation for memref so we can just add type check in Full diff: https://github.com/llvm/llvm-project/pull/125148.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 754ed898142936..b726faa92a03a0 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -54,6 +54,23 @@ struct ConstrainedVectorConvertToLLVMPattern
}
};
+/// No-op bitcast.
+struct IdentityBitcastLowering final
+ : public OpConversionPattern<arith::BitcastOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ Value src = adaptor.getIn();
+ if (src.getType() != getTypeConverter()->convertType(op.getType()))
+ return rewriter.notifyMatchFailure(op, "Types are different");
+
+ rewriter.replaceOp(op, src);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Straightforward Op Lowerings
//===----------------------------------------------------------------------===//
@@ -524,6 +541,9 @@ void mlir::arith::registerConvertArithToLLVMInterface(
void mlir::arith::populateArithToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+
+ patterns.add<IdentityBitcastLowering>(converter, patterns.getContext());
+
// clang-format off
patterns.add<
AddFOpLowering,
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index 626135c10a3e96..c9d3b57b0d596e 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -103,6 +103,11 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
return success();
}
+static bool isVectorCompatibleType(Type type) {
+ return isa<LLVM::LLVMArrayType, VectorType, IntegerType, FloatType>(type) &&
+ LLVM::isCompatibleType(type);
+}
+
LogicalResult LLVM::detail::vectorOneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
@@ -111,7 +116,7 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
assert(!operands.empty());
// Cannot convert ops if their operands are not of LLVM type.
- if (!llvm::all_of(operands.getTypes(), isCompatibleType))
+ if (!llvm::all_of(operands.getTypes(), isVectorCompatibleType))
return failure();
auto llvmNDVectorTy = operands[0].getType();
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 1dabacfd8a47cc..9a6c4bca88f3bf 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -727,3 +727,15 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
%3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
return
}
+
+// -----
+
+// CHECK-LABEL: func @memref_bitcast
+// CHECK-SAME: (%[[ARG:.*]]: memref<?xi16>)
+// CHECK: %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?xi16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[V2:.*]] = builtin.unrealized_conversion_cast %[[V1]] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<?xbf16>
+// CHECK: return %[[V2]]
+func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
+ %2 = arith.bitcast %1 : memref<?xi16> to memref<?xbf16>
+ func.return %2 : memref<?xbf16>
+}
|
matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const final { | ||
Value src = adaptor.getIn(); | ||
if (src.getType() != getTypeConverter()->convertType(op.getType())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find this a bit difficult to follow. Would it make sense to change this to: if (!isa<MemRefType>(src.getType())
? And rename the pattern to MemRefBitcastLowering
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if we need any memref-specific logic here, just handling same input/output converted types should be enough.
@@ -103,6 +103,11 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors( | |||
return success(); | |||
} | |||
|
|||
static bool isVectorCompatibleType(Type type) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can just LLVM::isCompatibleType
be used here? It already checks for LLVMArrayType
, VectorType
, etc. Alternatively, there is also LLVM::isCompatibleVectorType
, which may be useful here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was using LLVM::isCompatibleType
before, but it's too broad, I specifically want to limit this transform to scalar and vector types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of fiddling, can you just set PatternBenefit
on the bitcast pattern?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While PatternBenefit
will probably work in this specific case, I still think there is a potential problem in vectorOneToOneRewrite
as it can generate llvm bitcasts for unsupported types like structs.
arith.bitcast is allowed on memrefs and such code can actually be generated by IREE `ConvertBf16ArithToF32Pass`. `LLVM::detail::vectorOneToOneRewrite` doesn't properly check its types and will generate bitcast between structs which is illegal. With the opaque pointers this is a no-op operation for memref so we can just add type check in `LLVM::detail::vectorOneToOneRewrite` and add a separate pattern which removes op if converted types are the same.
ping |
Looks generally ok to me but is outside of my area of expertise (especially with regards to what types the vector lowering supports etc). Is there an owner that could be assigned for a review? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there're ways to deal with this that don't require messing with VectorPattern.cpp
@@ -103,6 +103,11 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors( | |||
return success(); | |||
} | |||
|
|||
static bool isVectorCompatibleType(Type type) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of fiddling, can you just set PatternBenefit
on the bitcast pattern?
@@ -54,6 +54,25 @@ struct ConstrainedVectorConvertToLLVMPattern | |||
} | |||
}; | |||
|
|||
/// No-op bitcast. Propagate type input arg if converted source and dest types | |||
/// are the same. | |||
struct IdentityBitcastLowering final |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think, to get the behavior you want here where identity bitcast on memref gets folded away before the general pattern for arith.bitcast
kicks in, you want to set a PatternBenefit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point about tightening up the bounds on a utility that shouldn't be used for stuff that doesn't work.
What I was worried about is that the one-to-one pattern might be used outside of the arith
lowerings where being able to deal with a struct
could come up
At the very least, we should allow pointers.
But, in the interests of not blocking things and of it not being an unreasonable tightening, approved.
88d92bf
to
01d05df
Compare
Added pointer support, but it would most likely will never be triggered as the only upstream usages are arith/math lowering. |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/177/builds/12872 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/10142 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/10206 Here is the relevant piece of the build log for the reference
|
This reverts commit 79010e2.
Reverts #125148 bot failures
…#126895) Reverts llvm/llvm-project#125148 bot failures
@krzysz00 You were actually right, it caused issues with |
`arith.bitcast` is allowed on memrefs and such code can actually be generated by IREE `ConvertBf16ArithToF32Pass`. `LLVM::detail::vectorOneToOneRewrite` doesn't properly check its types and will generate bitcast between structs which is illegal. With the opaque pointers this is a no-op operation for memref so we can just add a separate pattern which removes op if converted types are the same.
Updated PR #126939 |
…25148) (#126939) Reland llvm/llvm-project#125148 Limiting vector pattern caused issues with `select` of complex lowering, which wasn't caught as it was missing lit tests. Keep the pattern as is for now and instead set a higher benefit to `IdentityBitcastLowering` so it will always run before the vector pattern.
`arith.bitcast` is allowed on memrefs and such code can actually be generated by IREE `ConvertBf16ArithToF32Pass`. `LLVM::detail::vectorOneToOneRewrite` doesn't properly check its types and will generate bitcast between structs which is illegal. With the opaque pointers this is a no-op operation for memref so we can just add type check in `LLVM::detail::vectorOneToOneRewrite` and add a separate pattern which removes op if converted types are the same.
Reverts llvm#125148 bot failures
…lvm#126939) Reland llvm#125148 Limiting vector pattern caused issues with `select` of complex lowering, which wasn't caught as it was missing lit tests. Keep the pattern as is for now and instead set a higher benefit to `IdentityBitcastLowering` so it will always run before the vector pattern.
`arith.bitcast` is allowed on memrefs and such code can actually be generated by IREE `ConvertBf16ArithToF32Pass`. `LLVM::detail::vectorOneToOneRewrite` doesn't properly check its types and will generate bitcast between structs which is illegal. With the opaque pointers this is a no-op operation for memref so we can just add type check in `LLVM::detail::vectorOneToOneRewrite` and add a separate pattern which removes op if converted types are the same.
Reverts llvm#125148 bot failures
…lvm#126939) Reland llvm#125148 Limiting vector pattern caused issues with `select` of complex lowering, which wasn't caught as it was missing lit tests. Keep the pattern as is for now and instead set a higher benefit to `IdentityBitcastLowering` so it will always run before the vector pattern.
`arith.bitcast` is allowed on memrefs and such code can actually be generated by IREE `ConvertBf16ArithToF32Pass`. `LLVM::detail::vectorOneToOneRewrite` doesn't properly check its types and will generate bitcast between structs which is illegal. With the opaque pointers this is a no-op operation for memref so we can just add type check in `LLVM::detail::vectorOneToOneRewrite` and add a separate pattern which removes op if converted types are the same.
Reverts llvm#125148 bot failures
…lvm#126939) Reland llvm#125148 Limiting vector pattern caused issues with `select` of complex lowering, which wasn't caught as it was missing lit tests. Keep the pattern as is for now and instead set a higher benefit to `IdentityBitcastLowering` so it will always run before the vector pattern.
arith.bitcast
is allowed on memrefs and such code can actually be generated by IREEConvertBf16ArithToF32Pass
.LLVM::detail::vectorOneToOneRewrite
doesn't properly check its types and will generate bitcast between structs which is illegal.With the opaque pointers this is a no-op operation for memref so we can just add type check in
LLVM::detail::vectorOneToOneRewrite
and add a separate pattern which removes op if converted types are the same.