-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][LLVM][Mem2Reg] Relax type equality requirement for load and store #87504
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][LLVM][Mem2Reg] Relax type equality requirement for load and store #87504
Conversation
This commit relaxes Mem2Reg's type equality requirement for the LLVM dialect's load and store operations. For now, we only allow loads to be promoted if the reaching definition can be casted into a value of the target type. For stores, all type checks are removed, as a non-volatile store that does not write out the alloca's pointer can always be deleted.
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: Christian Ulmann (Dinistro) ChangesThis commit relaxes Mem2Reg's type equality requirement for the LLVM dialect's load and store operations. For now, we only allow loads to be promoted if the reaching definition can be casted into a value of the target type. For stores, all type checks are removed, as a non-volatile store that does not write out the alloca's pointer can always be deleted. Full diff: https://github.com/llvm/llvm-project/pull/87504.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 06c1fdd2eb2d95..f8fd3bdaff4a73 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -126,6 +126,13 @@ Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
return getValue();
}
+/// Checks that two types are the same or can be cast into one another.
+static bool areCastCompatible(const DataLayout &layout, Type lhs, Type rhs) {
+ return lhs == rhs || (!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(lhs) &&
+ !isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(rhs) &&
+ layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
+}
+
bool LLVM::LoadOp::canUsesBeRemoved(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
@@ -138,7 +145,29 @@ bool LLVM::LoadOp::canUsesBeRemoved(
// be removed (provided it loads the exact stored value and is not
// volatile).
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
- getResult().getType() == slot.elemType && !getVolatile_();
+ areCastCompatible(dataLayout, getResult().getType(), slot.elemType) &&
+ !getVolatile_();
+}
+
+/// Constructs operations that convert `inputValue` into a new value of type
+/// `targetType`. Assumes that this conversion is possible.
+static Value createConversionSequence(RewriterBase &rewriter, Location loc,
+ Value inputValue, Type targetType) {
+ if (inputValue.getType() == targetType)
+ return inputValue;
+
+ if (!isa<LLVM::LLVMPointerType>(targetType) &&
+ !isa<LLVM::LLVMPointerType>(inputValue.getType()))
+ return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, inputValue);
+
+ if (!isa<LLVM::LLVMPointerType>(targetType))
+ return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, targetType, inputValue);
+
+ if (!isa<LLVM::LLVMPointerType>(inputValue.getType()))
+ return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, inputValue);
+
+ return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
+ inputValue);
}
DeletionKind LLVM::LoadOp::removeBlockingUses(
@@ -146,7 +175,9 @@ DeletionKind LLVM::LoadOp::removeBlockingUses(
RewriterBase &rewriter, Value reachingDefinition) {
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
// pointer.
- rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
+ Value newResult = createConversionSequence(
+ rewriter, getLoc(), reachingDefinition, getResult().getType());
+ rewriter.replaceAllUsesWith(getResult(), newResult);
return DeletionKind::Delete;
}
@@ -161,8 +192,7 @@ bool LLVM::StoreOp::canUsesBeRemoved(
// fine, provided we are currently promoting its target value. Don't allow a
// store OF the slot pointer, only INTO the slot pointer.
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
- getValue() != slot.ptr && getValue().getType() == slot.elemType &&
- !getVolatile_();
+ getValue() != slot.ptr && !getVolatile_();
}
DeletionKind LLVM::StoreOp::removeBlockingUses(
diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
index 90e56c1166edfd..c2d443ef315379 100644
--- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir
+++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir
@@ -697,3 +697,203 @@ llvm.func @transitive_reaching_def() -> !llvm.ptr {
%3 = llvm.load %1 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
llvm.return %3 : !llvm.ptr
}
+
+// -----
+
+// CHECK-LABEL: @load_int_from_float
+llvm.func @load_int_from_float() -> i32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ %1 = llvm.alloca %0 x f32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+ // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
+ // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : f32 to i32
+ // CHECK: llvm.return %[[BITCAST:.*]]
+ llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @load_float_from_int
+llvm.func @load_float_from_int() -> f32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
+ // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
+ // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : i32 to f32
+ // CHECK: llvm.return %[[BITCAST:.*]]
+ llvm.return %2 : f32
+}
+
+// -----
+
+// CHECK-LABEL: @load_int_from_vector
+llvm.func @load_int_from_vector() -> i32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ %1 = llvm.alloca %0 x vector<2xi16> : (i32) -> !llvm.ptr
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+ // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
+ // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : vector<2xi16> to i32
+ // CHECK: llvm.return %[[BITCAST:.*]]
+ llvm.return %2 : i32
+}
+
+// -----
+
+// LLVM arrays cannot be bitcasted, so the following cannot be promoted.
+
+// CHECK-LABEL: @load_int_from_array
+llvm.func @load_int_from_array() -> i32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: llvm.alloca
+ %1 = llvm.alloca %0 x !llvm.array<2 x i16> : (i32) -> !llvm.ptr
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+ // CHECK-NOT: llvm.bitcast
+ llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @store_int_to_float
+// CHECK-SAME: %[[ARG:.*]]: i32
+llvm.func @store_int_to_float(%arg: i32) -> i32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ %1 = llvm.alloca %0 x f32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ llvm.store %arg, %1 {alignment = 4 : i64} : i32, !llvm.ptr
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+ // CHECK: llvm.return %[[ARG]]
+ llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @store_float_to_int
+// CHECK-SAME: %[[ARG:.*]]: f32
+llvm.func @store_float_to_int(%arg: f32) -> i32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ llvm.store %arg, %1 {alignment = 4 : i64} : f32, !llvm.ptr
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
+ // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : f32 to i32
+ // CHECK: llvm.return %[[BITCAST]]
+ llvm.return %2 : i32
+}
+
+// -----
+
+// CHECK-LABEL: @store_int_to_vector
+// CHECK-SAME: %[[ARG:.*]]: i32
+llvm.func @store_int_to_vector(%arg: i32) -> vector<4xi8> {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ %1 = llvm.alloca %0 x vector<2xi16> {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ llvm.store %arg, %1 {alignment = 4 : i64} : i32, !llvm.ptr
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> vector<4xi8>
+ // CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : i32 to vector<4xi8>
+ // CHECK: llvm.return %[[BITCAST]]
+ llvm.return %2 : vector<4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @load_ptr_from_int
+llvm.func @load_ptr_from_int() -> !llvm.ptr {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ %1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr
+ // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
+ // CHECK: %[[CAST:.*]] = llvm.inttoptr %[[UNDEF]] : i64 to !llvm.ptr
+ // CHECK: llvm.return %[[CAST:.*]]
+ llvm.return %2 : !llvm.ptr
+}
+
+// -----
+
+// CHECK-LABEL: @load_int_from_ptr
+llvm.func @load_int_from_ptr() -> i64 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ %1 = llvm.alloca %0 x !llvm.ptr {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i64
+ // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
+ // CHECK: %[[CAST:.*]] = llvm.ptrtoint %[[UNDEF]] : !llvm.ptr to i64
+ // CHECK: llvm.return %[[CAST:.*]]
+ llvm.return %2 : i64
+}
+
+// -----
+
+// CHECK-LABEL: @load_ptr_addrspace_cast
+llvm.func @load_ptr_addrspace_cast() -> !llvm.ptr<2> {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK-NOT: llvm.alloca
+ %1 = llvm.alloca %0 x !llvm.ptr<1> {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr<2>
+ // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
+ // CHECK: %[[CAST:.*]] = llvm.addrspacecast %[[UNDEF]] : !llvm.ptr<1> to !llvm.ptr<2>
+ // CHECK: llvm.return %[[CAST:.*]]
+ llvm.return %2 : !llvm.ptr<2>
+}
+
+// -----
+
+// CHECK-LABEL: @load_smaller_int
+llvm.func @load_smaller_int() -> i16 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: llvm.alloca
+ %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i16
+ llvm.return %2 : i16
+}
+
+// -----
+
+// CHECK-LABEL: @load_different_type_smaller
+llvm.func @load_different_type_smaller() -> f32 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: llvm.alloca
+ %1 = llvm.alloca %0 x i64 {alignment = 8 : i64} : (i32) -> !llvm.ptr
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
+ llvm.return %2 : f32
+}
+
+// -----
+
+// This alloca is too small for the load, still, mem2reg should not touch it.
+
+// CHECK-LABEL: @impossible_load
+llvm.func @impossible_load() -> f64 {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: llvm.alloca
+ %1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
+ llvm.return %2 : f64
+}
+
+// -----
+
+// Verifies that mem2reg does not introduce address space casts of pointers
+// with different bitsize.
+
+module attributes { dlti.dl_spec = #dlti.dl_spec<
+ #dlti.dl_entry<!llvm.ptr<1>, dense<[32, 64, 64]> : vector<3xi64>>,
+ #dlti.dl_entry<!llvm.ptr<2>, dense<[64, 64, 64]> : vector<3xi64>>
+>} {
+
+ // CHECK-LABEL: @load_ptr_addrspace_cast_different_size
+ llvm.func @load_ptr_addrspace_cast_different_size() -> !llvm.ptr<2> {
+ %0 = llvm.mlir.constant(1 : i32) : i32
+ // CHECK: llvm.alloca
+ %1 = llvm.alloca %0 x !llvm.ptr<1> {alignment = 4 : i64} : (i32) -> !llvm.ptr
+ %2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr<2>
+ // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
+ // CHECK: %[[CAST:.*]] = llvm.addrspacecast %[[UNDEF]] : !llvm.ptr<1> to !llvm.ptr<2>
+ // CHECK: llvm.return %[[CAST:.*]]
+ llvm.return %2 : !llvm.ptr<2>
+ }
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM modulo the last test that looks suspicious.
%1 = llvm.alloca %0 x !llvm.ptr<1> {alignment = 4 : i64} : (i32) -> !llvm.ptr | ||
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr<2> | ||
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef | ||
// CHECK: %[[CAST:.*]] = llvm.addrspacecast %[[UNDEF]] : !llvm.ptr<1> to !llvm.ptr<2> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not really match the comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whups, no idea how this sneaked in. The test actually just fails on this^^
…d and store (#87504)" (#87631) This reverts commit d6e4582 as it violates an assumption of Mem2Reg's block argument creation. Mem2Reg strongly assumes that all involved values have the same type as the alloca, which was relaxed by this PR. Therefore, branches got created that jumped to basic blocks with differently typed block arguments.
…ore (#87637) This commit relaxes Mem2Reg's type equality requirement for the LLVM dialect's load and store operations. For now, we only allow loads to be promoted if the reaching definition can be casted into a value of the target type. For stores, the same conversion casting check is applied and we ensure that their result is properly casted to the type of the memory slot. This is necessary to satisfy assumptions of the general mem2reg pass, as it creates block arguments with the types of the memory slot. This relands #87504
This commit relaxes Mem2Reg's type equality requirement for the LLVM dialect's load and store operations. For now, we only allow loads to be promoted if the reaching definition can be casted into a value of the target type.
For stores, all type checks are removed, as a non-volatile store that does not write out the alloca's pointer can always be deleted.