@@ -112,7 +112,8 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
112
112
113
113
bool LLVM::LoadOp::storesTo (const MemorySlot &slot) { return false ; }
114
114
115
- Value LLVM::LoadOp::getStored (const MemorySlot &slot, RewriterBase &rewriter) {
115
+ Value LLVM::LoadOp::getStored (const MemorySlot &slot, RewriterBase &rewriter,
116
+ const DataLayout &dataLayout) {
116
117
llvm_unreachable (" getStored should not be called on LoadOp" );
117
118
}
118
119
@@ -122,37 +123,121 @@ bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
122
123
return getAddr () == slot.ptr ;
123
124
}
124
125
125
- // / Checks that two types are the same or can be cast into one another.
126
- static bool areCastCompatible (const DataLayout &layout, Type lhs, Type rhs) {
127
- return lhs == rhs || (!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(lhs) &&
128
- !isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(rhs) &&
129
- layout.getTypeSize (lhs) == layout.getTypeSize (rhs));
126
+ // / Checks if `type` can be used in any kind of conversion sequences.
127
+ static bool isSupportedTypeForConversion (Type type) {
128
+ // Aggregate types are not bitcastable.
129
+ if (isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(type))
130
+ return false ;
131
+
132
+ // LLVM vector types are only used for either pointers or target specific
133
+ // types. These types cannot be casted in the general case, thus the memory
134
+ // optimizations do not support them.
135
+ if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
136
+ return false ;
137
+
138
+ // Scalable types are not supported.
139
+ if (auto vectorType = dyn_cast<VectorType>(type))
140
+ return !vectorType.isScalable ();
141
+ return true ;
130
142
}
131
143
144
+ // / Checks that `rhs` can be converted to `lhs` by a sequence of casts and
145
+ // / truncations.
146
+ static bool areConversionCompatible (const DataLayout &layout, Type targetType,
147
+ Type srcType) {
148
+ if (targetType == srcType)
149
+ return true ;
150
+
151
+ if (!isSupportedTypeForConversion (targetType) ||
152
+ !isSupportedTypeForConversion (srcType))
153
+ return false ;
154
+
155
+ // Pointer casts will only be sane when the bitsize of both pointer types is
156
+ // the same.
157
+ if (isa<LLVM::LLVMPointerType>(targetType) &&
158
+ isa<LLVM::LLVMPointerType>(srcType))
159
+ return layout.getTypeSize (targetType) == layout.getTypeSize (srcType);
160
+
161
+ return layout.getTypeSize (targetType) <= layout.getTypeSize (srcType);
162
+ }
163
+
164
+ // / Checks if `dataLayout` describes a little endian layout.
165
+ static bool isBigEndian (const DataLayout &dataLayout) {
166
+ auto endiannessStr = dyn_cast_or_null<StringAttr>(dataLayout.getEndianness ());
167
+ return endiannessStr && endiannessStr == " big" ;
168
+ }
169
+
170
+ // / The size of a byte in bits.
171
+ constexpr const static uint64_t kBitsInByte = 8 ;
172
+
132
173
// / Constructs operations that convert `inputValue` into a new value of type
133
174
// / `targetType`. Assumes that this conversion is possible.
134
175
static Value createConversionSequence (RewriterBase &rewriter, Location loc,
135
- Value inputValue, Type targetType) {
136
- if (inputValue.getType () == targetType)
137
- return inputValue;
138
-
139
- if (!isa<LLVM::LLVMPointerType>(targetType) &&
140
- !isa<LLVM::LLVMPointerType>(inputValue.getType ()))
141
- return rewriter.createOrFold <LLVM::BitcastOp>(loc, targetType, inputValue);
176
+ Value srcValue, Type targetType,
177
+ const DataLayout &dataLayout) {
178
+ // Get the types of the source and target values.
179
+ Type srcType = srcValue.getType ();
180
+ assert (areConversionCompatible (dataLayout, targetType, srcType) &&
181
+ " expected that the compatibility was checked before" );
182
+
183
+ uint64_t srcTypeSize = dataLayout.getTypeSize (srcType);
184
+ uint64_t targetTypeSize = dataLayout.getTypeSize (targetType);
185
+
186
+ // Nothing has to be done if the types are already the same.
187
+ if (srcType == targetType)
188
+ return srcValue;
189
+
190
+ // In the special case of casting one pointer to another, we want to generate
191
+ // an address space cast. Bitcasts of pointers are not allowed and using
192
+ // pointer to integer conversions are not equivalent due to the loss of
193
+ // provenance.
194
+ if (isa<LLVM::LLVMPointerType>(targetType) &&
195
+ isa<LLVM::LLVMPointerType>(srcType))
196
+ return rewriter.createOrFold <LLVM::AddrSpaceCastOp>(loc, targetType,
197
+ srcValue);
198
+
199
+ IntegerType valueSizeInteger =
200
+ rewriter.getIntegerType (srcTypeSize * kBitsInByte );
201
+ Value replacement = srcValue;
202
+
203
+ // First, cast the value to a same-sized integer type.
204
+ if (isa<LLVM::LLVMPointerType>(srcType))
205
+ replacement = rewriter.createOrFold <LLVM::PtrToIntOp>(loc, valueSizeInteger,
206
+ replacement);
207
+ else if (replacement.getType () != valueSizeInteger)
208
+ replacement = rewriter.createOrFold <LLVM::BitcastOp>(loc, valueSizeInteger,
209
+ replacement);
210
+
211
+ // Truncate the integer if the size of the target is less than the value.
212
+ if (targetTypeSize != srcTypeSize) {
213
+ if (isBigEndian (dataLayout)) {
214
+ uint64_t shiftAmount = (srcTypeSize - targetTypeSize) * kBitsInByte ;
215
+ auto shiftConstant = rewriter.create <LLVM::ConstantOp>(
216
+ loc, rewriter.getIntegerAttr (srcType, shiftAmount));
217
+ replacement =
218
+ rewriter.createOrFold <LLVM::LShrOp>(loc, srcValue, shiftConstant);
219
+ }
142
220
143
- if (!isa<LLVM::LLVMPointerType>(targetType))
144
- return rewriter.createOrFold <LLVM::PtrToIntOp>(loc, targetType, inputValue);
221
+ replacement = rewriter.create <LLVM::TruncOp>(
222
+ loc, rewriter.getIntegerType (targetTypeSize * kBitsInByte ),
223
+ replacement);
224
+ }
145
225
146
- if (!isa<LLVM::LLVMPointerType>(inputValue.getType ()))
147
- return rewriter.createOrFold <LLVM::IntToPtrOp>(loc, targetType, inputValue);
226
+ // Now cast the integer to the actual target type if required.
227
+ if (isa<LLVM::LLVMPointerType>(targetType))
228
+ replacement =
229
+ rewriter.createOrFold <LLVM::IntToPtrOp>(loc, targetType, replacement);
230
+ else if (replacement.getType () != targetType)
231
+ replacement =
232
+ rewriter.createOrFold <LLVM::BitcastOp>(loc, targetType, replacement);
148
233
149
- return rewriter.createOrFold <LLVM::AddrSpaceCastOp>(loc, targetType,
150
- inputValue);
234
+ return replacement;
151
235
}
152
236
153
- Value LLVM::StoreOp::getStored (const MemorySlot &slot, RewriterBase &rewriter) {
154
- return createConversionSequence (rewriter, getLoc (), getValue (),
155
- slot.elemType );
237
+ Value LLVM::StoreOp::getStored (const MemorySlot &slot, RewriterBase &rewriter,
238
+ const DataLayout &dataLayout) {
239
+ return createConversionSequence (rewriter, getLoc (), getValue (), slot.elemType ,
240
+ dataLayout);
156
241
}
157
242
158
243
bool LLVM::LoadOp::canUsesBeRemoved (
@@ -167,17 +252,20 @@ bool LLVM::LoadOp::canUsesBeRemoved(
167
252
// be removed (provided it loads the exact stored value and is not
168
253
// volatile).
169
254
return blockingUse == slot.ptr && getAddr () == slot.ptr &&
170
- areCastCompatible (dataLayout, getResult ().getType (), slot.elemType ) &&
255
+ areConversionCompatible (dataLayout, getResult ().getType (),
256
+ slot.elemType ) &&
171
257
!getVolatile_ ();
172
258
}
173
259
174
260
DeletionKind LLVM::LoadOp::removeBlockingUses (
175
261
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
176
- RewriterBase &rewriter, Value reachingDefinition) {
262
+ RewriterBase &rewriter, Value reachingDefinition,
263
+ const DataLayout &dataLayout) {
177
264
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
178
265
// pointer.
179
- Value newResult = createConversionSequence (
180
- rewriter, getLoc (), reachingDefinition, getResult ().getType ());
266
+ Value newResult =
267
+ createConversionSequence (rewriter, getLoc (), reachingDefinition,
268
+ getResult ().getType (), dataLayout);
181
269
rewriter.replaceAllUsesWith (getResult (), newResult);
182
270
return DeletionKind::Delete;
183
271
}
@@ -194,13 +282,15 @@ bool LLVM::StoreOp::canUsesBeRemoved(
194
282
// store OF the slot pointer, only INTO the slot pointer.
195
283
return blockingUse == slot.ptr && getAddr () == slot.ptr &&
196
284
getValue () != slot.ptr &&
197
- areCastCompatible (dataLayout, slot.elemType , getValue ().getType ()) &&
285
+ areConversionCompatible (dataLayout, slot.elemType ,
286
+ getValue ().getType ()) &&
198
287
!getVolatile_ ();
199
288
}
200
289
201
290
DeletionKind LLVM::StoreOp::removeBlockingUses (
202
291
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
203
- RewriterBase &rewriter, Value reachingDefinition) {
292
+ RewriterBase &rewriter, Value reachingDefinition,
293
+ const DataLayout &dataLayout) {
204
294
return DeletionKind::Delete;
205
295
}
206
296
@@ -747,8 +837,8 @@ bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
747
837
return getDst () == slot.ptr ;
748
838
}
749
839
750
- Value LLVM::MemsetOp::getStored (const MemorySlot &slot,
751
- RewriterBase &rewriter ) {
840
+ Value LLVM::MemsetOp::getStored (const MemorySlot &slot, RewriterBase &rewriter,
841
+ const DataLayout &dataLayout ) {
752
842
// TODO: Support non-integer types.
753
843
return TypeSwitch<Type, Value>(slot.elemType )
754
844
.Case ([&](IntegerType intType) -> Value {
@@ -802,7 +892,8 @@ bool LLVM::MemsetOp::canUsesBeRemoved(
802
892
803
893
DeletionKind LLVM::MemsetOp::removeBlockingUses (
804
894
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
805
- RewriterBase &rewriter, Value reachingDefinition) {
895
+ RewriterBase &rewriter, Value reachingDefinition,
896
+ const DataLayout &dataLayout) {
806
897
return DeletionKind::Delete;
807
898
}
808
899
@@ -1059,8 +1150,8 @@ bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
1059
1150
return memcpyStoresTo (*this , slot);
1060
1151
}
1061
1152
1062
- Value LLVM::MemcpyOp::getStored (const MemorySlot &slot,
1063
- RewriterBase &rewriter ) {
1153
+ Value LLVM::MemcpyOp::getStored (const MemorySlot &slot, RewriterBase &rewriter,
1154
+ const DataLayout &dataLayout ) {
1064
1155
return memcpyGetStored (*this , slot, rewriter);
1065
1156
}
1066
1157
@@ -1074,7 +1165,8 @@ bool LLVM::MemcpyOp::canUsesBeRemoved(
1074
1165
1075
1166
DeletionKind LLVM::MemcpyOp::removeBlockingUses (
1076
1167
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1077
- RewriterBase &rewriter, Value reachingDefinition) {
1168
+ RewriterBase &rewriter, Value reachingDefinition,
1169
+ const DataLayout &dataLayout) {
1078
1170
return memcpyRemoveBlockingUses (*this , slot, blockingUses, rewriter,
1079
1171
reachingDefinition);
1080
1172
}
@@ -1109,7 +1201,8 @@ bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
1109
1201
}
1110
1202
1111
1203
Value LLVM::MemcpyInlineOp::getStored (const MemorySlot &slot,
1112
- RewriterBase &rewriter) {
1204
+ RewriterBase &rewriter,
1205
+ const DataLayout &dataLayout) {
1113
1206
return memcpyGetStored (*this , slot, rewriter);
1114
1207
}
1115
1208
@@ -1123,7 +1216,8 @@ bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
1123
1216
1124
1217
DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses (
1125
1218
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1126
- RewriterBase &rewriter, Value reachingDefinition) {
1219
+ RewriterBase &rewriter, Value reachingDefinition,
1220
+ const DataLayout &dataLayout) {
1127
1221
return memcpyRemoveBlockingUses (*this , slot, blockingUses, rewriter,
1128
1222
reachingDefinition);
1129
1223
}
@@ -1159,8 +1253,8 @@ bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
1159
1253
return memcpyStoresTo (*this , slot);
1160
1254
}
1161
1255
1162
- Value LLVM::MemmoveOp::getStored (const MemorySlot &slot,
1163
- RewriterBase &rewriter ) {
1256
+ Value LLVM::MemmoveOp::getStored (const MemorySlot &slot, RewriterBase &rewriter,
1257
+ const DataLayout &dataLayout ) {
1164
1258
return memcpyGetStored (*this , slot, rewriter);
1165
1259
}
1166
1260
@@ -1174,7 +1268,8 @@ bool LLVM::MemmoveOp::canUsesBeRemoved(
1174
1268
1175
1269
DeletionKind LLVM::MemmoveOp::removeBlockingUses (
1176
1270
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1177
- RewriterBase &rewriter, Value reachingDefinition) {
1271
+ RewriterBase &rewriter, Value reachingDefinition,
1272
+ const DataLayout &dataLayout) {
1178
1273
return memcpyRemoveBlockingUses (*this , slot, blockingUses, rewriter,
1179
1274
reachingDefinition);
1180
1275
}
0 commit comments