@@ -121,6 +121,13 @@ static std::optional<APInt> extractConstantBits(const Constant *C) {
121
121
return std::nullopt;
122
122
}
123
123
124
+ static std::optional<APInt> extractConstantBits(const Constant *C,
125
+ unsigned NumBits) {
126
+ if (std::optional<APInt> Bits = extractConstantBits(C))
127
+ return Bits->zextOrTrunc(NumBits);
128
+ return std::nullopt;
129
+ }
130
+
124
131
// Attempt to compute the splat width of bits data by normalizing the splat to
125
132
// remove undefs.
126
133
static std::optional<APInt> getSplatableConstant(const Constant *C,
@@ -217,16 +224,15 @@ static Constant *rebuildConstant(LLVMContext &Ctx, Type *SclTy,
217
224
218
225
// Attempt to rebuild a normalized splat vector constant of the requested splat
219
226
// width, built up of potentially smaller scalar values.
220
- static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumElts */,
221
- unsigned SplatBitWidth) {
227
+ static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumBits */,
228
+ unsigned /*NumElts*/, unsigned SplatBitWidth) {
222
229
std::optional<APInt> Splat = getSplatableConstant(C, SplatBitWidth);
223
230
if (!Splat)
224
231
return nullptr;
225
232
226
233
// Determine scalar size to use for the constant splat vector, clamping as we
227
234
// might have found a splat smaller than the original constant data.
228
- const Type *OriginalType = C->getType();
229
- Type *SclTy = OriginalType->getScalarType();
235
+ Type *SclTy = C->getType()->getScalarType();
230
236
unsigned NumSclBits = SclTy->getPrimitiveSizeInBits();
231
237
NumSclBits = std::min<unsigned>(NumSclBits, SplatBitWidth);
232
238
@@ -236,20 +242,19 @@ static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumElts*/,
236
242
: 64;
237
243
238
244
// Extract per-element bits.
239
- return rebuildConstant(OriginalType ->getContext(), SclTy, *Splat, NumSclBits);
245
+ return rebuildConstant(C ->getContext(), SclTy, *Splat, NumSclBits);
240
246
}
241
247
242
- static Constant *rebuildZeroUpperCst(const Constant *C, unsigned /*NumElts*/,
248
+ static Constant *rebuildZeroUpperCst(const Constant *C, unsigned NumBits,
249
+ unsigned /*NumElts*/,
243
250
unsigned ScalarBitWidth) {
244
- Type *Ty = C->getType();
245
- Type *SclTy = Ty->getScalarType();
246
- unsigned NumBits = Ty->getPrimitiveSizeInBits();
251
+ Type *SclTy = C->getType()->getScalarType();
247
252
unsigned NumSclBits = SclTy->getPrimitiveSizeInBits();
248
253
LLVMContext &Ctx = C->getContext();
249
254
250
255
if (NumBits > ScalarBitWidth) {
251
256
// Determine if the upper bits are all zero.
252
- if (std::optional<APInt> Bits = extractConstantBits(C)) {
257
+ if (std::optional<APInt> Bits = extractConstantBits(C, NumBits )) {
253
258
if (Bits->countLeadingZeros() >= (NumBits - ScalarBitWidth)) {
254
259
// If the original constant was made of smaller elements, try to retain
255
260
// those types.
@@ -266,16 +271,15 @@ static Constant *rebuildZeroUpperCst(const Constant *C, unsigned /*NumElts*/,
266
271
return nullptr;
267
272
}
268
273
269
- static Constant *rebuildExtCst(const Constant *C, bool IsSExt, unsigned NumElts,
274
+ static Constant *rebuildExtCst(const Constant *C, bool IsSExt,
275
+ unsigned NumBits, unsigned NumElts,
270
276
unsigned SrcEltBitWidth) {
271
- Type *Ty = C->getType();
272
- unsigned NumBits = Ty->getPrimitiveSizeInBits();
273
277
unsigned DstEltBitWidth = NumBits / NumElts;
274
278
assert((NumBits % NumElts) == 0 && (NumBits % SrcEltBitWidth) == 0 &&
275
279
(DstEltBitWidth % SrcEltBitWidth) == 0 &&
276
280
(DstEltBitWidth > SrcEltBitWidth) && "Illegal extension width");
277
281
278
- if (std::optional<APInt> Bits = extractConstantBits(C)) {
282
+ if (std::optional<APInt> Bits = extractConstantBits(C, NumBits )) {
279
283
assert((Bits->getBitWidth() / DstEltBitWidth) == NumElts &&
280
284
(Bits->getBitWidth() % DstEltBitWidth) == 0 &&
281
285
"Unexpected constant extension");
@@ -290,19 +294,20 @@ static Constant *rebuildExtCst(const Constant *C, bool IsSExt, unsigned NumElts,
290
294
TruncBits.insertBits(Elt.trunc(SrcEltBitWidth), I * SrcEltBitWidth);
291
295
}
292
296
297
+ Type *Ty = C->getType();
293
298
return rebuildConstant(Ty->getContext(), Ty->getScalarType(), TruncBits,
294
299
SrcEltBitWidth);
295
300
}
296
301
297
302
return nullptr;
298
303
}
299
- static Constant *rebuildSExtCst(const Constant *C, unsigned NumElts ,
300
- unsigned SrcEltBitWidth) {
301
- return rebuildExtCst(C, true, NumElts, SrcEltBitWidth);
304
+ static Constant *rebuildSExtCst(const Constant *C, unsigned NumBits ,
305
+ unsigned NumElts, unsigned SrcEltBitWidth) {
306
+ return rebuildExtCst(C, true, NumBits, NumElts, SrcEltBitWidth);
302
307
}
303
- static Constant *rebuildZExtCst(const Constant *C, unsigned NumElts ,
304
- unsigned SrcEltBitWidth) {
305
- return rebuildExtCst(C, false, NumElts, SrcEltBitWidth);
308
+ static Constant *rebuildZExtCst(const Constant *C, unsigned NumBits ,
309
+ unsigned NumElts, unsigned SrcEltBitWidth) {
310
+ return rebuildExtCst(C, false, NumBits, NumElts, SrcEltBitWidth);
306
311
}
307
312
308
313
bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
@@ -320,7 +325,7 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
320
325
int Op;
321
326
int NumCstElts;
322
327
int BitWidth;
323
- std::function<Constant *(const Constant *, unsigned, unsigned)>
328
+ std::function<Constant *(const Constant *, unsigned, unsigned, unsigned )>
324
329
RebuildConstant;
325
330
};
326
331
auto FixupConstant = [&](ArrayRef<FixupEntry> Fixups, unsigned OperandNo) {
@@ -335,12 +340,13 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
335
340
assert(MI.getNumOperands() >= (OperandNo + X86::AddrNumOperands) &&
336
341
"Unexpected number of operands!");
337
342
if (auto *C = X86::getConstantFromPool(MI, OperandNo)) {
343
+ unsigned NumBits = C->getType()->getPrimitiveSizeInBits();
338
344
for (const FixupEntry &Fixup : Fixups) {
339
345
if (Fixup.Op) {
340
346
// Construct a suitable constant and adjust the MI to use the new
341
347
// constant pool entry.
342
- if (Constant *NewCst =
343
- Fixup.RebuildConstant(C , Fixup.NumCstElts, Fixup.BitWidth)) {
348
+ if (Constant *NewCst = Fixup.RebuildConstant(
349
+ C, NumBits , Fixup.NumCstElts, Fixup.BitWidth)) {
344
350
unsigned NewCPI =
345
351
CP->getConstantPoolIndex(NewCst, Align(Fixup.BitWidth / 8));
346
352
MI.setDesc(TII->get(Fixup.Op));
0 commit comments