Skip to content

Revert "[MLIR][LLVM][Mem2Reg] Relax type equality requirement for load and store (#87504)" #87631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 4 additions & 34 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,6 @@ Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
return getValue();
}

/// Checks that two types are the same or can be cast into one another.
static bool areCastCompatible(const DataLayout &layout, Type lhs, Type rhs) {
return lhs == rhs || (!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(lhs) &&
!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(rhs) &&
layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
}

bool LLVM::LoadOp::canUsesBeRemoved(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
SmallVectorImpl<OpOperand *> &newBlockingUses,
Expand All @@ -145,39 +138,15 @@ bool LLVM::LoadOp::canUsesBeRemoved(
// be removed (provided it loads the exact stored value and is not
// volatile).
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
areCastCompatible(dataLayout, getResult().getType(), slot.elemType) &&
!getVolatile_();
}

/// Constructs operations that convert `inputValue` into a new value of type
/// `targetType`. Assumes that this conversion is possible.
static Value createConversionSequence(RewriterBase &rewriter, Location loc,
Value inputValue, Type targetType) {
if (inputValue.getType() == targetType)
return inputValue;

if (!isa<LLVM::LLVMPointerType>(targetType) &&
!isa<LLVM::LLVMPointerType>(inputValue.getType()))
return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, inputValue);

if (!isa<LLVM::LLVMPointerType>(targetType))
return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, targetType, inputValue);

if (!isa<LLVM::LLVMPointerType>(inputValue.getType()))
return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, inputValue);

return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
inputValue);
getResult().getType() == slot.elemType && !getVolatile_();
}

DeletionKind LLVM::LoadOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
RewriterBase &rewriter, Value reachingDefinition) {
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
// pointer.
Value newResult = createConversionSequence(
rewriter, getLoc(), reachingDefinition, getResult().getType());
rewriter.replaceAllUsesWith(getResult(), newResult);
rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
return DeletionKind::Delete;
}

Expand All @@ -192,7 +161,8 @@ bool LLVM::StoreOp::canUsesBeRemoved(
// fine, provided we are currently promoting its target value. Don't allow a
// store OF the slot pointer, only INTO the slot pointer.
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
getValue() != slot.ptr && !getVolatile_();
getValue() != slot.ptr && getValue().getType() == slot.elemType &&
!getVolatile_();
}

DeletionKind LLVM::StoreOp::removeBlockingUses(
Expand Down
197 changes: 0 additions & 197 deletions mlir/test/Dialect/LLVMIR/mem2reg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -697,200 +697,3 @@ llvm.func @transitive_reaching_def() -> !llvm.ptr {
%3 = llvm.load %1 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
llvm.return %3 : !llvm.ptr
}

// -----

// CHECK-LABEL: @load_int_from_float
llvm.func @load_int_from_float() -> i32 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK-NOT: llvm.alloca
%1 = llvm.alloca %0 x f32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : f32 to i32
// CHECK: llvm.return %[[BITCAST:.*]]
llvm.return %2 : i32
}

// -----

// CHECK-LABEL: @load_float_from_int
llvm.func @load_float_from_int() -> f32 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK-NOT: llvm.alloca
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : i32 to f32
// CHECK: llvm.return %[[BITCAST:.*]]
llvm.return %2 : f32
}

// -----

// CHECK-LABEL: @load_int_from_vector
llvm.func @load_int_from_vector() -> i32 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK-NOT: llvm.alloca
%1 = llvm.alloca %0 x vector<2xi16> : (i32) -> !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : vector<2xi16> to i32
// CHECK: llvm.return %[[BITCAST:.*]]
llvm.return %2 : i32
}

// -----

// LLVM arrays cannot be bitcasted, so the following cannot be promoted.

// CHECK-LABEL: @load_int_from_array
llvm.func @load_int_from_array() -> i32 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: llvm.alloca
%1 = llvm.alloca %0 x !llvm.array<2 x i16> : (i32) -> !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
// CHECK-NOT: llvm.bitcast
llvm.return %2 : i32
}

// -----

// CHECK-LABEL: @store_int_to_float
// CHECK-SAME: %[[ARG:.*]]: i32
llvm.func @store_int_to_float(%arg: i32) -> i32 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK-NOT: llvm.alloca
%1 = llvm.alloca %0 x f32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
llvm.store %arg, %1 {alignment = 4 : i64} : i32, !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
// CHECK: llvm.return %[[ARG]]
llvm.return %2 : i32
}

