Skip to content

Commit e0e615e

Browse files
authored
Revert "[MLIR][LLVM][Mem2Reg] Relax type equality requirement for load and store (#87504)" (#87631)
This reverts commit d6e4582 as it violates an assumption of Mem2Reg's block argument creation. Mem2Reg strongly assumes that all involved values have the same type as the alloca, which was relaxed by this PR. Therefore, branches got created that jumped to basic blocks with differently typed block arguments.
1 parent 8f9903d commit e0e615e

File tree

2 files changed

+4
-231
lines changed

2 files changed

+4
-231
lines changed

mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,6 @@ Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
126126
return getValue();
127127
}
128128

129-
/// Checks that two types are the same or can be cast into one another.
130-
static bool areCastCompatible(const DataLayout &layout, Type lhs, Type rhs) {
131-
return lhs == rhs || (!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(lhs) &&
132-
!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(rhs) &&
133-
layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
134-
}
135-
136129
bool LLVM::LoadOp::canUsesBeRemoved(
137130
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
138131
SmallVectorImpl<OpOperand *> &newBlockingUses,
@@ -145,39 +138,15 @@ bool LLVM::LoadOp::canUsesBeRemoved(
145138
// be removed (provided it loads the exact stored value and is not
146139
// volatile).
147140
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
148-
areCastCompatible(dataLayout, getResult().getType(), slot.elemType) &&
149-
!getVolatile_();
150-
}
151-
152-
/// Constructs operations that convert `inputValue` into a new value of type
153-
/// `targetType`. Assumes that this conversion is possible.
154-
static Value createConversionSequence(RewriterBase &rewriter, Location loc,
155-
Value inputValue, Type targetType) {
156-
if (inputValue.getType() == targetType)
157-
return inputValue;
158-
159-
if (!isa<LLVM::LLVMPointerType>(targetType) &&
160-
!isa<LLVM::LLVMPointerType>(inputValue.getType()))
161-
return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, inputValue);
162-
163-
if (!isa<LLVM::LLVMPointerType>(targetType))
164-
return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, targetType, inputValue);
165-
166-
if (!isa<LLVM::LLVMPointerType>(inputValue.getType()))
167-
return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, inputValue);
168-
169-
return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
170-
inputValue);
141+
getResult().getType() == slot.elemType && !getVolatile_();
171142
}
172143

173144
DeletionKind LLVM::LoadOp::removeBlockingUses(
174145
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
175146
RewriterBase &rewriter, Value reachingDefinition) {
176147
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
177148
// pointer.
178-
Value newResult = createConversionSequence(
179-
rewriter, getLoc(), reachingDefinition, getResult().getType());
180-
rewriter.replaceAllUsesWith(getResult(), newResult);
149+
rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
181150
return DeletionKind::Delete;
182151
}
183152

@@ -192,7 +161,8 @@ bool LLVM::StoreOp::canUsesBeRemoved(
192161
// fine, provided we are currently promoting its target value. Don't allow a
193162
// store OF the slot pointer, only INTO the slot pointer.
194163
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
195-
getValue() != slot.ptr && !getVolatile_();
164+
getValue() != slot.ptr && getValue().getType() == slot.elemType &&
165+
!getVolatile_();
196166
}
197167

