Skip to content

Commit 974f1ee

Browse files
authored
[MLIR][LLVM][Mem2Reg] Relax type equality requirement for load and store (#87637)
This commit relaxes Mem2Reg's type equality requirement for the LLVM dialect's load and store operations. For now, we only allow loads to be promoted if the reaching definition can be casted into a value of the target type. For stores, the same conversion casting check is applied and we ensure that their result is properly casted to the type of the memory slot. This is necessary to satisfy assumptions of the general mem2reg pass, as it creates block arguments with the types of the memory slot. This relands #87504
1 parent 1c683eb commit 974f1ee

File tree

2 files changed

+284
-4
lines changed

2 files changed

+284
-4
lines changed

mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,37 @@ bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
122122
return getAddr() == slot.ptr;
123123
}
124124

125+
/// Checks that two types are the same or can be cast into one another.
126+
static bool areCastCompatible(const DataLayout &layout, Type lhs, Type rhs) {
127+
return lhs == rhs || (!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(lhs) &&
128+
!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(rhs) &&
129+
layout.getTypeSize(lhs) == layout.getTypeSize(rhs));
130+
}
131+
132+
/// Constructs operations that convert `inputValue` into a new value of type
133+
/// `targetType`. Assumes that this conversion is possible.
134+
static Value createConversionSequence(RewriterBase &rewriter, Location loc,
135+
Value inputValue, Type targetType) {
136+
if (inputValue.getType() == targetType)
137+
return inputValue;
138+
139+
if (!isa<LLVM::LLVMPointerType>(targetType) &&
140+
!isa<LLVM::LLVMPointerType>(inputValue.getType()))
141+
return rewriter.createOrFold<LLVM::BitcastOp>(loc, targetType, inputValue);
142+
143+
if (!isa<LLVM::LLVMPointerType>(targetType))
144+
return rewriter.createOrFold<LLVM::PtrToIntOp>(loc, targetType, inputValue);
145+
146+
if (!isa<LLVM::LLVMPointerType>(inputValue.getType()))
147+
return rewriter.createOrFold<LLVM::IntToPtrOp>(loc, targetType, inputValue);
148+
149+
return rewriter.createOrFold<LLVM::AddrSpaceCastOp>(loc, targetType,
150+
inputValue);
151+
}
152+
125153
Value LLVM::StoreOp::getStored(const MemorySlot &slot, RewriterBase &rewriter) {
126-
return getValue();
154+
return createConversionSequence(rewriter, getLoc(), getValue(),
155+
slot.elemType);
127156
}
128157