// -----

// CHECK-LABEL: @store_float_to_int
// CHECK-SAME: %[[ARG:.*]]: f32
llvm.func @store_float_to_int(%arg: f32) -> i32 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK-NOT: llvm.alloca
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
llvm.store %arg, %1 {alignment = 4 : i64} : f32, !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : f32 to i32
// CHECK: llvm.return %[[BITCAST]]
llvm.return %2 : i32
}

// -----

// CHECK-LABEL: @store_int_to_vector
// CHECK-SAME: %[[ARG:.*]]: i32
llvm.func @store_int_to_vector(%arg: i32) -> vector<4xi8> {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK-NOT: llvm.alloca
%1 = llvm.alloca %0 x vector<2xi16> {alignment = 4 : i64} : (i32) -> !llvm.ptr
llvm.store %arg, %1 {alignment = 4 : i64} : i32, !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> vector<4xi8>
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : i32 to vector<4xi8>
// CHECK: llvm.return %[[BITCAST]]
llvm.return %2 : vector<4xi8>
}

// -----

// CHECK-LABEL: @load_ptr_from_int
llvm.func @load_ptr_from_int() -> !llvm.ptr {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK-NOT: llvm.alloca
%1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
// CHECK: %[[CAST:.*]] = llvm.inttoptr %[[UNDEF]] : i64 to !llvm.ptr
// CHECK: llvm.return %[[CAST:.*]]
llvm.return %2 : !llvm.ptr
}

// -----

// CHECK-LABEL: @load_int_from_ptr
llvm.func @load_int_from_ptr() -> i64 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK-NOT: llvm.alloca
%1 = llvm.alloca %0 x !llvm.ptr {alignment = 4 : i64} : (i32) -> !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i64
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
// CHECK: %[[CAST:.*]] = llvm.ptrtoint %[[UNDEF]] : !llvm.ptr to i64
// CHECK: llvm.return %[[CAST:.*]]
llvm.return %2 : i64
}

// -----

// CHECK-LABEL: @load_ptr_addrspace_cast
llvm.func @load_ptr_addrspace_cast() -> !llvm.ptr<2> {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK-NOT: llvm.alloca
%1 = llvm.alloca %0 x !llvm.ptr<1> {alignment = 4 : i64} : (i32) -> !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr<2>
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
// CHECK: %[[CAST:.*]] = llvm.addrspacecast %[[UNDEF]] : !llvm.ptr<1> to !llvm.ptr<2>
// CHECK: llvm.return %[[CAST:.*]]
llvm.return %2 : !llvm.ptr<2>
}

// -----

// CHECK-LABEL: @load_smaller_int
llvm.func @load_smaller_int() -> i16 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: llvm.alloca
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i16
llvm.return %2 : i16
}

// -----

// CHECK-LABEL: @load_different_type_smaller
llvm.func @load_different_type_smaller() -> f32 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: llvm.alloca
%1 = llvm.alloca %0 x i64 {alignment = 8 : i64} : (i32) -> !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
llvm.return %2 : f32
}

// -----

// This alloca is too small for the load, still, mem2reg should not touch it.

// CHECK-LABEL: @impossible_load
llvm.func @impossible_load() -> f64 {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: llvm.alloca
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
llvm.return %2 : f64
}

// -----

// Verifies that mem2reg does not introduce address space casts of pointers
// with different bitsize.

module attributes { dlti.dl_spec = #dlti.dl_spec<
#dlti.dl_entry<!llvm.ptr<1>, dense<[32, 64, 64]> : vector<3xi64>>,
#dlti.dl_entry<!llvm.ptr<2>, dense<[64, 64, 64]> : vector<3xi64>>
>} {

// CHECK-LABEL: @load_ptr_addrspace_cast_different_size
llvm.func @load_ptr_addrspace_cast_different_size() -> !llvm.ptr<2> {
%0 = llvm.mlir.constant(1 : i32) : i32
// CHECK: llvm.alloca
%1 = llvm.alloca %0 x !llvm.ptr<1> {alignment = 4 : i64} : (i32) -> !llvm.ptr
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr<2>
llvm.return %2 : !llvm.ptr<2>
}
}