Skip to content

Commit 0722800

Browse files
committed
[RISCV] Match constant indices of non-index type when forming strided ops (#65777)
When checking to see if our index expressions can be converted into strided operations, we previously gave up if the index type wasn't an exact match for the intptrty for the address. Per gep semantics, this mismatch implies a sext or trunc cast to the respective index type. For constants, go ahead and evaluate that cast instead of giving up. Note that the motivation of this is mostly test cleanup. We canonicalize at IR such that the gep index will match the intptrty. This is mostly useful so that we can write both RV32 and RV64 tests from the same source. Its also helpful in preventing confusion - I've stumbled across this at least four times now and wasted time each one. Note: The test change for scatters unit stride cases contains a minor regression for rv32 and 64 bit indices. This is an artifact of order in which changes are landing. This will be addressed in a near future change for all configurations.
1 parent fa44ec7 commit 0722800

File tree

3 files changed

+71
-440
lines changed

3 files changed

+71
-440
lines changed

llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -376,11 +376,19 @@ RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr,
376376
// We can't extract the stride if the arithmetic is done at a different size
377377
// than the pointer type. Adding the stride later may not wrap correctly.
378378
// Technically we could handle wider indices, but I don't expect that in
379-
// practice.
379+
// practice. Handle one special case here - constants. This simplifies
380+
// writing test cases.
380381
Value *VecIndex = Ops[*VecOperand];
381382
Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType());
382-
if (VecIndex->getType() != VecIntPtrTy)
383-
return std::make_pair(nullptr, nullptr);
383+
if (VecIndex->getType() != VecIntPtrTy) {
384+
auto *VecIndexC = dyn_cast<Constant>(VecIndex);
385+
if (!VecIndexC)
386+
return std::make_pair(nullptr, nullptr);
387+
if (VecIndex->getType()->getScalarSizeInBits() > VecIntPtrTy->getScalarSizeInBits())
388+
VecIndex = ConstantFoldCastInstruction(Instruction::Trunc, VecIndexC, VecIntPtrTy);
389+
else
390+
VecIndex = ConstantFoldCastInstruction(Instruction::SExt, VecIndexC, VecIntPtrTy);
391+
}
384392

385393
// Handle the non-recursive case. This is what we see if the vectorizer
386394
// decides to use a scalar IV + vid on demand instead of a vector IV.

0 commit comments

Comments
 (0)