129158
bool LLVM::LoadOp::canUsesBeRemoved(
@@ -138,15 +167,18 @@ bool LLVM::LoadOp::canUsesBeRemoved(
138167
// be removed (provided it loads the exact stored value and is not
139168
// volatile).
140169
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
141-
getResult().getType() == slot.elemType && !getVolatile_();
170+
areCastCompatible(dataLayout, getResult().getType(), slot.elemType) &&
171+
!getVolatile_();
142172
}
143173

144174
DeletionKind LLVM::LoadOp::removeBlockingUses(
145175
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
146176
RewriterBase &rewriter, Value reachingDefinition) {
147177
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
148178
// pointer.
149-
rewriter.replaceAllUsesWith(getResult(), reachingDefinition);
179+
Value newResult = createConversionSequence(
180+
rewriter, getLoc(), reachingDefinition, getResult().getType());
181+
rewriter.replaceAllUsesWith(getResult(), newResult);
150182
return DeletionKind::Delete;
151183
}
152184

@@ -161,7 +193,8 @@ bool LLVM::StoreOp::canUsesBeRemoved(
161193
// fine, provided we are currently promoting its target value. Don't allow a
162194
// store OF the slot pointer, only INTO the slot pointer.
163195
return blockingUse == slot.ptr && getAddr() == slot.ptr &&
164-
getValue() != slot.ptr && getValue().getType() == slot.elemType &&
196+
getValue() != slot.ptr &&
197+
areCastCompatible(dataLayout, slot.elemType, getValue().getType()) &&
165198
!getVolatile_();
166199
}
167200

mlir/test/Dialect/LLVMIR/mem2reg.mlir

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,3 +697,250 @@ llvm.func @transitive_reaching_def() -> !llvm.ptr {
697697
%3 = llvm.load %1 {alignment = 8 : i64} : !llvm.ptr -> !llvm.ptr
698698
llvm.return %3 : !llvm.ptr
699699
}
700+
701+
// -----
702+
703+
// CHECK-LABEL: @load_int_from_float
704+
llvm.func @load_int_from_float() -> i32 {
705+
%0 = llvm.mlir.constant(1 : i32) : i32
706+
// CHECK-NOT: llvm.alloca
707+
%1 = llvm.alloca %0 x f32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
708+
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
709+
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
710+
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : f32 to i32
711+
// CHECK: llvm.return %[[BITCAST:.*]]
712+
llvm.return %2 : i32
713+
}
714+
715+
// -----
716+
717+
// CHECK-LABEL: @load_float_from_int
718+
llvm.func @load_float_from_int() -> f32 {
719+
%0 = llvm.mlir.constant(1 : i32) : i32
720+
// CHECK-NOT: llvm.alloca
721+
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
722+
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
723+
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
724+
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : i32 to f32
725+
// CHECK: llvm.return %[[BITCAST:.*]]
726+
llvm.return %2 : f32
727+
}
728+
729+
// -----
730+
731+
// CHECK-LABEL: @load_int_from_vector
732+
llvm.func @load_int_from_vector() -> i32 {
733+
%0 = llvm.mlir.constant(1 : i32) : i32
734+
// CHECK-NOT: llvm.alloca
735+
%1 = llvm.alloca %0 x vector<2xi16> : (i32) -> !llvm.ptr
736+
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
737+
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
738+
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[UNDEF]] : vector<2xi16> to i32
739+
// CHECK: llvm.return %[[BITCAST:.*]]
740+
llvm.return %2 : i32
741+
}
742+
743+
// -----
744+
745+
// LLVM arrays cannot be bitcasted, so the following cannot be promoted.
746+
747+
// CHECK-LABEL: @load_int_from_array
748+
llvm.func @load_int_from_array() -> i32 {
749+
%0 = llvm.mlir.constant(1 : i32) : i32
750+
// CHECK: llvm.alloca
751+
%1 = llvm.alloca %0 x !llvm.array<2 x i16> : (i32) -> !llvm.ptr
752+
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
753+
// CHECK-NOT: llvm.bitcast
754+
llvm.return %2 : i32
755+
}
756+
757+
// -----
758+
759+
// CHECK-LABEL: @store_int_to_float
760+
// CHECK-SAME: %[[ARG:.*]]: i32
761+
llvm.func @store_int_to_float(%arg: i32) -> i32 {
762+
%0 = llvm.mlir.constant(1 : i32) : i32
763+
// CHECK-NOT: llvm.alloca
764+
%1 = llvm.alloca %0 x f32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
765+
llvm.store %arg, %1 {alignment = 4 : i64} : i32, !llvm.ptr
766+
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
767+
// CHECK: llvm.return %[[ARG]]
768+
llvm.return %2 : i32
769+
}
770+
771+
// -----
772+
773+
// CHECK-LABEL: @store_float_to_int
774+
// CHECK-SAME: %[[ARG:.*]]: f32
775+
llvm.func @store_float_to_int(%arg: f32) -> i32 {
776+
%0 = llvm.mlir.constant(1 : i32) : i32
777+
// CHECK-NOT: llvm.alloca
778+
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
779+
llvm.store %arg, %1 {alignment = 4 : i64} : f32, !llvm.ptr
780+
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
781+
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG]] : f32 to i32
782+
// CHECK: llvm.return %[[BITCAST]]
783+
llvm.return %2 : i32
784+
}
785+
786+
// -----
787+
788+
// CHECK-LABEL: @store_int_to_vector
789+
// CHECK-SAME: %[[ARG:.*]]: i32
790+
llvm.func @store_int_to_vector(%arg: i32) -> vector<4xi8> {
791+
%0 = llvm.mlir.constant(1 : i32) : i32
792+
// CHECK-NOT: llvm.alloca
793+
%1 = llvm.alloca %0 x vector<2xi16> {alignment = 4 : i64} : (i32) -> !llvm.ptr
794+
llvm.store %arg, %1 {alignment = 4 : i64} : i32, !llvm.ptr
795+
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> vector<4xi8>
796+
// CHECK: %[[BITCAST0:.*]] = llvm.bitcast %[[ARG]] : i32 to vector<2xi16>
797+
// CHECK: %[[BITCAST1:.*]] = llvm.bitcast %[[BITCAST0]] : vector<2xi16> to vector<4xi8>
798+
// CHECK: llvm.return %[[BITCAST1]]
799+
llvm.return %2 : vector<4xi8>
800+
}
801+
802+
// -----
803+
804+
// CHECK-LABEL: @load_ptr_from_int
805+
llvm.func @load_ptr_from_int() -> !llvm.ptr {
806+
%0 = llvm.mlir.constant(1 : i32) : i32
807+
// CHECK-NOT: llvm.alloca
808+
%1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
809+
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr
810+
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
811+
// CHECK: %[[CAST:.*]] = llvm.inttoptr %[[UNDEF]] : i64 to !llvm.ptr
812+
// CHECK: llvm.return %[[CAST:.*]]
813+
llvm.return %2 : !llvm.ptr
814+
}
815+
816+
// -----
817+
818+
// CHECK-LABEL: @load_int_from_ptr
819+
llvm.func @load_int_from_ptr() -> i64 {
820+
%0 = llvm.mlir.constant(1 : i32) : i32
821+
// CHECK-NOT: llvm.alloca
822+
%1 = llvm.alloca %0 x !llvm.ptr {alignment = 4 : i64} : (i32) -> !llvm.ptr
823+
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i64
824+
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
825+
// CHECK: %[[CAST:.*]] = llvm.ptrtoint %[[UNDEF]] : !llvm.ptr to i64
826+
// CHECK: llvm.return %[[CAST:.*]]
827+
llvm.return %2 : i64
828+
}
829+
830+
// -----
831+
832+
// CHECK-LABEL: @load_ptr_addrspace_cast
833+
llvm.func @load_ptr_addrspace_cast() -> !llvm.ptr<2> {
834+
%0 = llvm.mlir.constant(1 : i32) : i32
835+
// CHECK-NOT: llvm.alloca
836+
%1 = llvm.alloca %0 x !llvm.ptr<1> {alignment = 4 : i64} : (i32) -> !llvm.ptr
837+
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr<2>
838+
// CHECK: %[[UNDEF:.*]] = llvm.mlir.undef
839+
// CHECK: %[[CAST:.*]] = llvm.addrspacecast %[[UNDEF]] : !llvm.ptr<1> to !llvm.ptr<2>
840+
// CHECK: llvm.return %[[CAST:.*]]
841+
llvm.return %2 : !llvm.ptr<2>
842+
}
843+
844+
// -----
845+
846+
// CHECK-LABEL: @stores_with_different_types
847+
// CHECK-SAME: %[[ARG0:.*]]: i64
848+
// CHECK-SAME: %[[ARG1:.*]]: f64
849+
llvm.func @stores_with_different_types(%arg0: i64, %arg1: f64, %cond: i1) -> f64 {
850+
%0 = llvm.mlir.constant(1 : i32) : i32
851+
// CHECK-NOT: llvm.alloca
852+
%1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
853+
llvm.cond_br %cond, ^bb1, ^bb2
854+
^bb1:
855+
llvm.store %arg0, %1 {alignment = 4 : i64} : i64, !llvm.ptr
856+
// CHECK: llvm.br ^[[BB3:.*]](%[[ARG0]]
857+
llvm.br ^bb3
858+
^bb2:
859+
llvm.store %arg1, %1 {alignment = 4 : i64} : f64, !llvm.ptr
860+
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[ARG1]] : f64 to i64
861+
// CHECK: llvm.br ^[[BB3]](%[[BITCAST]]
862+
llvm.br ^bb3
863+
// CHECK: ^[[BB3]](%[[BLOCK_ARG:.*]]: i64)
864+
^bb3:
865+
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
866+
// CHECK: %[[BITCAST:.*]] = llvm.bitcast %[[BLOCK_ARG]] : i64 to f64
867+
// CHECK: llvm.return %[[BITCAST]]
868+
llvm.return %2 : f64
869+
}
870+
871+
// -----
872+
873+
// Verifies that stores with smaller bitsize inputs are not replaced. A trivial
874+
// implementation will be incorrect due to endianness considerations.
875+
876+
// CHECK-LABEL: @stores_with_different_type_sizes
877+
llvm.func @stores_with_different_type_sizes(%arg0: i64, %arg1: f32, %cond: i1) -> f64 {
878+
%0 = llvm.mlir.constant(1 : i32) : i32
879+
// CHECK: llvm.alloca
880+
%1 = llvm.alloca %0 x i64 {alignment = 4 : i64} : (i32) -> !llvm.ptr
881+
llvm.cond_br %cond, ^bb1, ^bb2
882+
^bb1:
883+
llvm.store %arg0, %1 {alignment = 4 : i64} : i64, !llvm.ptr
884+
llvm.br ^bb3
885+
^bb2:
886+
llvm.store %arg1, %1 {alignment = 4 : i64} : f32, !llvm.ptr
887+
llvm.br ^bb3
888+
^bb3:
889+
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
890+
llvm.return %2 : f64
891+
}
892+
893+
// -----
894+
895+
// CHECK-LABEL: @load_smaller_int
896+
llvm.func @load_smaller_int() -> i16 {
897+
%0 = llvm.mlir.constant(1 : i32) : i32
898+
// CHECK: llvm.alloca
899+
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
900+
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i16
901+
llvm.return %2 : i16
902+
}
903+
904+
// -----
905+
906+
// CHECK-LABEL: @load_different_type_smaller
907+
llvm.func @load_different_type_smaller() -> f32 {
908+
%0 = llvm.mlir.constant(1 : i32) : i32
909+
// CHECK: llvm.alloca
910+
%1 = llvm.alloca %0 x i64 {alignment = 8 : i64} : (i32) -> !llvm.ptr
911+
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f32
912+
llvm.return %2 : f32
913+
}
914+
915+
// -----
916+
917+
// This alloca is too small for the load, still, mem2reg should not touch it.
918+
919+
// CHECK-LABEL: @impossible_load
920+
llvm.func @impossible_load() -> f64 {
921+
%0 = llvm.mlir.constant(1 : i32) : i32
922+
// CHECK: llvm.alloca
923+
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
924+
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> f64
925+
llvm.return %2 : f64
926+
}
927+
928+
// -----
929+
930+
// Verifies that mem2reg does not introduce address space casts of pointers
931+
// with different bitsize.
932+
933+
module attributes { dlti.dl_spec = #dlti.dl_spec<
934+
#dlti.dl_entry<!llvm.ptr<1>, dense<[32, 64, 64]> : vector<3xi64>>,
935+
#dlti.dl_entry<!llvm.ptr<2>, dense<[64, 64, 64]> : vector<3xi64>>
936+
>} {
937+
938+
// CHECK-LABEL: @load_ptr_addrspace_cast_different_size
939+
llvm.func @load_ptr_addrspace_cast_different_size() -> !llvm.ptr<2> {
940+
%0 = llvm.mlir.constant(1 : i32) : i32
941+
// CHECK: llvm.alloca
942+
%1 = llvm.alloca %0 x !llvm.ptr<1> {alignment = 4 : i64} : (i32) -> !llvm.ptr
943+
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> !llvm.ptr<2>
944+
llvm.return %2 : !llvm.ptr<2>
945+
}
946+
}

0 commit comments

Comments
 (0)