-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Reland [mlir] ArithToLLVM: fix memref bitcast lowering (#125148) #126939
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
`arith.bitcast` is allowed on memrefs and such code can actually be generated by IREE `ConvertBf16ArithToF32Pass`. `LLVM::detail::vectorOneToOneRewrite` doesn't properly check its types and will generate bitcast between structs which is illegal. With the opaque pointers this is a no-op operation for memref so we can just add a separate pattern which removes op if converted types are the same.
@llvm/pr-subscribers-mlir Author: Ivan Butygin (Hardcode84) ChangesReland #125148 Limiting vector pattern caused issues with Full diff: https://github.com/llvm/llvm-project/pull/126939.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 754ed89814293..ced18a48766bf 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -54,6 +54,25 @@ struct ConstrainedVectorConvertToLLVMPattern
}
};
+/// No-op bitcast. Propagate type input arg if converted source and dest types
+/// are the same.
+struct IdentityBitcastLowering final
+ : public OpConversionPattern<arith::BitcastOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ Value src = adaptor.getIn();
+ Type resultType = getTypeConverter()->convertType(op.getType());
+ if (src.getType() != resultType)
+ return rewriter.notifyMatchFailure(op, "Types are different");
+
+ rewriter.replaceOp(op, src);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Straightforward Op Lowerings
//===----------------------------------------------------------------------===//
@@ -524,6 +543,12 @@ void mlir::arith::registerConvertArithToLLVMInterface(
void mlir::arith::populateArithToLLVMConversionPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+
+ // Set a higher pattern benefit for IdentityBitcastLowering so it will run
+ // before BitcastOpLowering.
+ patterns.add<IdentityBitcastLowering>(converter, patterns.getContext(),
+ /*patternBenefit*/ 10);
+
// clang-format off
patterns.add<
AddFOpLowering,
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 1dabacfd8a47c..7daf4ef8717bc 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -577,12 +577,26 @@ func.func @cmpi_2dvector(%arg0 : vector<4x3xi32>, %arg1 : vector<4x3xi32>) {
// -----
// CHECK-LABEL: @select
+// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
func.func @select(%arg0 : i1, %arg1 : i32, %arg2 : i32) -> i32 {
- // CHECK: = llvm.select %arg0, %arg1, %arg2 : i1, i32
+ // CHECK: %[[RES:.*]] = llvm.select %[[ARG0]], %[[ARG1]], %[[ARG2]] : i1, i32
+ // CHECK: return %[[RES]]
%0 = arith.select %arg0, %arg1, %arg2 : i32
return %0 : i32
}
+// CHECK-LABEL: @select_complex
+// CHECK-SAME: (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: complex<f32>, %[[ARG2:.*]]: complex<f32>)
+func.func @select_complex(%arg0 : i1, %arg1 : complex<f32>, %arg2 : complex<f32>) -> complex<f32> {
+ // CHECK-DAG: %[[ARGC1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : complex<f32> to !llvm.struct<(f32, f32)>
+ // CHECK-DAG: %[[ARGC2:.*]] = builtin.unrealized_conversion_cast %[[ARG2]] : complex<f32> to !llvm.struct<(f32, f32)>
+ // CHECK: %[[RES:.*]] = llvm.select %[[ARG0]], %[[ARGC1]], %[[ARGC2]] : i1, !llvm.struct<(f32, f32)>
+ // CHECK: %[[RESC:.*]] = builtin.unrealized_conversion_cast %[[RES]] : !llvm.struct<(f32, f32)> to complex<f32>
+ // CHECK: return %[[RESC]]
+ %0 = arith.select %arg0, %arg1, %arg2 : complex<f32>
+ return %0 : complex<f32>
+}
+
// -----
// CHECK-LABEL: @ceildivsi
@@ -727,3 +741,15 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
%3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
return
}
+
+// -----
+
+// CHECK-LABEL: func @memref_bitcast
+// CHECK-SAME: (%[[ARG:.*]]: memref<?xi16>)
+// CHECK: %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?xi16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+// CHECK: %[[V2:.*]] = builtin.unrealized_conversion_cast %[[V1]] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<?xbf16>
+// CHECK: return %[[V2]]
+func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
+ %2 = arith.bitcast %1 : memref<?xi16> to memref<?xbf16>
+ func.return %2 : memref<?xbf16>
+}
|
…lvm#126939) Reland llvm#125148 Limiting vector pattern caused issues with `select` of complex lowering, which wasn't caught as it was missing lit tests. Keep the pattern as is for now and instead set a higher benefit to `IdentityBitcastLowering` so it will always run before the vector pattern.
…lvm#126939) Reland llvm#125148 Limiting vector pattern caused issues with `select` of complex lowering, which wasn't caught as it was missing lit tests. Keep the pattern as is for now and instead set a higher benefit to `IdentityBitcastLowering` so it will always run before the vector pattern.
…lvm#126939) Reland llvm#125148 Limiting vector pattern caused issues with `select` of complex lowering, which wasn't caught as it was missing lit tests. Keep the pattern as is for now and instead set a higher benefit to `IdentityBitcastLowering` so it will always run before the vector pattern.
Reland #125148
Limiting vector pattern caused issues with
select
of complex lowering, which wasn't caught as it was missing lit tests. Keep the pattern as is for now and instead set a higher benefit toIdentityBitcastLowering
so it will always run before the vector pattern.