Skip to content

Commit c4ba734

Browse files
[mlir] Compare std::optional<T> to values directly (NFC) (#144241)
This patch transforms: X && *X == Y to: X == Y where X is of std::optional<T>, and Y is of T or similar.
1 parent a0c00cc commit c4ba734

File tree

7 files changed

+11
-16
lines changed

7 files changed

+11
-16
lines changed

mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ getTreePredicates(std::vector<PositionalPredicate> &predList, Value val,
173173

174174
// Ignore the specified operand, usually because this position was
175175
// visited in an upward traversal via an iterative choice.
176-
if (ignoreOperand && *ignoreOperand == operandIt.index())
176+
if (ignoreOperand == operandIt.index())
177177
continue;
178178

179179
Position *pos =

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2367,7 +2367,7 @@ struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
23672367
if (forOp.getNumResults() == 0)
23682368
return success();
23692369
std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
2370-
if (tripCount && *tripCount == 0) {
2370+
if (tripCount == 0) {
23712371
// The initial values of the iteration arguments would be the op's
23722372
// results.
23732373
rewriter.replaceOp(forOp, forOp.getInits());
@@ -2447,7 +2447,7 @@ void AffineForOp::getSuccessorRegions(
24472447

24482448
// From the loop body, if the trip count is one, we can only branch back to
24492449
// the parent.
2450-
if (!point.isParent() && tripCount && *tripCount == 1) {
2450+
if (!point.isParent() && tripCount == 1) {
24512451
regions.push_back(RegionSuccessor(getResults()));
24522452
return;
24532453
}
@@ -2460,8 +2460,7 @@ void AffineForOp::getSuccessorRegions(
24602460

24612461
/// Returns true if the affine.for has zero iterations in trivial cases.
24622462
static bool hasTrivialZeroTripCount(AffineForOp op) {
2463-
std::optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
2464-
return tripCount && *tripCount == 0;
2463+
return getTrivialConstantTripCount(op) == 0;
24652464
}
24662465

24672466
LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
@@ -4789,7 +4788,7 @@ struct DropUnitExtentBasis
47894788
llvm::enumerate(delinearizeOp.getPaddedBasis())) {
47904789
std::optional<int64_t> basisVal =
47914790
basis ? getConstantIntValue(basis) : std::nullopt;
4792-
if (basisVal && *basisVal == 1)
4791+
if (basisVal == 1)
47934792
replacements[index] = getZero();
47944793
else
47954794
newBasis.push_back(basis);

mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,8 +1015,7 @@ LogicalResult mlir::affine::loopUnrollByFactor(
10151015

10161016
std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
10171017
if (unrollFactor == 1) {
1018-
if (mayBeConstantTripCount && *mayBeConstantTripCount == 1 &&
1019-
failed(promoteIfSingleIteration(forOp)))
1018+
if (mayBeConstantTripCount == 1 && failed(promoteIfSingleIteration(forOp)))
10201019
return failure();
10211020
return success();
10221021
}
@@ -1103,8 +1102,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
11031102

11041103
std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
11051104
if (unrollJamFactor == 1) {
1106-
if (mayBeConstantTripCount && *mayBeConstantTripCount == 1 &&
1107-
failed(promoteIfSingleIteration(forOp)))
1105+
if (mayBeConstantTripCount == 1 && failed(promoteIfSingleIteration(forOp)))
11081106
return failure();
11091107
return success();
11101108
}

mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -606,8 +606,7 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
606606
int64_t padRank = sourceShape.size();
607607

608608
auto isStaticZero = [](OpFoldResult f) {
609-
std::optional<int64_t> maybeInt = getConstantIntValue(f);
610-
return maybeInt && *maybeInt == 0;
609+
return getConstantIntValue(f) == 0;
611610
};
612611

613612
llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
688688
// tensors with "0" dimensions would never be constructed.
689689
int64_t shapeSize = shape[r];
690690
std::optional<int64_t> sizeCst = getConstantIntValue(size);
691-
auto hasTileSizeOne = sizeCst && *sizeCst == 1;
691+
auto hasTileSizeOne = sizeCst == 1;
692692
auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) &&
693693
((shapeSize % *sizeCst) == 0);
694694
if (!hasTileSizeOne && !dividesEvenly) {

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
737737
spirv::SPIRVDialect::getAttributeName(
738738
spirv::Decoration::BuiltIn))) {
739739
auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
740-
if (varBuiltIn && *varBuiltIn == builtin) {
740+
if (varBuiltIn == builtin) {
741741
return varOp;
742742
}
743743
}

mlir/lib/Dialect/Utils/StaticValueUtils.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,7 @@ getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
142142
}
143143

144144
bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
145-
auto val = getConstantIntValue(ofr);
146-
return val && *val == value;
145+
return getConstantIntValue(ofr) == value;
147146
}
148147

149148
bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {

0 commit comments

Comments
 (0)