-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[LoopIdiom] Improve code; use SCEVPatternMatch (NFC) #139540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-transforms Author: Ramkumar Ramachandra (artagnon) ChangesFull diff: https://github.com/llvm/llvm-project/pull/139540.diff 1 Files Affected:
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index 8f5d1ecba982d..5f59ec6daaba5 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -32,7 +32,6 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
-#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
@@ -49,6 +48,7 @@
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
@@ -91,6 +91,7 @@
#include <vector>
using namespace llvm;
+using namespace SCEVPatternMatch;
#define DEBUG_TYPE "loop-idiom"
@@ -340,9 +341,8 @@ bool LoopIdiomRecognize::runOnCountableLoop() {
// If this loop executes exactly one time, then it should be peeled, not
// optimized by this pass.
- if (const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount))
- if (BECst->getAPInt() == 0)
- return false;
+ if (match(BECount, m_scev_SpecificInt(0)))
+ return false;
SmallVector<BasicBlock *, 8> ExitBlocks;
CurLoop->getUniqueExitBlocks(ExitBlocks);
@@ -805,20 +805,17 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI,
// Check if the stride matches the size of the memcpy. If so, then we know
// that every byte is touched in the loop.
- const SCEVConstant *ConstStoreStride =
- dyn_cast<SCEVConstant>(StoreEv->getOperand(1));
- const SCEVConstant *ConstLoadStride =
- dyn_cast<SCEVConstant>(LoadEv->getOperand(1));
- if (!ConstStoreStride || !ConstLoadStride)
+ const APInt *StoreStrideValue, *LoadStrideValue;
+ if (!match(StoreEv->getOperand(1), m_scev_APInt(StoreStrideValue)) ||
+ !match(LoadEv->getOperand(1), m_scev_APInt(LoadStrideValue)))
return false;
- APInt StoreStrideValue = ConstStoreStride->getAPInt();
- APInt LoadStrideValue = ConstLoadStride->getAPInt();
// Huge stride value - give up
- if (StoreStrideValue.getBitWidth() > 64 || LoadStrideValue.getBitWidth() > 64)
+ if (StoreStrideValue->getBitWidth() > 64 ||
+ LoadStrideValue->getBitWidth() > 64)
return false;
- if (SizeInBytes != StoreStrideValue && SizeInBytes != -StoreStrideValue) {
+ if (SizeInBytes != *StoreStrideValue && SizeInBytes != -*StoreStrideValue) {
ORE.emit([&]() {
return OptimizationRemarkMissed(DEBUG_TYPE, "SizeStrideUnequal", MCI)
<< ore::NV("Inst", "memcpy") << " in "
@@ -829,8 +826,8 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI,
return false;
}
- int64_t StoreStrideInt = StoreStrideValue.getSExtValue();
- int64_t LoadStrideInt = LoadStrideValue.getSExtValue();
+ int64_t StoreStrideInt = StoreStrideValue->getSExtValue();
+ int64_t LoadStrideInt = LoadStrideValue->getSExtValue();
// Check if the load stride matches the store stride.
if (StoreStrideInt != LoadStrideInt)
return false;
@@ -879,15 +876,14 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,
// we know that every byte is touched in the loop.
LLVM_DEBUG(dbgs() << " memset size is constant\n");
uint64_t SizeInBytes = cast<ConstantInt>(MSI->getLength())->getZExtValue();
- const SCEVConstant *ConstStride = dyn_cast<SCEVConstant>(Ev->getOperand(1));
- if (!ConstStride)
+ const APInt *Stride;
+ if (!match(Ev->getOperand(1), m_scev_APInt(Stride)))
return false;
- APInt Stride = ConstStride->getAPInt();
- if (SizeInBytes != Stride && SizeInBytes != -Stride)
+ if (SizeInBytes != *Stride && SizeInBytes != -*Stride)
return false;
- IsNegStride = SizeInBytes == -Stride;
+ IsNegStride = SizeInBytes == -*Stride;
} else {
// Memset size is non-constant.
// Check if the pointer stride matches the memset size.
@@ -963,11 +959,11 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
// If the loop iterates a fixed number of times, we can refine the access size
// to be exactly the size of the memset, which is (BECount+1)*StoreSize
- const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount);
- const SCEVConstant *ConstSize = dyn_cast<SCEVConstant>(StoreSizeSCEV);
- if (BECst && ConstSize) {
- std::optional<uint64_t> BEInt = BECst->getAPInt().tryZExtValue();
- std::optional<uint64_t> SizeInt = ConstSize->getAPInt().tryZExtValue();
+ const APInt *BECst, *ConstSize;
+ if (match(BECount, m_scev_APInt(BECst)) &&
+ match(StoreSizeSCEV, m_scev_APInt(ConstSize))) {
+ std::optional<uint64_t> BEInt = BECst->tryZExtValue();
+ std::optional<uint64_t> SizeInt = ConstSize->tryZExtValue();
// FIXME: Should this check for overflow?
if (BEInt && SizeInt)
AccessSize = LocationSize::precise((*BEInt + 1) * *SizeInt);
@@ -1605,16 +1601,11 @@ class StrlenVerifier {
LLVM_DEBUG(dbgs() << "pointer load scev: " << *LoadEv << "\n");
- const SCEVConstant *Step =
- dyn_cast<SCEVConstant>(LoadEv->getStepRecurrence(*SE));
- if (!Step)
+ const APInt *Step;
+ if (!match(LoadEv->getStepRecurrence(*SE), m_scev_APInt(Step)))
return false;
- unsigned StepSize = 0;
- StepSizeCI = dyn_cast<ConstantInt>(Step->getValue());
- if (!StepSizeCI)
- return false;
- StepSize = StepSizeCI->getZExtValue();
+ unsigned StepSize = Step->getZExtValue();
// Verify that StepSize is consistent with platform char width.
OpWidth = OperandType->getIntegerBitWidth();
@@ -3294,9 +3285,7 @@ bool LoopIdiomRecognize::recognizeShiftUntilZero() {
// Ok, transform appears worthwhile.
MadeChange = true;
- bool OffsetIsZero = false;
- if (auto *ExtraOffsetExprC = dyn_cast<SCEVConstant>(ExtraOffsetExpr))
- OffsetIsZero = ExtraOffsetExprC->isZero();
+ bool OffsetIsZero = match(ExtraOffsetExpr, m_scev_SpecificInt(0));
// Step 1: Compute the loop's final IV value / trip count.
|
Gentle ping. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
||
const SCEV *PointerStrideSCEV = Ev->getOperand(1); | ||
const SCEV *MemsetSizeSCEV = SE->getSCEV(MSI->getLength()); | ||
if (!PointerStrideSCEV || !MemsetSizeSCEV) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Side note: This check is pointless.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for noticing! Will strip in a follow-up.
No description provided.