@@ -112,7 +112,8 @@ bool LLVM::LoadOp::loadsFrom(const MemorySlot &slot) {
112
112
113
113
bool LLVM::LoadOp::storesTo (const MemorySlot &slot) { return false ; }
114
114
115
- Value LLVM::LoadOp::getStored (const MemorySlot &slot, RewriterBase &rewriter) {
115
+ Value LLVM::LoadOp::getStored (const MemorySlot &slot, RewriterBase &rewriter,
116
+ const DataLayout &dataLayout) {
116
117
llvm_unreachable (" getStored should not be called on LoadOp" );
117
118
}
118
119
@@ -122,37 +123,124 @@ bool LLVM::StoreOp::storesTo(const MemorySlot &slot) {
122
123
return getAddr () == slot.ptr ;
123
124
}
124
125
125
- // / Checks that two types are the same or can be cast into one another.
126
- static bool areCastCompatible (const DataLayout &layout, Type lhs, Type rhs) {
127
- return lhs == rhs || (!isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(lhs) &&
128
- !isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(rhs) &&
129
- layout.getTypeSize (lhs) == layout.getTypeSize (rhs));
126
+ // / Checks if `type` can be used in any kind of conversion sequences.
127
+ static bool isSupportedTypeForConversion (Type type) {
128
+ // Aggregate types are not bitcastable.
129
+ if (isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(type))
130
+ return false ;
131
+
132
+ // LLVM vector types are only used for either pointers or target specific
133
+ // types. These types cannot be casted in the general case, thus the memory
134
+ // optimizations do not support them.
135
+ if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type))
136
+ return false ;
137
+
138
+ // Scalable types are not supported.
139
+ if (auto vectorType = dyn_cast<VectorType>(type))
140
+ return !vectorType.isScalable ();
141
+ return true ;
142
+ }
143
+
144
+ // / Checks that `rhs` can be converted to `lhs` by a sequence of casts and
145
+ // / truncations.
146
+ static bool areConversionCompatible (const DataLayout &layout, Type lhs,
147
+ Type rhs) {
148
+ if (lhs == rhs)
149
+ return true ;
150
+
151
+ // Aggregate types cannot be casted.
152
+ if (!isSupportedTypeForConversion (lhs) || !isSupportedTypeForConversion (rhs))
153
+ return false ;
154
+ return layout.getTypeSize (lhs) <= layout.getTypeSize (rhs);
130
155
}
131
156
157
+ // / Checks if `dataLayout` describes a little endian layout.
158
+ static bool isLittleEndian (const DataLayout &dataLayout) {
159
+ auto endiannessStr = dyn_cast_or_null<StringAttr>(dataLayout.getEndianness ());
160
+ return !endiannessStr || endiannessStr == " little" ;
161
+ }
162
+
163
+ // / The size of a byte in bits.
164
+ constexpr const static uint64_t kBitsInByte = 8 ;
165
+
132
166
// / Constructs operations that convert `inputValue` into a new value of type
133
167
// / `targetType`. Assumes that this conversion is possible.
134
168
static Value createConversionSequence (RewriterBase &rewriter, Location loc,
135
- Value inputValue, Type targetType) {
136
- if (inputValue.getType () == targetType)
137
- return inputValue;
169
+ Value srcValue, Type targetType,
170
+ const DataLayout &dataLayout) {
171
+ // Get the types of the source and destination values.
172
+ Type srcType = srcValue.getType ();
173
+
174
+ uint64_t srcTypeSize = dataLayout.getTypeSize (srcType);
175
+ uint64_t targetTypeSize = dataLayout.getTypeSize (targetType);
176
+
177
+ // Nothing has to be done if the types are already the same.
178
+ if (srcType == targetType)
179
+ return srcValue;
180
+
181
+ // The code below is currently not capable of handling aggregate types as it
182
+ // makes use of bitcasts. Aggregates cannot be bitcast.
183
+ // TODO: We should have a `LLVMAggregateType` base class to easily perform
184
+ // this `isa`.
185
+ if (isa<LLVM::LLVMArrayType, LLVM::LLVMStructType>(srcType) ||
186
+ isa<LLVM::LLVMArrayType, LLVM::LLVMStructType>(targetType))
187
+ return nullptr ;
188
+
189
+ // In the special case of casting one pointer to another, we want to generate
190
+ // an address space cast. Bitcasts of pointers are not allowed and using
191
+ // pointer to integer conversions are not equivalent due to the loss or
192
+ // provenance.
193
+ if (isa<LLVM::LLVMPointerType>(targetType) &&
194
+ isa<LLVM::LLVMPointerType>(srcType)) {
195
+ // Abort the conversion if the pointers have different bitwidths.
196
+ if (srcTypeSize != targetTypeSize)
197
+ return nullptr ;
198
+ return rewriter.createOrFold <LLVM::AddrSpaceCastOp>(loc, targetType,
199
+ srcValue);
200
+ }
138
201
139
- if (!isa<LLVM::LLVMPointerType>(targetType) &&
140
- !isa<LLVM::LLVMPointerType>(inputValue.getType ()))
141
- return rewriter.createOrFold <LLVM::BitcastOp>(loc, targetType, inputValue);
202
+ IntegerType valueSizeInteger =
203
+ rewriter.getIntegerType (srcTypeSize * kBitsInByte );
204
+ Value replacement = srcValue;
205
+
206
+ // First, cast the value to a same-sized integer type.
207
+ if (isa<LLVM::LLVMPointerType>(srcType))
208
+ replacement = rewriter.createOrFold <LLVM::PtrToIntOp>(loc, valueSizeInteger,
209
+ replacement);
210
+ else if (replacement.getType () != valueSizeInteger)
211
+ replacement = rewriter.createOrFold <LLVM::BitcastOp>(loc, valueSizeInteger,
212
+ replacement);
213
+
214
+ // Truncate the integer if the size of the read is less than the value.
215
+ if (targetTypeSize != srcTypeSize) {
216
+ if (!isLittleEndian (dataLayout)) {
217
+ uint64_t shiftAmount = (srcTypeSize - targetTypeSize) * kBitsInByte ;
218
+ auto shiftConstant = rewriter.create <LLVM::ConstantOp>(
219
+ loc, rewriter.getIntegerAttr (srcType, shiftAmount));
220
+ replacement =
221
+ rewriter.createOrFold <LLVM::LShrOp>(loc, srcValue, shiftConstant);
222
+ }
142
223
143
- if (!isa<LLVM::LLVMPointerType>(targetType))
144
- return rewriter.createOrFold <LLVM::PtrToIntOp>(loc, targetType, inputValue);
224
+ replacement = rewriter.create <LLVM::TruncOp>(
225
+ loc, rewriter.getIntegerType (targetTypeSize * kBitsInByte ),
226
+ replacement);
227
+ }
145
228
146
- if (!isa<LLVM::LLVMPointerType>(inputValue.getType ()))
147
- return rewriter.createOrFold <LLVM::IntToPtrOp>(loc, targetType, inputValue);
229
+ // Now cast the integer to the actual destination type if required.
230
+ if (isa<LLVM::LLVMPointerType>(targetType))
231
+ replacement =
232
+ rewriter.createOrFold <LLVM::IntToPtrOp>(loc, targetType, replacement);
233
+ else if (replacement.getType () != targetType)
234
+ replacement =
235
+ rewriter.createOrFold <LLVM::BitcastOp>(loc, targetType, replacement);
148
236
149
- return rewriter.createOrFold <LLVM::AddrSpaceCastOp>(loc, targetType,
150
- inputValue);
237
+ return replacement;
151
238
}
152
239
153
- Value LLVM::StoreOp::getStored (const MemorySlot &slot, RewriterBase &rewriter) {
154
- return createConversionSequence (rewriter, getLoc (), getValue (),
155
- slot.elemType );
240
+ Value LLVM::StoreOp::getStored (const MemorySlot &slot, RewriterBase &rewriter,
241
+ const DataLayout &dataLayout) {
242
+ return createConversionSequence (rewriter, getLoc (), getValue (), slot.elemType ,
243
+ dataLayout);
156
244
}
157
245
158
246
bool LLVM::LoadOp::canUsesBeRemoved (
@@ -167,17 +255,20 @@ bool LLVM::LoadOp::canUsesBeRemoved(
167
255
// be removed (provided it loads the exact stored value and is not
168
256
// volatile).
169
257
return blockingUse == slot.ptr && getAddr () == slot.ptr &&
170
- areCastCompatible (dataLayout, getResult ().getType (), slot.elemType ) &&
258
+ areConversionCompatible (dataLayout, getResult ().getType (),
259
+ slot.elemType ) &&
171
260
!getVolatile_ ();
172
261
}
173
262
174
263
DeletionKind LLVM::LoadOp::removeBlockingUses (
175
264
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
176
- RewriterBase &rewriter, Value reachingDefinition) {
265
+ RewriterBase &rewriter, Value reachingDefinition,
266
+ const DataLayout &dataLayout) {
177
267
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
178
268
// pointer.
179
- Value newResult = createConversionSequence (
180
- rewriter, getLoc (), reachingDefinition, getResult ().getType ());
269
+ Value newResult =
270
+ createConversionSequence (rewriter, getLoc (), reachingDefinition,
271
+ getResult ().getType (), dataLayout);
181
272
rewriter.replaceAllUsesWith (getResult (), newResult);
182
273
return DeletionKind::Delete;
183
274
}
@@ -194,13 +285,15 @@ bool LLVM::StoreOp::canUsesBeRemoved(
194
285
// store OF the slot pointer, only INTO the slot pointer.
195
286
return blockingUse == slot.ptr && getAddr () == slot.ptr &&
196
287
getValue () != slot.ptr &&
197
- areCastCompatible (dataLayout, slot.elemType , getValue ().getType ()) &&
288
+ areConversionCompatible (dataLayout, slot.elemType ,
289
+ getValue ().getType ()) &&
198
290
!getVolatile_ ();
199
291
}
200
292
201
293
DeletionKind LLVM::StoreOp::removeBlockingUses (
202
294
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
203
- RewriterBase &rewriter, Value reachingDefinition) {
295
+ RewriterBase &rewriter, Value reachingDefinition,
296
+ const DataLayout &dataLayout) {
204
297
return DeletionKind::Delete;
205
298
}
206
299
@@ -747,8 +840,8 @@ bool LLVM::MemsetOp::storesTo(const MemorySlot &slot) {
747
840
return getDst () == slot.ptr ;
748
841
}
749
842
750
- Value LLVM::MemsetOp::getStored (const MemorySlot &slot,
751
- RewriterBase &rewriter ) {
843
+ Value LLVM::MemsetOp::getStored (const MemorySlot &slot, RewriterBase &rewriter,
844
+ const DataLayout &dataLayout ) {
752
845
// TODO: Support non-integer types.
753
846
return TypeSwitch<Type, Value>(slot.elemType )
754
847
.Case ([&](IntegerType intType) -> Value {
@@ -802,7 +895,8 @@ bool LLVM::MemsetOp::canUsesBeRemoved(
802
895
803
896
DeletionKind LLVM::MemsetOp::removeBlockingUses (
804
897
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
805
- RewriterBase &rewriter, Value reachingDefinition) {
898
+ RewriterBase &rewriter, Value reachingDefinition,
899
+ const DataLayout &dataLayout) {
806
900
return DeletionKind::Delete;
807
901
}
808
902
@@ -1059,8 +1153,8 @@ bool LLVM::MemcpyOp::storesTo(const MemorySlot &slot) {
1059
1153
return memcpyStoresTo (*this , slot);
1060
1154
}
1061
1155
1062
- Value LLVM::MemcpyOp::getStored (const MemorySlot &slot,
1063
- RewriterBase &rewriter ) {
1156
+ Value LLVM::MemcpyOp::getStored (const MemorySlot &slot, RewriterBase &rewriter,
1157
+ const DataLayout &dataLayout ) {
1064
1158
return memcpyGetStored (*this , slot, rewriter);
1065
1159
}
1066
1160
@@ -1074,7 +1168,8 @@ bool LLVM::MemcpyOp::canUsesBeRemoved(
1074
1168
1075
1169
DeletionKind LLVM::MemcpyOp::removeBlockingUses (
1076
1170
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1077
- RewriterBase &rewriter, Value reachingDefinition) {
1171
+ RewriterBase &rewriter, Value reachingDefinition,
1172
+ const DataLayout &dataLayout) {
1078
1173
return memcpyRemoveBlockingUses (*this , slot, blockingUses, rewriter,
1079
1174
reachingDefinition);
1080
1175
}
@@ -1109,7 +1204,8 @@ bool LLVM::MemcpyInlineOp::storesTo(const MemorySlot &slot) {
1109
1204
}
1110
1205
1111
1206
Value LLVM::MemcpyInlineOp::getStored (const MemorySlot &slot,
1112
- RewriterBase &rewriter) {
1207
+ RewriterBase &rewriter,
1208
+ const DataLayout &dataLayout) {
1113
1209
return memcpyGetStored (*this , slot, rewriter);
1114
1210
}
1115
1211
@@ -1123,7 +1219,8 @@ bool LLVM::MemcpyInlineOp::canUsesBeRemoved(
1123
1219
1124
1220
DeletionKind LLVM::MemcpyInlineOp::removeBlockingUses (
1125
1221
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1126
- RewriterBase &rewriter, Value reachingDefinition) {
1222
+ RewriterBase &rewriter, Value reachingDefinition,
1223
+ const DataLayout &dataLayout) {
1127
1224
return memcpyRemoveBlockingUses (*this , slot, blockingUses, rewriter,
1128
1225
reachingDefinition);
1129
1226
}
@@ -1159,8 +1256,8 @@ bool LLVM::MemmoveOp::storesTo(const MemorySlot &slot) {
1159
1256
return memcpyStoresTo (*this , slot);
1160
1257
}
1161
1258
1162
- Value LLVM::MemmoveOp::getStored (const MemorySlot &slot,
1163
- RewriterBase &rewriter ) {
1259
+ Value LLVM::MemmoveOp::getStored (const MemorySlot &slot, RewriterBase &rewriter,
1260
+ const DataLayout &dataLayout ) {
1164
1261
return memcpyGetStored (*this , slot, rewriter);
1165
1262
}
1166
1263
@@ -1174,7 +1271,8 @@ bool LLVM::MemmoveOp::canUsesBeRemoved(
1174
1271
1175
1272
DeletionKind LLVM::MemmoveOp::removeBlockingUses (
1176
1273
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
1177
- RewriterBase &rewriter, Value reachingDefinition) {
1274
+ RewriterBase &rewriter, Value reachingDefinition,
1275
+ const DataLayout &dataLayout) {
1178
1276
return memcpyRemoveBlockingUses (*this , slot, blockingUses, rewriter,
1179
1277
reachingDefinition);
1180
1278
}
0 commit comments