198168
DeletionKind LLVM::StoreOp::removeBlockingUses(

mlir/test/Dialect/LLVMIR/mem2reg.mlir

Lines changed: 0 additions & 197 deletions
Original file line numberDiff line numberDiff line change
@@ -697,200 +697,3 @@ llvm.func @transitive_reaching_def() -> !llvm.ptr {
697697
%3 = llvm.load %1 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
698698
llvm.return %3 : !llvm.ptr
699699
}
700-
701-
// -----
702-
703-
// CHECK-LABEL: @load_int_from_float
704-
llvm.func @load_int_from_float() -> i32 {
705-
%0 = llvm.mlir.constant(1 : i32) : i32
706-
// CHECK-NOT: llvm.alloca
707-
%1 = llvm.alloca %0 x f32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
708-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
709-
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
710-
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : f32 to i32
711-
// CHECK: llvm.return %[[BITCAST:.*]]
712-
llvm.return %2 : i32
713-
}
714-
715-
// -----
716-
717-
// CHECK-LABEL: @load_float_from_int
718-
llvm.func @load_float_from_int() -> f32 {
719-
%0 = llvm.mlir.constant(1 : i32) : i32
720-
// CHECK-NOT: llvm.alloca
721-
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
722-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
723-
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
724-
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : i32 to f32
725-
// CHECK: llvm.return %[[BITCAST:.*]]
726-
llvm.return %2 : f32
727-
}
728-
729-
// -----
730-
731-
// CHECK-LABEL: @load_int_from_vector
732-
llvm.func @load_int_from_vector() -> i32 {
733-
%0 = llvm.mlir.constant(1 : i32) : i32
734-
// CHECK-NOT: llvm.alloca
735-
%1 = llvm.alloca %0 x vector<2xi16> : (i32) -> !llvm.ptr
736-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
737-
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
738-
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : vector<2xi16> to i32
739-
// CHECK: llvm.return %[[BITCAST:.*]]
740-
llvm.return %2 : i32
741-
}
742-
743-
// -----
744-
745-
// LLVM arrays cannot be bitcasted, so the following cannot be promoted.
746-
747-
// CHECK-LABEL: @load_int_from_array
748-
llvm.func @load_int_from_array() -> i32 {
749-
%0 = llvm.mlir.constant(1 : i32) : i32
750-
// CHECK: llvm.alloca
751-
%1 = llvm.alloca %0 x !llvm.array<2 x i16> : (i32) -> !llvm.ptr
752-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
753-
// CHECK-NOT: llvm.bitcast
754-
llvm.return %2 : i32
755-
}
756-
757-
// -----
758-
759-
// CHECK-LABEL: @store_int_to_float
760-
// CHECK-SAME: %[[ARG:.*]]: i32
761-
llvm.func @store_int_to_float(%arg: i32) -> i32 {
762-
%0 = llvm.mlir.constant(1 : i32) : i32
763-
// CHECK-NOT: llvm.alloca
764-
%1 = llvm.alloca %0 x f32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
765-
llvm.store %arg, %1 {alignment = 4 : i64} : i32, !llvm.ptr
766-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
767-
// CHECK: llvm.return %[[ARG]]
768-
llvm.return %2 : i32
769-
}
770-
771-
// -----
772-
773-
// CHECK-LABEL: @store_float_to_int
774-
// CHECK-SAME: %[[ARG:.*]]: f32
775-
llvm.func @store_float_to_int(%arg: f32) -> i32 {
776-
%0 = llvm.mlir.constant(1 : i32) : i32
777-
// CHECK-NOT: llvm.alloca
778-
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
779-
llvm.store %arg, %1 {alignment = 4 : i64} : f32, !llvm.ptr
780-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
781-
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : f32 to i32
782-
// CHECK: llvm.return %[[BITCAST]]
783-
llvm.return %2 : i32
784-
}
785-
786-
// -----
787-
788-
// CHECK-LABEL: @store_int_to_vector
789-
// CHECK-SAME: %[[ARG:.*]]: i32
790-
llvm.func @store_int_to_vector(%arg: i32) -> vector<4xi8> {
791-
%0 = llvm.mlir.constant(1 : i32) : i32
792-
// CHECK-NOT: llvm.alloca
793-
%1 = llvm.alloca %0 x vector<2xi16> {alignment = 4 : i64} : (i32) -> !llvm.ptr
794-
llvm.store %arg, %1 {alignment = 4 : i64} : i32, !llvm.ptr
795-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> vector<4xi8>
796-
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : i32 to vector<4xi8>
797-
// CHECK: llvm.return %[[BITCAST]]
798-
llvm.return %2 : vector<4xi8>
799-
}
800-
801-
// -----
802-
803-
// CHECK-LABEL: @load_ptr_from_int
804-
llvm.func @load_ptr_from_int() -> !llvm.ptr {
805-
%0 = llvm.mlir.constant(1 : i32) : i32
806-
// CHECK-NOT: llvm.alloca
807-
%1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
808-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr
809-
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
810-
// CHECK: %[[CAST:.*]] = llvm.inttoptr %[[UNDEF]] : i64 to !llvm.ptr
811-
// CHECK: llvm.return %[[CAST:.*]]
812-
llvm.return %2 : !llvm.ptr
813-
}
814-
815-
// -----
816-
817-
// CHECK-LABEL: @load_int_from_ptr
818-
llvm.func @load_int_from_ptr() -> i64 {
819-
%0 = llvm.mlir.constant(1 : i32) : i32
820-
// CHECK-NOT: llvm.alloca
821-
%1 = llvm.alloca %0 x !llvm.ptr {alignment = 4 : i64} : (i32) -> !llvm.ptr
822-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i64
823-
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
824-
// CHECK: %[[CAST:.*]] = llvm.ptrtoint %[[UNDEF]] : !llvm.ptr to i64
825-
// CHECK: llvm.return %[[CAST:.*]]
826-
llvm.return %2 : i64
827-
}
828-
829-
// -----
830-
831-
// CHECK-LABEL: @load_ptr_addrspace_cast
832-
llvm.func @load_ptr_addrspace_cast() -> !llvm.ptr<2> {
833-
%0 = llvm.mlir.constant(1 : i32) : i32
834-
// CHECK-NOT: llvm.alloca
835-
%1 = llvm.alloca %0 x !llvm.ptr<1> {alignment = 4 : i64} : (i32) -> !llvm.ptr
836-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr<2>
837-
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
838-
// CHECK: %[[CAST:.*]] = llvm.addrspacecast %[[UNDEF]] : !llvm.ptr<1> to !llvm.ptr<2>
839-
// CHECK: llvm.return %[[CAST:.*]]
840-
llvm.return %2 : !llvm.ptr<2>
841-
}
842-
843-
// -----
844-
845-
// CHECK-LABEL: @load_smaller_int
846-
llvm.func @load_smaller_int() -> i16 {
847-
%0 = llvm.mlir.constant(1 : i32) : i32
848-
// CHECK: llvm.alloca
849-
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
850-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i16
851-
llvm.return %2 : i16
852-
}
853-
854-
// -----
855-
856-
// CHECK-LABEL: @load_different_type_smaller
857-
llvm.func @load_different_type_smaller() -> f32 {
858-
%0 = llvm.mlir.constant(1 : i32) : i32
859-
// CHECK: llvm.alloca
860-
%1 = llvm.alloca %0 x i64 {alignment = 8 : i64} : (i32) -> !llvm.ptr
861-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
862-
llvm.return %2 : f32
863-
}
864-
865-
// -----
866-
867-
// This alloca is too small for the load, still, mem2reg should not touch it.
868-
869-
// CHECK-LABEL: @impossible_load
870-
llvm.func @impossible_load() -> f64 {
871-
%0 = llvm.mlir.constant(1 : i32) : i32
872-
// CHECK: llvm.alloca
873-
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
874-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
875-
llvm.return %2 : f64
876-
}
877-
878-
// -----
879-
880-
// Verifies that mem2reg does not introduce address space casts of pointers
881-
// with different bitsize.
882-
883-
module attributes { dlti.dl_spec = #dlti.dl_spec<
884-
#dlti.dl_entry<!llvm.ptr<1>, dense<[32, 64, 64]> : vector<3xi64>>,
885-
#dlti.dl_entry<!llvm.ptr<2>, dense<[64, 64, 64]> : vector<3xi64>>
886-
>} {
887-
888-
// CHECK-LABEL: @load_ptr_addrspace_cast_different_size
889-
llvm.func @load_ptr_addrspace_cast_different_size() -> !llvm.ptr<2> {
890-
%0 = llvm.mlir.constant(1 : i32) : i32
891-
// CHECK: llvm.alloca
892-
%1 = llvm.alloca %0 x !llvm.ptr<1> {alignment = 4 : i64} : (i32) -> !llvm.ptr
893-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr<2>
894-
llvm.return %2 : !llvm.ptr<2>
895-
}
896-
}

0 commit comments

Comments
 (0)