Skip to content

Revert "[MLIR][LLVM][Mem2Reg] Relax type equality requirement for load and store (#87504)" #87631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

Dinistro
Copy link
Contributor

@Dinistro Dinistro commented Apr 4, 2024

This reverts commit d6e4582 as it violates an assumption of Mem2Reg's block argument creation. Mem2Reg strongly assumes that all involved values have the same type as the alloca, which was relaxed by this PR. Therefore, branches got created that jumped to basic blocks with differently typed block arguments.

…d and store (#87504)"

This reverts commit d6e4582 as it
violates an assumption of Mem2Regs block argument creation.
@llvmbot
Copy link
Member

llvmbot commented Apr 4, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Christian Ulmann (Dinistro)

Changes

…d and store (#87504)"

This reverts commit d6e4582 as it violates an assumption of Mem2Regs block argument creation.


Full diff: https://github.com/llvm/llvm-project/pull/87631.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp (+4-34)
  • (modified) mlir/test/Dialect/LLVMIR/mem2reg.mlir (-197)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index f8fd3bdaff4a73..06c1fdd2eb2d95 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -126,13 +126,6 @@ Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
   return getValue();
 }
 
-/// Checks that two types are the same or can be cast into one another.
-static bool areCastCompatible(const DataLayout &layout, Type lhs, Type rhs) {
-  return lhs == rhs || (!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(lhs) &&
-                        !isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(rhs) &&
-                        layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
-}
-
 bool LLVM::LoadOp::canUsesBeRemoved(
     const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
     SmallVectorImpl<OpOperand *> &newBlockingUses,
@@ -145,29 +138,7 @@ bool LLVM::LoadOp::canUsesBeRemoved(
   // be removed (provided it loads the exact stored value and is not
   // volatile).
   return blockingUse == slot.ptr && getAddr() == slot.ptr &&
-         areCastCompatible(dataLayout, getResult().getType(), slot.elemType) &&
-         !getVolatile_();
-}
-
-/// Constructs operations that convert `inputValue` into a new value of type
-/// `targetType`. Assumes that this conversion is possible.
-static Value createConversionSequence(RewriterBase &rewriter, Location loc,
-                                      Value inputValue, Type targetType) {
-  if (inputValue.getType() == targetType)
-    return inputValue;
-
-  if (!isa<LLVM::LLVMPointerType>(targetType) &&
-      !isa<LLVM::LLVMPointerType>(inputValue.getType()))
-    return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, inputValue);
-
-  if (!isa<LLVM::LLVMPointerType>(targetType))
-    return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, targetType, inputValue);
-
-  if (!isa<LLVM::LLVMPointerType>(inputValue.getType()))
-    return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, inputValue);
-
-  return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
-                                                      inputValue);
+         getResult().getType() == slot.elemType && !getVolatile_();
 }
 
 DeletionKind LLVM::LoadOp::removeBlockingUses(
@@ -175,9 +146,7 @@ DeletionKind LLVM::LoadOp::removeBlockingUses(
     RewriterBase &rewriter, Value reachingDefinition) {
   // `canUsesBeRemoved` checked this blocking use must be the loaded slot
   // pointer.
-  Value newResult = createConversionSequence(
-      rewriter, getLoc(), reachingDefinition, getResult().getType());
-  rewriter.replaceAllUsesWith(getResult(), newResult);
+  rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
   return DeletionKind::Delete;
 }
 
@@ -192,7 +161,8 @@ bool LLVM::StoreOp::canUsesBeRemoved(
   // fine, provided we are currently promoting its target value. Don't allow a
   // store OF the slot pointer, only INTO the slot pointer.
   return blockingUse == slot.ptr && getAddr() == slot.ptr &&
-         getValue() != slot.ptr && !getVolatile_();
+         getValue() != slot.ptr && getValue().getType() == slot.elemType &&
+         !getVolatile_();
 }
 
 DeletionKind LLVM::StoreOp::removeBlockingUses(
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index d6d5e1bdc93c76..90e56c1166edfd 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -697,200 +697,3 @@ llvm.func @transitive_reaching_def() -> !llvm.ptr {
   %3 = llvm.load %1 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
   llvm.return %3 : !llvm.ptr
 }
-
-// -----
-
-// CHECK-LABEL: @load_int_from_float
-llvm.func @load_int_from_float() -> i32 {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK-NOT: llvm.alloca
-  %1 = llvm.alloca %0 x f32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
-  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
-  // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
-  // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : f32 to i32
-  // CHECK: llvm.return %[[BITCAST:.*]]
-  llvm.return %2 : i32
-}
-
-// -----
-
-// CHECK-LABEL: @load_float_from_int
-llvm.func @load_float_from_int() -> f32 {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK-NOT: llvm.alloca
-  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
-  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
-  // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
-  // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : i32 to f32
-  // CHECK: llvm.return %[[BITCAST:.*]]
-  llvm.return %2 : f32
-}
-
-// -----
-
-// CHECK-LABEL: @load_int_from_vector
-llvm.func @load_int_from_vector() -> i32 {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK-NOT: llvm.alloca
-  %1 = llvm.alloca %0 x vector<2xi16> : (i32) -> !llvm.ptr
-  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
-  // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
-  // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : vector<2xi16> to i32
-  // CHECK: llvm.return %[[BITCAST:.*]]
-  llvm.return %2 : i32
-}
-
-// -----
-
-// LLVM arrays cannot be bitcasted, so the following cannot be promoted.
-
-// CHECK-LABEL: @load_int_from_array
-llvm.func @load_int_from_array() -> i32 {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: llvm.alloca
-  %1 = llvm.alloca %0 x !llvm.array<2 x i16> : (i32) -> !llvm.ptr
-  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
-  // CHECK-NOT: llvm.bitcast
-  llvm.return %2 : i32
-}
-
-// -----
-
-// CHECK-LABEL: @store_int_to_float
-// CHECK-SAME: %[[ARG:.*]]: i32
-llvm.func @store_int_to_float(%arg: i32) -> i32 {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK-NOT: llvm.alloca
-  %1 = llvm.alloca %0 x f32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
-  llvm.store %arg, %1 {alignment = 4 : i64} : i32, !llvm.ptr
-  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
-  // CHECK: llvm.return %[[ARG]]
-  llvm.return %2 : i32
-}
-
-// -----
-
-// CHECK-LABEL: @store_float_to_int
-// CHECK-SAME: %[[ARG:.*]]: f32
-llvm.func @store_float_to_int(%arg: f32) -> i32 {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK-NOT: llvm.alloca
-  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
-  llvm.store %arg, %1 {alignment = 4 : i64} : f32, !llvm.ptr
-  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
-  // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : f32 to i32
-  // CHECK: llvm.return %[[BITCAST]]
-  llvm.return %2 : i32
-}
-
-// -----
-
-// CHECK-LABEL: @store_int_to_vector
-// CHECK-SAME: %[[ARG:.*]]: i32
-llvm.func @store_int_to_vector(%arg: i32) -> vector<4xi8> {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK-NOT: llvm.alloca
-  %1 = llvm.alloca %0 x vector<2xi16> {alignment = 4 : i64} : (i32) -> !llvm.ptr
-  llvm.store %arg, %1 {alignment = 4 : i64} : i32, !llvm.ptr
-  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> vector<4xi8>
-  // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : i32 to vector<4xi8>
-  // CHECK: llvm.return %[[BITCAST]]
-  llvm.return %2 : vector<4xi8>
-}
-
-// -----
-
-// CHECK-LABEL: @load_ptr_from_int
-llvm.func @load_ptr_from_int() -> !llvm.ptr {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK-NOT: llvm.alloca
-  %1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
-  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr
-  // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
-  // CHECK: %[[CAST:.*]] = llvm.inttoptr %[[UNDEF]] : i64 to !llvm.ptr
-  // CHECK: llvm.return %[[CAST:.*]]
-  llvm.return %2 : !llvm.ptr
-}
-
-// -----
-
-// CHECK-LABEL: @load_int_from_ptr
-llvm.func @load_int_from_ptr() -> i64 {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK-NOT: llvm.alloca
-  %1 = llvm.alloca %0 x !llvm.ptr {alignment = 4 : i64} : (i32) -> !llvm.ptr
-  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i64
-  // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
-  // CHECK: %[[CAST:.*]] = llvm.ptrtoint %[[UNDEF]] : !llvm.ptr to i64
-  // CHECK: llvm.return %[[CAST:.*]]
-  llvm.return %2 : i64
-}
-
-// -----
-
-// CHECK-LABEL: @load_ptr_addrspace_cast
-llvm.func @load_ptr_addrspace_cast() -> !llvm.ptr<2> {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK-NOT: llvm.alloca
-  %1 = llvm.alloca %0 x !llvm.ptr<1> {alignment = 4 : i64} : (i32) -> !llvm.ptr
-  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr<2>
-  // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
-  // CHECK: %[[CAST:.*]] = llvm.addrspacecast %[[UNDEF]] : !llvm.ptr<1> to !llvm.ptr<2>
-  // CHECK: llvm.return %[[CAST:.*]]
-  llvm.return %2 : !llvm.ptr<2>
-}
-
-// -----
-
-// CHECK-LABEL: @load_smaller_int
-llvm.func @load_smaller_int() -> i16 {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: llvm.alloca
-  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
-  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i16
-  llvm.return %2 : i16
-}
-
-// -----
-
-// CHECK-LABEL: @load_different_type_smaller
-llvm.func @load_different_type_smaller() -> f32 {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: llvm.alloca
-  %1 = llvm.alloca %0 x i64 {alignment = 8 : i64} : (i32) -> !llvm.ptr
-  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
-  llvm.return %2 : f32
-}
-
-// -----
-
-// This alloca is too small for the load, still, mem2reg should not touch it.
-
-// CHECK-LABEL: @impossible_load
-llvm.func @impossible_load() -> f64 {
-  %0 = llvm.mlir.constant(1 : i32) : i32
-  // CHECK: llvm.alloca
-  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
-  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
-  llvm.return %2 : f64
-}
-
-// -----
-
-// Verifies that mem2reg does not introduce address space casts of pointers
-// with different bitsize.
-
-module attributes { dlti.dl_spec = #dlti.dl_spec<
-  #dlti.dl_entry<!llvm.ptr<1>, dense<[32, 64, 64]> : vector<3xi64>>,
-  #dlti.dl_entry<!llvm.ptr<2>, dense<[64, 64, 64]> : vector<3xi64>>
->} {
-
-  // CHECK-LABEL: @load_ptr_addrspace_cast_different_size
-  llvm.func @load_ptr_addrspace_cast_different_size() -> !llvm.ptr<2> {
-    %0 = llvm.mlir.constant(1 : i32) : i32
-    // CHECK: llvm.alloca
-    %1 = llvm.alloca %0 x !llvm.ptr<1> {alignment = 4 : i64} : (i32) -> !llvm.ptr
-    %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr<2>
-    llvm.return %2 : !llvm.ptr<2>
-  }
-}

@joker-eph joker-eph changed the title Revert "[MLIR][LLVM][Mem2Reg] Relax type equality requirement for loa… Revert "[MLIR][LLVM][Mem2Reg] Relax type equality requirement for load and store (#87504) Apr 4, 2024
@Dinistro Dinistro changed the title Revert "[MLIR][LLVM][Mem2Reg] Relax type equality requirement for load and store (#87504) Revert "[MLIR][LLVM][Mem2Reg] Relax type equality requirement for load and store (#87504)" Apr 4, 2024
@joker-eph
Copy link
Collaborator

This reverts commit d6e4582 as it violates an assumption of Mem2Regs block argument creation.

Can you make the assumption explicit in the description?

Also are there tests broken? An assertions? Otherwise can we add a test showing the problem?

@Dinistro
Copy link
Contributor Author

Dinistro commented Apr 4, 2024

This reverts commit d6e4582 as it violates an assumption of Mem2Regs block argument creation.

Can you make the assumption explicit in the description?

Also are there tests broken? An assertions? Otherwise can we add a test showing the problem?

I'll updated the description and will work on an updated revision that fixes the broken cases instead of shipping a PR that extends the test set, if that's fine for you.

@joker-eph
Copy link
Collaborator

Thanks for elaborating on the description!

@Dinistro Dinistro merged commit e0e615e into main Apr 4, 2024
@Dinistro Dinistro deleted the users/dinistro/revert-mem2reg-inconsistent-type-support branch April 4, 2024 13:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants