Skip to content

Commit b846613

Browse files
committed
[X86] X86FixupVectorConstants - add destination register width to rebuildSplatCst/rebuildZeroUpperCst/rebuildExtCst callbacks
As found on #81136 - we aren't correctly handling for cases where the constant pool entry is wider than the destination register width, causing incorrect scaling of the truncated constant for load-extension cases. This first patch just pulls out the destination register width argument, its still currently driven by the constant pool entry but that will be addressed in a followup.
1 parent 7d19dc5 commit b846613

File tree

1 file changed

+29
-23
lines changed

1 file changed

+29
-23
lines changed

llvm/lib/Target/X86/X86FixupVectorConstants.cpp

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,13 @@ static std::optional<APInt> extractConstantBits(const Constant *C) {
121121
return std::nullopt;
122122
}
123123

124+
static std::optional<APInt> extractConstantBits(const Constant *C,
125+
unsigned NumBits) {
126+
if (std::optional<APInt> Bits = extractConstantBits(C))
127+
return Bits->zextOrTrunc(NumBits);
128+
return std::nullopt;
129+
}
130+
124131
// Attempt to compute the splat width of bits data by normalizing the splat to
125132
// remove undefs.
126133
static std::optional<APInt> getSplatableConstant(const Constant *C,
@@ -217,16 +224,15 @@ static Constant *rebuildConstant(LLVMContext &Ctx, Type *SclTy,
217224

218225
// Attempt to rebuild a normalized splat vector constant of the requested splat
219226
// width, built up of potentially smaller scalar values.
220-
static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumElts*/,
221-
unsigned SplatBitWidth) {
227+
static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumBits*/,
228+
unsigned /*NumElts*/, unsigned SplatBitWidth) {
222229
std::optional<APInt> Splat = getSplatableConstant(C, SplatBitWidth);
223230
if (!Splat)
224231
return nullptr;
225232

226233
// Determine scalar size to use for the constant splat vector, clamping as we
227234
// might have found a splat smaller than the original constant data.
228-
const Type *OriginalType = C->getType();
229-
Type *SclTy = OriginalType->getScalarType();
235+
Type *SclTy = C->getType()->getScalarType();
230236
unsigned NumSclBits = SclTy->getPrimitiveSizeInBits();
231237
NumSclBits = std::min<unsigned>(NumSclBits, SplatBitWidth);
232238

@@ -236,20 +242,19 @@ static Constant *rebuildSplatCst(const Constant *C, unsigned /*NumElts*/,
236242
: 64;
237243

238244
// Extract per-element bits.
239-
return rebuildConstant(OriginalType->getContext(), SclTy, *Splat, NumSclBits);
245+
return rebuildConstant(C->getContext(), SclTy, *Splat, NumSclBits);
240246
}
241247

242-
static Constant *rebuildZeroUpperCst(const Constant *C, unsigned /*NumElts*/,
248+
static Constant *rebuildZeroUpperCst(const Constant *C, unsigned NumBits,
249+
unsigned /*NumElts*/,
243250
unsigned ScalarBitWidth) {
244-
Type *Ty = C->getType();
245-
Type *SclTy = Ty->getScalarType();
246-
unsigned NumBits = Ty->getPrimitiveSizeInBits();
251+
Type *SclTy = C->getType()->getScalarType();
247252
unsigned NumSclBits = SclTy->getPrimitiveSizeInBits();
248253
LLVMContext &Ctx = C->getContext();
249254

250255
if (NumBits > ScalarBitWidth) {
251256
// Determine if the upper bits are all zero.
252-
if (std::optional<APInt> Bits = extractConstantBits(C)) {
257+
if (std::optional<APInt> Bits = extractConstantBits(C, NumBits)) {
253258
if (Bits->countLeadingZeros() >= (NumBits - ScalarBitWidth)) {
254259
// If the original constant was made of smaller elements, try to retain
255260
// those types.
@@ -266,16 +271,15 @@ static Constant *rebuildZeroUpperCst(const Constant *C, unsigned /*NumElts*/,
266271
return nullptr;
267272
}
268273

269-
static Constant *rebuildExtCst(const Constant *C, bool IsSExt, unsigned NumElts,
274+
static Constant *rebuildExtCst(const Constant *C, bool IsSExt,
275+
unsigned NumBits, unsigned NumElts,
270276
unsigned SrcEltBitWidth) {
271-
Type *Ty = C->getType();
272-
unsigned NumBits = Ty->getPrimitiveSizeInBits();
273277
unsigned DstEltBitWidth = NumBits / NumElts;
274278
assert((NumBits % NumElts) == 0 && (NumBits % SrcEltBitWidth) == 0 &&
275279
(DstEltBitWidth % SrcEltBitWidth) == 0 &&
276280
(DstEltBitWidth > SrcEltBitWidth) && "Illegal extension width");
277281

278-
if (std::optional<APInt> Bits = extractConstantBits(C)) {
282+
if (std::optional<APInt> Bits = extractConstantBits(C, NumBits)) {
279283
assert((Bits->getBitWidth() / DstEltBitWidth) == NumElts &&
280284
(Bits->getBitWidth() % DstEltBitWidth) == 0 &&
281285
"Unexpected constant extension");
@@ -290,19 +294,20 @@ static Constant *rebuildExtCst(const Constant *C, bool IsSExt, unsigned NumElts,
290294
TruncBits.insertBits(Elt.trunc(SrcEltBitWidth), I * SrcEltBitWidth);
291295
}
292296

297+
Type *Ty = C->getType();
293298
return rebuildConstant(Ty->getContext(), Ty->getScalarType(), TruncBits,
294299
SrcEltBitWidth);
295300
}
296301

297302
return nullptr;
298303
}
299-
static Constant *rebuildSExtCst(const Constant *C, unsigned NumElts,
300-
unsigned SrcEltBitWidth) {
301-
return rebuildExtCst(C, true, NumElts, SrcEltBitWidth);
304+
static Constant *rebuildSExtCst(const Constant *C, unsigned NumBits,
305+
unsigned NumElts, unsigned SrcEltBitWidth) {
306+
return rebuildExtCst(C, true, NumBits, NumElts, SrcEltBitWidth);
302307
}
303-
static Constant *rebuildZExtCst(const Constant *C, unsigned NumElts,
304-
unsigned SrcEltBitWidth) {
305-
return rebuildExtCst(C, false, NumElts, SrcEltBitWidth);
308+
static Constant *rebuildZExtCst(const Constant *C, unsigned NumBits,
309+
unsigned NumElts, unsigned SrcEltBitWidth) {
310+
return rebuildExtCst(C, false, NumBits, NumElts, SrcEltBitWidth);
306311
}
307312

308313
bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
@@ -320,7 +325,7 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
320325
int Op;
321326
int NumCstElts;
322327
int BitWidth;
323-
std::function<Constant *(const Constant *, unsigned, unsigned)>
328+
std::function<Constant *(const Constant *, unsigned, unsigned, unsigned)>
324329
RebuildConstant;
325330
};
326331
auto FixupConstant = [&](ArrayRef<FixupEntry> Fixups, unsigned OperandNo) {
@@ -335,12 +340,13 @@ bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF,
335340
assert(MI.getNumOperands() >= (OperandNo + X86::AddrNumOperands) &&
336341
"Unexpected number of operands!");
337342
if (auto *C = X86::getConstantFromPool(MI, OperandNo)) {
343+
unsigned NumBits = C->getType()->getPrimitiveSizeInBits();
338344
for (const FixupEntry &Fixup : Fixups) {
339345
if (Fixup.Op) {
340346
// Construct a suitable constant and adjust the MI to use the new
341347
// constant pool entry.
342-
if (Constant *NewCst =
343-
Fixup.RebuildConstant(C, Fixup.NumCstElts, Fixup.BitWidth)) {
348+
if (Constant *NewCst = Fixup.RebuildConstant(
349+
C, NumBits, Fixup.NumCstElts, Fixup.BitWidth)) {
344350
unsigned NewCPI =
345351
CP->getConstantPoolIndex(NewCst, Align(Fixup.BitWidth / 8));
346352
MI.setDesc(TII->get(Fixup.Op));

0 commit comments

Comments
 (0)