Skip to content

Commit 67f1d62

Browse files
committed
[RISCV] Match indices based on significant bits when forming strided ops
When checking to see if our index expressions can be converted into strided operations, we previously gave up if the index type wasn't an exact match for the intptrty for the address. Instead, we can ask known bits how many bits are needed, and proceed as long as the number of required bits fits in intptrty. Note that the motivation of this is mostly test cleanup. We canonicalize at IR such that the gep index will match the intptrty. This is mostly useful so that we can write both RV32 and RV64 tests from the same source. Its also helpful in preventing confusion - I've stumbled across this at least four times now and wasted time each one. Interestingly, this does appear to catch some loop cases that O3 does not canonicalize. I don't think the vectorizer is likely to emit indices narrower that intptr, but if another frontend did, this might pick them up. This may also pick up some pass ordering problems, but that's a happy accident at best.
1 parent 0a29827 commit 67f1d62

File tree

4 files changed

+87
-339
lines changed

4 files changed

+87
-339
lines changed

llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -373,14 +373,15 @@ RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr,
373373
if (!VecOperand)
374374
return std::make_pair(nullptr, nullptr);
375375

376-
// We can't extract the stride if the arithmetic is done at a different size
377-
// than the pointer type. Adding the stride later may not wrap correctly.
378-
// Technically we could handle wider indices, but I don't expect that in
379-
// practice.
376+
// We need the number of significant bits to match the index type. IF it
377+
// doesn't, then adding the stride later may not wrap correctly.
380378
Value *VecIndex = Ops[*VecOperand];
381379
Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType());
382-
if (VecIndex->getType() != VecIntPtrTy)
383-
return std::make_pair(nullptr, nullptr);
380+
if (VecIndex->getType()->getScalarSizeInBits() > VecIntPtrTy->getScalarSizeInBits()) {
381+
unsigned MaxBits = ComputeMaxSignificantBits(VecIndex, *DL);
382+
if (MaxBits > VecIntPtrTy->getScalarSizeInBits())
383+
return std::make_pair(nullptr, nullptr);
384+
}
384385

385386
// Handle the non-recursive case. This is what we see if the vectorizer
386387
// decides to use a scalar IV + vid on demand instead of a vector IV.
@@ -397,7 +398,8 @@ RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr,
397398

398399
// Convert stride to pointer size if needed.
399400
Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
400-
assert(Stride->getType() == IntPtrTy && "Unexpected type");
401+
assert(IntPtrTy == VecIntPtrTy->getScalarType());
402+
Stride = Builder.CreateSExtOrTrunc(Stride, IntPtrTy);
401403

402404
// Scale the stride by the size of the indexed type.
403405
if (TypeScale != 1)
@@ -437,7 +439,8 @@ RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr,
437439

438440
// Convert stride to pointer size if needed.
439441
Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
440-
assert(Stride->getType() == IntPtrTy && "Unexpected type");
442+
assert(IntPtrTy == VecIntPtrTy->getScalarType());
443+
Stride = Builder.CreateSExtOrTrunc(Stride, IntPtrTy);
441444

442445
// Scale the stride by the size of the indexed type.
443446
if (TypeScale != 1)

0 commit comments

Comments
 (0)