Skip to content

[MLIR][LLVM][Mem2Reg] Relax type equality requirement for load and store #87637

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 5, 2024

Conversation

Dinistro
Copy link
Contributor

@Dinistro Dinistro commented Apr 4, 2024

This commit relaxes Mem2Reg's type equality requirement for the LLVM dialect's load and store operations. For now, we only allow loads to be promoted if the reaching definition can be casted into a value of the target type.

For stores, the same conversion casting check is applied and we ensure that their result is properly casted to the type of the memory slot. This is necessary to satisfy assumptions of the general mem2reg pass, as it creates block arguments with the types of the memory slot.

This relands #87504

This commit relaxes Mem2Reg's type equality requirement for the LLVM
dialect's load and store operations. For now, we only allow loads to be
promoted if the reaching definition can be casted into a value of the
target type.

For stores, the same conversion casting check is applied and we ensure
that their result is properly casted to the type of the memory slot.
This is necessary to satisfy assumptions of the general mem2reg pass, as
it creates block arguments with the types of the memory slot.
@llvmbot
Copy link
Member

llvmbot commented Apr 4, 2024

@llvm/pr-subscribers-mlir-llvm

Author: Christian Ulmann (Dinistro)

Changes

This commit relaxes Mem2Reg's type equality requirement for the LLVM dialect's load and store operations. For now, we only allow loads to be promoted if the reaching definition can be casted into a value of the target type.

For stores, the same conversion casting check is applied and we ensure that their result is properly casted to the type of the memory slot. This is necessary to satisfy assumptions of the general mem2reg pass, as it creates block arguments with the types of the memory slot.

This relands #87504


Full diff: https://github.com/llvm/llvm-project/pull/87637.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp (+36-5)
  • (modified) mlir/test/Dialect/LLVMIR/mem2reg.mlir (+225)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 06c1fdd2eb2d95..ec4000662e871c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -122,8 +122,37 @@ bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
   return getAddr() == slot.ptr;
 }
 
+/// Checks that two types are the same or can be cast into one another.
+static bool areCastCompatible(const DataLayout &layout, Type lhs, Type rhs) {
+  return lhs == rhs || (!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(lhs) &&
+                        !isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(rhs) &&
+                        layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
+}
+
+/// Constructs operations that convert `inputValue` into a new value of type
+/// `targetType`. Assumes that this conversion is possible.
+static Value createConversionSequence(RewriterBase &rewriter, Location loc,
+                                      Value inputValue, Type targetType) {
+  if (inputValue.getType() == targetType)
+    return inputValue;
+
+  if (!isa<LLVM::LLVMPointerType>(targetType) &&
+      !isa<LLVM::LLVMPointerType>(inputValue.getType()))
+    return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, inputValue);
+
+  if (!isa<LLVM::LLVMPointerType>(targetType))
+    return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, targetType, inputValue);
+
+  if (!isa<LLVM::LLVMPointerType>(inputValue.getType()))
+    return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, inputValue);
+
+  return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
+                                                      inputValue);
+}
+
 Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
