Skip to content

Revert "[mlir] ArithToLLVM: fix memref bitcast lowering" #126895

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 12, 2025

Conversation

Hardcode84
Copy link
Contributor

Reverts #125148

bot failures

@Hardcode84 Hardcode84 merged commit 0e779ad into main Feb 12, 2025
8 of 11 checks passed
@Hardcode84 Hardcode84 deleted the revert-125148-fix-memref-bitcast-llvm branch February 12, 2025 11:34
@llvmbot
Copy link
Member

llvmbot commented Feb 12, 2025

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

Reverts llvm/llvm-project#125148

bot failures


Full diff: https://github.com/llvm/llvm-project/pull/126895.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp (-22)
  • (modified) mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp (+1-9)
  • (modified) mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir (-12)
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 5c1afe8034c73..754ed89814293 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -54,25 +54,6 @@ struct ConstrainedVectorConvertToLLVMPattern
   }
 };
 
-/// No-op bitcast. Propagate type input arg if converted source and dest types
-/// are the same.
-struct IdentityBitcastLowering final
-    : public OpConversionPattern<arith::BitcastOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    Value src = adaptor.getIn();
-    Type resultType = getTypeConverter()->convertType(op.getType());
-    if (src.getType() != resultType)
-      return rewriter.notifyMatchFailure(op, "Types are different");
-
-    rewriter.replaceOp(op, src);
-    return success();
-  }
-};
-
 //===----------------------------------------------------------------------===//
 // Straightforward Op Lowerings
 //===----------------------------------------------------------------------===//
@@ -543,9 +524,6 @@ void mlir::arith::registerConvertArithToLLVMInterface(
 
 void mlir::arith::populateArithToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
-
-  patterns.add<IdentityBitcastLowering>(converter, patterns.getContext());
-
   // clang-format off
   patterns.add<
     AddFOpLowering,
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index fe4781138fa29..bf3f31729c3da 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -103,14 +103,6 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
   return success();
 }
 
-static bool isVectorCompatibleType(Type type) {
-  // Limit `vectorOneToOneRewrite` to scalar and vector types (and to
-  // `LLVM::LLVMArrayType` which have a special handling).
-  return isa<LLVM::LLVMArrayType, LLVM::LLVMPointerType, VectorType,
-             IntegerType, FloatType>(type) &&
-         LLVM::isCompatibleType(type);
-}
-
 LogicalResult LLVM::detail::vectorOneToOneRewrite(
     Operation *op, StringRef targetOp, ValueRange operands,
     ArrayRef<NamedAttribute> targetAttrs,
@@ -119,7 +111,7 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
   assert(!operands.empty());
 
   // Cannot convert ops if their operands are not of LLVM type.
-  if (!llvm::all_of(operands.getTypes(), isVectorCompatibleType))
+  if (!llvm::all_of(operands.getTypes(), isCompatibleType))
     return failure();
 
   auto llvmNDVectorTy = operands[0].getType();
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 9a6c4bca88f3b..1dabacfd8a47c 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -727,15 +727,3 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
   %3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
   return
 }
-
-// -----
-
-// CHECK-LABEL: func @memref_bitcast
-//  CHECK-SAME:   (%[[ARG:.*]]: memref<?xi16>)
-//       CHECK:   %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?xi16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-//       CHECK:   %[[V2:.*]] = builtin.unrealized_conversion_cast %[[V1]] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<?xbf16>
-//       CHECK:   return %[[V2]]
-func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
-  %2 = arith.bitcast %1 : memref<?xi16> to memref<?xbf16>
-  func.return %2 : memref<?xbf16>
-}

@llvmbot
Copy link
Member

llvmbot commented Feb 12, 2025

@llvm/pr-subscribers-mlir-llvm

Author: Ivan Butygin (Hardcode84)

Changes

Reverts llvm/llvm-project#125148

bot failures


Full diff: https://github.com/llvm/llvm-project/pull/126895.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp (-22)
  • (modified) mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp (+1-9)
  • (modified) mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir (-12)
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 5c1afe8034c73..754ed89814293 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -54,25 +54,6 @@ struct ConstrainedVectorConvertToLLVMPattern
   }
 };
 
-/// No-op bitcast. Propagate type input arg if converted source and dest types
-/// are the same.
-struct IdentityBitcastLowering final
-    : public OpConversionPattern<arith::BitcastOp> {
-  using OpConversionPattern::OpConversionPattern;
-
-  LogicalResult
-  matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const final {
-    Value src = adaptor.getIn();
-    Type resultType = getTypeConverter()->convertType(op.getType());
-    if (src.getType() != resultType)
-      return rewriter.notifyMatchFailure(op, "Types are different");
-
-    rewriter.replaceOp(op, src);
-    return success();
-  }
-};
-
 //===----------------------------------------------------------------------===//
 // Straightforward Op Lowerings
 //===----------------------------------------------------------------------===//
@@ -543,9 +524,6 @@ void mlir::arith::registerConvertArithToLLVMInterface(
 
 void mlir::arith::populateArithToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
-
-  patterns.add<IdentityBitcastLowering>(converter, patterns.getContext());
-
   // clang-format off
   patterns.add<
     AddFOpLowering,
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index fe4781138fa29..bf3f31729c3da 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -103,14 +103,6 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
   return success();
 }
 
-static bool isVectorCompatibleType(Type type) {
-  // Limit `vectorOneToOneRewrite` to scalar and vector types (and to
-  // `LLVM::LLVMArrayType` which have a special handling).
-  return isa<LLVM::LLVMArrayType, LLVM::LLVMPointerType, VectorType,
-             IntegerType, FloatType>(type) &&
-         LLVM::isCompatibleType(type);
-}
-
 LogicalResult LLVM::detail::vectorOneToOneRewrite(
     Operation *op, StringRef targetOp, ValueRange operands,
     ArrayRef<NamedAttribute> targetAttrs,
@@ -119,7 +111,7 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
   assert(!operands.empty());
 
   // Cannot convert ops if their operands are not of LLVM type.
-  if (!llvm::all_of(operands.getTypes(), isVectorCompatibleType))
+  if (!llvm::all_of(operands.getTypes(), isCompatibleType))
     return failure();
 
   auto llvmNDVectorTy = operands[0].getType();
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 9a6c4bca88f3b..1dabacfd8a47c 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -727,15 +727,3 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
   %3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
   return
 }
-
-// -----
-
-// CHECK-LABEL: func @memref_bitcast
-//  CHECK-SAME:   (%[[ARG:.*]]: memref<?xi16>)
-//       CHECK:   %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?xi16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
-//       CHECK:   %[[V2:.*]] = builtin.unrealized_conversion_cast %[[V1]] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<?xbf16>
-//       CHECK:   return %[[V2]]
-func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
-  %2 = arith.bitcast %1 : memref<?xi16> to memref<?xbf16>
-  func.return %2 : memref<?xbf16>
-}

flovent pushed a commit to flovent/llvm-project that referenced this pull request Feb 13, 2025
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants