Skip to content

Reland [mlir] ArithToLLVM: fix memref bitcast lowering (#125148) #126939

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 12, 2025

Conversation

Hardcode84
Copy link
Contributor

Reland #125148

Limiting vector pattern caused issues with select of complex lowering, which wasn't caught as it was missing lit tests. Keep the pattern as is for now and instead set a higher benefit to IdentityBitcastLowering so it will always run before the vector pattern.

`arith.bitcast` is allowed on memrefs and such code can actually be
generated by IREE `ConvertBf16ArithToF32Pass`.
`LLVM::detail::vectorOneToOneRewrite` doesn't properly check its types
and will generate bitcast between structs which is illegal.

With the opaque pointers this is a no-op operation for memref so we can
just add a separate pattern which removes op if converted types are the same.
@llvmbot
Copy link
Member

llvmbot commented Feb 12, 2025

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

Reland #125148

Limiting vector pattern caused issues with select of complex lowering, which wasn't caught as it was missing lit tests. Keep the pattern as is for now and instead set a higher benefit to IdentityBitcastLowering so it will always run before the vector pattern.


Full diff: https://github.com/llvm/llvm-project/pull/126939.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp (+25)
  • (modified) mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir (+27-1)
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 754ed89814293..ced18a48766bf 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -54,6 +54,25 @@ struct ConstrainedVectorConvertToLLVMPattern
   }
 };
 
+/// No-op bitcast. Propagate type input arg if converted source and dest types
+/// are the same.
+struct IdentityBitcastLowering final
+    : public OpConversionPattern<arith::BitcastOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    Value src = adaptor.getIn();
+    Type resultType = getTypeConverter()->convertType(op.getType());
+    if (src.getType() != resultType)
+      return rewriter.notifyMatchFailure(op, "Types are different");
+
+    rewriter.replaceOp(op, src);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Straightforward Op Lowerings
 //===----------------------------------------------------------------------===//
@@ -524,6 +543,12 @@ void mlir::arith::registerConvertArithToLLVMInterface(
 
 void mlir::arith::populateArithToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+
+  // Set a higher pattern benefit for IdentityBitcastLowering so it will run
+  // before BitcastOpLowering.
+  patterns.add<IdentityBitcastLowering>(converter, patterns.getContext(),
+                                        /*patternBenefit*/ 10);
+
   // clang-format off
   patterns.add<
     AddFOpLowering,
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 1dabacfd8a47c..7daf4ef8717bc 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -577,12 +577,26 @@ func.func @cmpi_2dvector(%arg0 : vector<4x3xi32>, %arg1 : vector<4x3xi32>) {
 // -----
 
 // CHECK-LABEL: @select
+//  CHECK-SAME:  (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32)
 func.func @select(%arg0 : i1, %arg1 : i32, %arg2 : i32) -> i32 {
-  // CHECK: = llvm.select %arg0, %arg1, %arg2 : i1, i32
+  // CHECK: %[[RES:.*]] = llvm.select %[[ARG0]], %[[ARG1]], %[[ARG2]] : i1, i32
+  // CHECK: return %[[RES]]
   %0 = arith.select %arg0, %arg1, %arg2 : i32
   return %0 : i32
 }
 
+// CHECK-LABEL: @select_complex
+//  CHECK-SAME:  (%[[ARG0:.*]]: i1, %[[ARG1:.*]]: complex<f32>, %[[ARG2:.*]]: complex<f32>)
+func.func @select_complex(%arg0 : i1, %arg1 : complex<f32>, %arg2 : complex<f32>) -> complex<f32> {
+  // CHECK-DAG: %[[ARGC1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : complex<f32> to !llvm.struct<(f32, f32)>
+  // CHECK-DAG: %[[ARGC2:.*]] = builtin.unrealized_conversion_cast %[[ARG2]] : complex<f32> to !llvm.struct<(f32, f32)>
+  //     CHECK: %[[RES:.*]] = llvm.select %[[ARG0]], %[[ARGC1]], %[[ARGC2]] : i1, !llvm.struct<(f32, f32)>
+  //     CHECK: %[[RESC:.*]] = builtin.unrealized_conversion_cast %[[RES]] : !llvm.struct<(f32, f32)> to complex<f32>
+  //     CHECK: return %[[RESC]]
+  %0 = arith.select %arg0, %arg1, %arg2 : complex<f32>
+  return %0 : complex<f32>
+}
+
 // -----
 
 // CHECK-LABEL: @ceildivsi
@@ -727,3 +741,15 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
   %3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @memref_bitcast
+//  CHECK-SAME:   (%[[ARG:.*]]: memref<?xi16>)
+//       CHECK:   %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?xi16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+//       CHECK:   %[[V2:.*]] = builtin.unrealized_conversion_cast %[[V1]] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<?xbf16>
+//       CHECK:   return %[[V2]]
+func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
+  %2 = arith.bitcast %1 : memref<?xi16> to memref<?xbf16>
+  func.return %2 : memref<?xbf16>
+}

@Hardcode84 Hardcode84 merged commit e167c31 into llvm:main Feb 12, 2025
10 checks passed
@Hardcode84 Hardcode84 deleted the reland-bitcast branch February 12, 2025 17:32
flovent pushed a commit to flovent/llvm-project that referenced this pull request Feb 13, 2025
…lvm#126939)

Reland llvm#125148

Limiting vector pattern caused issues with `select` of complex lowering,
which wasn't caught as it was missing lit tests. Keep the pattern as is
for now and instead set a higher benefit to `IdentityBitcastLowering` so
it will always run before the vector pattern.
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
…lvm#126939)

Reland llvm#125148

Limiting vector pattern caused issues with `select` of complex lowering,
which wasn't caught as it was missing lit tests. Keep the pattern as is
for now and instead set a higher benefit to `IdentityBitcastLowering` so
it will always run before the vector pattern.
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
…lvm#126939)

Reland llvm#125148

Limiting vector pattern caused issues with `select` of complex lowering,
which wasn't caught as it was missing lit tests. Keep the pattern as is
for now and instead set a higher benefit to `IdentityBitcastLowering` so
it will always run before the vector pattern.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants