Skip to content

Commit 8a07ae4

Browse files
committed
Revert "[MLIR][LLVM][Mem2Reg] Relax type equality requirement for load and store (#87504)"
This reverts commit d6e4582 as it violates an assumption of Mem2Regs block argument creation.
1 parent e69cab7 commit 8a07ae4

File tree

2 files changed

+4
-231
lines changed

2 files changed

+4
-231
lines changed

mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,6 @@ Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
126126
return getValue();
127127
}
128128

129-
/// Checks that two types are the same or can be cast into one another.
130-
static bool areCastCompatible(const DataLayout &layout, Type lhs, Type rhs) {
131-
return lhs == rhs || (!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(lhs) &&
132-
!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(rhs) &&
133-
layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
134-
}
135-
136129
bool LLVM::LoadOp::canUsesBeRemoved(
137130
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
138131
SmallVectorImpl<OpOperand *> &newBlockingUses,
@@ -145,39 +138,15 @@ bool LLVM::LoadOp::canUsesBeRemoved(
145138
// be removed (provided it loads the exact stored value and is not
146139
// volatile).
147140
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
148-
areCastCompatible(dataLayout, getResult().getType(), slot.elemType) &&
149-
!getVolatile_();
150-
}
151-
152-
/// Constructs operations that convert `inputValue` into a new value of type
153-
/// `targetType`. Assumes that this conversion is possible.
154-
static Value createConversionSequence(RewriterBase &rewriter, Location loc,
155-
Value inputValue, Type targetType) {
156-
if (inputValue.getType() == targetType)
157-
return inputValue;
158-
159-
if (!isa<LLVM::LLVMPointerType>(targetType) &&
160-
!isa<LLVM::LLVMPointerType>(inputValue.getType()))
161-
return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, inputValue);
162-
163-
if (!isa<LLVM::LLVMPointerType>(targetType))
164-
return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, targetType, inputValue);
165-
166-
if (!isa<LLVM::LLVMPointerType>(inputValue.getType()))
167-
return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, inputValue);
168-
169-
return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
170-
inputValue);
141+
getResult().getType() == slot.elemType && !getVolatile_();
171142
}
172143

173144
DeletionKind LLVM::LoadOp::removeBlockingUses(
174145
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
175146
RewriterBase &rewriter, Value reachingDefinition) {
176147
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
177148
// pointer.
178-
Value newResult = createConversionSequence(
179-
rewriter, getLoc(), reachingDefinition, getResult().getType());
180-
rewriter.replaceAllUsesWith(getResult(), newResult);
149+
rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
181150
return DeletionKind::Delete;
182151
}
183152

@@ -192,7 +161,8 @@ bool LLVM::StoreOp::canUsesBeRemoved(
192161
// fine, provided we are currently promoting its target value. Don't allow a
193162
// store OF the slot pointer, only INTO the slot pointer.
194163
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
195-
getValue() != slot.ptr && !getVolatile_();
164+
getValue() != slot.ptr && getValue().getType() == slot.elemType &&
165+
!getVolatile_();
196166
}
197167

198168
DeletionKind LLVM::StoreOp::removeBlockingUses(

mlir/test/Dialect/LLVMIR/mem2reg.mlir

Lines changed: 0 additions & 197 deletions
Original file line numberDiff line numberDiff line change
@@ -697,200 +697,3 @@ llvm.func @transitive_reaching_def() -> !llvm.ptr {
697697
%3 = llvm.load %1 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
698698
llvm.return %3 : !llvm.ptr
699699
}
700-
701-
// -----
702-
703-
// CHECK-LABEL: @load_int_from_float
704-
llvm.func @load_int_from_float() -> i32 {
705-
%0 = llvm.mlir.constant(1 : i32) : i32
706-
// CHECK-NOT: llvm.alloca
707-
%1 = llvm.alloca %0 x f32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
708-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
709-
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
710-
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : f32 to i32
711-
// CHECK: llvm.return %[[BITCAST:.*]]
712-
llvm.return %2 : i32
713-
}
714-
715-
// -----
716-
717-
// CHECK-LABEL: @load_float_from_int
718-
llvm.func @load_float_from_int() -> f32 {
719-
%0 = llvm.mlir.constant(1 : i32) : i32
720-
// CHECK-NOT: llvm.alloca
721-
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
722-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
723-
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
724-
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : i32 to f32
725-
// CHECK: llvm.return %[[BITCAST:.*]]
726-
llvm.return %2 : f32
727-
}
728-
729-
// -----
730-
731-
// CHECK-LABEL: @load_int_from_vector
732-
llvm.func @load_int_from_vector() -> i32 {
733-
%0 = llvm.mlir.constant(1 : i32) : i32
734-
// CHECK-NOT: llvm.alloca
735-
%1 = llvm.alloca %0 x vector<2xi16> : (i32) -> !llvm.ptr
736-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
737-
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
738-
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : vector<2xi16> to i32
739-
// CHECK: llvm.return %[[BITCAST:.*]]
740-
llvm.return %2 : i32
741-
}
742-
743-
// -----
744-
745-
// LLVM arrays cannot be bitcasted, so the following cannot be promoted.
746-
747-
// CHECK-LABEL: @load_int_from_array
748-
llvm.func @load_int_from_array() -> i32 {
749-
%0 = llvm.mlir.constant(1 : i32) : i32
750-
// CHECK: llvm.alloca
751-
%1 = llvm.alloca %0 x !llvm.array<2 x i16> : (i32) -> !llvm.ptr
752-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
753-
// CHECK-NOT: llvm.bitcast
754-
llvm.return %2 : i32
755-
}
756-
757-
// -----
758-
759-
// CHECK-LABEL: @store_int_to_float
760-
// CHECK-SAME: %[[ARG:.*]]: i32
761-
llvm.func @store_int_to_float(%arg: i32) -> i32 {
762-
%0 = llvm.mlir.constant(1 : i32) : i32
763-
// CHECK-NOT: llvm.alloca
764-
%1 = llvm.alloca %0 x f32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
765-
llvm.store %arg, %1 {alignment = 4 : i64} : i32, !llvm.ptr
766-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
767-
// CHECK: llvm.return %[[ARG]]
768-
llvm.return %2 : i32
769-
}
770-
771-
// -----
772-
773-
// CHECK-LABEL: @store_float_to_int
774-
// CHECK-SAME: %[[ARG:.*]]: f32
775-
llvm.func @store_float_to_int(%arg: f32) -> i32 {
776-
%0 = llvm.mlir.constant(1 : i32) : i32
777-
// CHECK-NOT: llvm.alloca
778-
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
779-
llvm.store %arg, %1 {alignment = 4 : i64} : f32, !llvm.ptr
780-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
781-
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : f32 to i32
782-
// CHECK: llvm.return %[[BITCAST]]
783-
llvm.return %2 : i32
784-
}
785-
786-
// -----
787-
788-
// CHECK-LABEL: @store_int_to_vector
789-
// CHECK-SAME: %[[ARG:.*]]: i32
790-
llvm.func @store_int_to_vector(%arg: i32) -> vector<4xi8> {
791-
%0 = llvm.mlir.constant(1 : i32) : i32
792-
// CHECK-NOT: llvm.alloca
793-
%1 = llvm.alloca %0 x vector<2xi16> {alignment = 4 : i64} : (i32) -> !llvm.ptr
794-
llvm.store %arg, %1 {alignment = 4 : i64} : i32, !llvm.ptr
795-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> vector<4xi8>
796-
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : i32 to vector<4xi8>
797-
// CHECK: llvm.return %[[BITCAST]]
798-
llvm.return %2 : vector<4xi8>
799-
}
800-
801-
// -----
802-
803-
// CHECK-LABEL: @load_ptr_from_int
804-
llvm.func @load_ptr_from_int() -> !llvm.ptr {
805-
%0 = llvm.mlir.constant(1 : i32) : i32
806-
// CHECK-NOT: llvm.alloca
807-
%1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
808-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr
809-
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
810-
// CHECK: %[[CAST:.*]] = llvm.inttoptr %[[UNDEF]] : i64 to !llvm.ptr
811-
// CHECK: llvm.return %[[CAST:.*]]
812-
llvm.return %2 : !llvm.ptr
813-
}
814-
815-
// -----
816-
817-
// CHECK-LABEL: @load_int_from_ptr
818-
llvm.func @load_int_from_ptr() -> i64 {
819-
%0 = llvm.mlir.constant(1 : i32) : i32
820-
// CHECK-NOT: llvm.alloca
821-
%1 = llvm.alloca %0 x !llvm.ptr {alignment = 4 : i64} : (i32) -> !llvm.ptr
822-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i64
823-
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
824-
// CHECK: %[[CAST:.*]] = llvm.ptrtoint %[[UNDEF]] : !llvm.ptr to i64
825-
// CHECK: llvm.return %[[CAST:.*]]
826-
llvm.return %2 : i64
827-
}
828-
829-
// -----
830-
831-
// CHECK-LABEL: @load_ptr_addrspace_cast
832-
llvm.func @load_ptr_addrspace_cast() -> !llvm.ptr<2> {
833-
%0 = llvm.mlir.constant(1 : i32) : i32
834-
// CHECK-NOT: llvm.alloca
835-
%1 = llvm.alloca %0 x !llvm.ptr<1> {alignment = 4 : i64} : (i32) -> !llvm.ptr
836-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr<2>
837-
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
838-
// CHECK: %[[CAST:.*]] = llvm.addrspacecast %[[UNDEF]] : !llvm.ptr<1> to !llvm.ptr<2>
839-
// CHECK: llvm.return %[[CAST:.*]]
840-
llvm.return %2 : !llvm.ptr<2>
841-
}
842-
843-
// -----
844-
845-
// CHECK-LABEL: @load_smaller_int
846-
llvm.func @load_smaller_int() -> i16 {
847-
%0 = llvm.mlir.constant(1 : i32) : i32
848-
// CHECK: llvm.alloca
849-
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
850-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i16
851-
llvm.return %2 : i16
852-
}
853-
854-
// -----
855-
856-
// CHECK-LABEL: @load_different_type_smaller
857-
llvm.func @load_different_type_smaller() -> f32 {
858-
%0 = llvm.mlir.constant(1 : i32) : i32
859-
// CHECK: llvm.alloca
860-
%1 = llvm.alloca %0 x i64 {alignment = 8 : i64} : (i32) -> !llvm.ptr
861-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
862-
llvm.return %2 : f32
863-
}
864-
865-
// -----
866-
867-
// This alloca is too small for the load, still, mem2reg should not touch it.
868-
869-
// CHECK-LABEL: @impossible_load
870-
llvm.func @impossible_load() -> f64 {
871-
%0 = llvm.mlir.constant(1 : i32) : i32
872-
// CHECK: llvm.alloca
873-
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
874-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
875-
llvm.return %2 : f64
876-
}
877-
878-
// -----
879-
880-
// Verifies that mem2reg does not introduce address space casts of pointers
881-
// with different bitsize.
882-
883-
module attributes { dlti.dl_spec = #dlti.dl_spec<
884-
#dlti.dl_entry<!llvm.ptr<1>, dense<[32, 64, 64]> : vector<3xi64>>,
885-
#dlti.dl_entry<!llvm.ptr<2>, dense<[64, 64, 64]> : vector<3xi64>>
886-
>} {
887-
888-
// CHECK-LABEL: @load_ptr_addrspace_cast_different_size
889-
llvm.func @load_ptr_addrspace_cast_different_size() -> !llvm.ptr<2> {
890-
%0 = llvm.mlir.constant(1 : i32) : i32
891-
// CHECK: llvm.alloca
892-
%1 = llvm.alloca %0 x !llvm.ptr<1> {alignment = 4 : i64} : (i32) -> !llvm.ptr
893-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr<2>
894-
llvm.return %2 : !llvm.ptr<2>
895-
}
896-
}

0 commit comments

Comments
 (0)