-  return getValue();
+  return createConversionSequence(rewriter, getLoc(), getValue(),
+                                  slot.elemType);
 }
 
 bool LLVM::LoadOp::canUsesBeRemoved(
@@ -138,7 +167,8 @@ bool LLVM::LoadOp::canUsesBeRemoved(
   // be removed (provided it loads the exact stored value and is not
   // volatile).
   return blockingUse == slot.ptr && getAddr() == slot.ptr &&
-         getResult().getType() == slot.elemType && !getVolatile_();
+         areCastCompatible(dataLayout, getResult().getType(), slot.elemType) &&
+         !getVolatile_();
 }
 
 DeletionKind LLVM::LoadOp::removeBlockingUses(
@@ -146,7 +176,9 @@ DeletionKind LLVM::LoadOp::removeBlockingUses(
     RewriterBase &rewriter, Value reachingDefinition) {
   // `canUsesBeRemoved` checked this blocking use must be the loaded slot
   // pointer.
-  rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
+  Value newResult = createConversionSequence(
+      rewriter, getLoc(), reachingDefinition, getResult().getType());
+  rewriter.replaceAllUsesWith(getResult(), newResult);
   return DeletionKind::Delete;
 }
 
@@ -161,8 +193,7 @@ bool LLVM::StoreOp::canUsesBeRemoved(
   // fine, provided we are currently promoting its target value. Don't allow a
   // store OF the slot pointer, only INTO the slot pointer.
   return blockingUse == slot.ptr && getAddr() == slot.ptr &&
-         getValue() != slot.ptr && getValue().getType() == slot.elemType &&
-         !getVolatile_();
+         getValue() != slot.ptr && !getVolatile_();
 }
 
 DeletionKind LLVM::StoreOp::removeBlockingUses(
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 90e56c1166edfd..883142a92f12e3 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -697,3 +697,228 @@ llvm.func @transitive_reaching_def() -> !llvm.ptr {
   %3 = llvm.load %1 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
   llvm.return %3 : !llvm.ptr
 }
+
+// -----
+
+// CHECK-LABEL: @load_int_from_float
+llvm.func @load_int_from_float() -> i32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x f32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+  // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
+  // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : f32 to i32
+  // CHECK: llvm.return %[[BITCAST:.*]]
+  llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @load_float_from_int
+llvm.func @load_float_from_int() -> f32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
+  // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
+  // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : i32 to f32
+  // CHECK: llvm.return %[[BITCAST:.*]]
+  llvm.return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @load_int_from_vector
+llvm.func @load_int_from_vector() -> i32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x vector<2xi16> : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+  // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
+  // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : vector<2xi16> to i32
+  // CHECK: llvm.return %[[BITCAST:.*]]
+  llvm.return %2 : i32
+}
+
+// -----
+
+// LLVM arrays cannot be bitcasted, so the following cannot be promoted.
+
+// CHECK-LABEL: @load_int_from_array
+llvm.func @load_int_from_array() -> i32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: llvm.alloca
+  %1 = llvm.alloca %0 x !llvm.array<2 x i16> : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+  // CHECK-NOT: llvm.bitcast
+  llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @store_int_to_float
+// CHECK-SAME: %[[ARG:.*]]: i32
+llvm.func @store_int_to_float(%arg: i32) -> i32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x f32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  llvm.store %arg, %1 {alignment = 4 : i64} : i32, !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+  // CHECK: llvm.return %[[ARG]]
+  llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @store_float_to_int
+// CHECK-SAME: %[[ARG:.*]]: f32
+llvm.func @store_float_to_int(%arg: f32) -> i32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  llvm.store %arg, %1 {alignment = 4 : i64} : f32, !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+  // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : f32 to i32
+  // CHECK: llvm.return %[[BITCAST]]
+  llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @store_int_to_vector
+// CHECK-SAME: %[[ARG:.*]]: i32
+llvm.func @store_int_to_vector(%arg: i32) -> vector<4xi8> {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x vector<2xi16> {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  llvm.store %arg, %1 {alignment = 4 : i64} : i32, !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> vector<4xi8>
+  // CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG]] : i32 to vector<2xi16>
+  // CHECK: %[[BITCAST1:.*]] = llvm.bitcast %[[BITCAST0]] : vector<2xi16> to vector<4xi8>
+  // CHECK: llvm.return %[[BITCAST1]]
+  llvm.return %2 : vector<4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @load_ptr_from_int
+llvm.func @load_ptr_from_int() -> !llvm.ptr {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr
+  // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
+  // CHECK: %[[CAST:.*]] = llvm.inttoptr %[[UNDEF]] : i64 to !llvm.ptr
+  // CHECK: llvm.return %[[CAST:.*]]
+  llvm.return %2 : !llvm.ptr
+}
+
+// -----
+
+// CHECK-LABEL: @load_int_from_ptr
+llvm.func @load_int_from_ptr() -> i64 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x !llvm.ptr {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i64
+  // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
+  // CHECK: %[[CAST:.*]] = llvm.ptrtoint %[[UNDEF]] : !llvm.ptr to i64
+  // CHECK: llvm.return %[[CAST:.*]]
+  llvm.return %2 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @load_ptr_addrspace_cast
+llvm.func @load_ptr_addrspace_cast() -> !llvm.ptr<2> {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x !llvm.ptr<1> {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr<2>
+  // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
+  // CHECK: %[[CAST:.*]] = llvm.addrspacecast %[[UNDEF]] : !llvm.ptr<1> to !llvm.ptr<2>
+  // CHECK: llvm.return %[[CAST:.*]]
+  llvm.return %2 : !llvm.ptr<2>
+}
+
+// -----
+
+// CHECK-LABEL: @store_with_different_types
+// CHECK-SAME: %[[ARG0:.*]]: i64
+// CHECK-SAME: %[[ARG1:.*]]: f64
+llvm.func @store_with_different_types(%arg0: i64, %arg1: f64, %cond: i1) -> f64 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK-NOT: llvm.alloca
+  %1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  llvm.cond_br %cond, ^bb1, ^bb2
+^bb1:
+  llvm.store %arg0, %1 {alignment = 4 : i64} : i64, !llvm.ptr
+  // CHECK: llvm.br ^[[BB3:.*]](%[[ARG0]]
+  llvm.br ^bb3
+^bb2:
+  llvm.store %arg1, %1 {alignment = 4 : i64} : f64, !llvm.ptr
+  // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG1]] : f64 to i64
+  // CHECK: llvm.br ^[[BB3]](%[[BITCAST]]
+  llvm.br ^bb3
+// CHECK: ^[[BB3]](%[[BLOCK_ARG:.*]]: i64)
+^bb3:
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
+  // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[BLOCK_ARG]] : i64 to f64
+  // CHECK: llvm.return %[[BITCAST]]
+  llvm.return %2 : f64
+}
+
+// -----
+
+// CHECK-LABEL: @load_smaller_int
+llvm.func @load_smaller_int() -> i16 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: llvm.alloca
+  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i16
+  llvm.return %2 : i16
+}
+
+// -----
+
+// CHECK-LABEL: @load_different_type_smaller
+llvm.func @load_different_type_smaller() -> f32 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: llvm.alloca
+  %1 = llvm.alloca %0 x i64 {alignment = 8 : i64} : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
+  llvm.return %2 : f32
+}
+
+// -----
+
+// This alloca is too small for the load, still, mem2reg should not touch it.
+
+// CHECK-LABEL: @impossible_load
+llvm.func @impossible_load() -> f64 {
+  %0 = llvm.mlir.constant(1 : i32) : i32
+  // CHECK: llvm.alloca
+  %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+  %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
+  llvm.return %2 : f64
+}
+
+// -----
+
+// Verifies that mem2reg does not introduce address space casts of pointers
+// with different bitsize.
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+  #dlti.dl_entry<!llvm.ptr<1>, dense<[32, 64, 64]> : vector<3xi64>>,
+  #dlti.dl_entry<!llvm.ptr<2>, dense<[64, 64, 64]> : vector<3xi64>>
+>} {
+
+  // CHECK-LABEL: @load_ptr_addrspace_cast_different_size
+  llvm.func @load_ptr_addrspace_cast_different_size() -> !llvm.ptr<2> {
+    %0 = llvm.mlir.constant(1 : i32) : i32
+    // CHECK: llvm.alloca
+    %1 = llvm.alloca %0 x !llvm.ptr<1> {alignment = 4 : i64} : (i32) -> !llvm.ptr
+    %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr<2>
+    llvm.return %2 : !llvm.ptr<2>
+  }
+}

@@ -161,8 +193,7 @@ bool LLVM::StoreOp::canUsesBeRemoved(
// fine, provided we are currently promoting its target value. Don't allow a
// store OF the slot pointer, only INTO the slot pointer.
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
getValue() != slot.ptr && getValue().getType() == slot.elemType &&
!getVolatile_();
getValue() != slot.ptr && !getVolatile_();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be a check for cast compatibility here as well? I assume the conversion in getStored has to succeed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch. Working with a additional changes in the same files is a bad idea...
At least this shows that we should add a test for this case to ensure that this is detected.

@Dinistro Dinistro requested a review from gysit April 5, 2024 06:07
Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks LGTM!

@Dinistro Dinistro merged commit 974f1ee into main Apr 5, 2024
@Dinistro Dinistro deleted the users/dinistro/mem2reg-inconsistent-type-support branch April 5, 2024 06:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants