Skip to content

Commit c39915f

Browse files
authored
[mlir][NFC] Simplify constant checks with isOneInteger and renamed isZeroInteger. (#139340)
The revision adds isOneInteger helper, and simplifies the existing code with the two methods. It removes some lambda, which makes code cleaner. For downstream users, you can update the code with the below script. ```bash sed -i "s/isZeroIndex/isZeroInteger/g" **/*.h sed -i "s/isZeroIndex/isZeroInteger/g" **/*.cpp ``` --------- Signed-off-by: hanhanW <[email protected]>
1 parent de3e8ff commit c39915f

22 files changed

+54
-76
lines changed

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@
2424

2525
namespace mlir {
2626

27-
/// Return true if `v` is an IntegerAttr with value `0` of a ConstantIndexOp
28-
/// with attribute with value `0`.
29-
bool isZeroIndex(OpFoldResult v);
27+
/// Return true if `v` is an IntegerAttr with value `0`.
28+
bool isZeroInteger(OpFoldResult v);
29+
30+
/// Return true if `v` is an IntegerAttr with value `1`.
31+
bool isOneInteger(OpFoldResult v);
3032

3133
/// Represents a range (offset, size, and stride) where each element of the
3234
/// triple may be dynamic or static.

mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
897897
OpFoldResult offset =
898898
getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
899899
.front();
900-
if (isConstantIntValue(offset, 0)) {
900+
if (isZeroInteger(offset)) {
901901
rewriter.replaceOp(op, src);
902902
return success();
903903
}

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4488,8 +4488,7 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
44884488

44894489
// Return true if we have a zero-value tile.
44904490
auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
4491-
return llvm::any_of(
4492-
tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
4491+
return llvm::any_of(tiles, isZeroInteger);
44934492
};
44944493

44954494
// Verify tiles. Do not allow zero tiles.

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3401,10 +3401,7 @@ static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
34013401
SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
34023402
SmallVector<OpFoldResult> steps = loop.getMixedStep();
34033403

3404-
if (llvm::all_of(
3405-
lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
3406-
llvm::all_of(
3407-
steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
3404+
if (llvm::all_of(lbs, isZeroInteger) && llvm::all_of(steps, isOneInteger)) {
34083405
return loop;
34093406
}
34103407

mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,8 @@ mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter,
441441
// If the `padOp` has a nofold attribute and all paddings are known to be 0,
442442
// explicitly insert a `linalg.copy`.
443443
if (padOp.getNofoldAttr() &&
444-
llvm::all_of(padOp.getMixedLowPad(), isZeroIndex) &&
445-
llvm::all_of(padOp.getMixedHighPad(), isZeroIndex)) {
444+
llvm::all_of(padOp.getMixedLowPad(), isZeroInteger) &&
445+
llvm::all_of(padOp.getMixedHighPad(), isZeroInteger)) {
446446
using bufferization::AllocTensorOp;
447447
Value allocated =
448448
rewriter.create<AllocTensorOp>(loc, resultType, dynamicSizes);

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
2424
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2525
#include "mlir/Dialect/Utils/IndexingUtils.h"
26+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2627
#include "mlir/IR/AffineExpr.h"
2728
#include "mlir/IR/AffineMap.h"
2829
#include "mlir/IR/BuiltinOps.h"
@@ -376,13 +377,13 @@ static void calculateTileOffsetsAndSizes(
376377

377378
SmallVector<Value> threadIds = forallOp.getInductionVars();
378379
SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
379-
numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });
380+
numThreads, [](OpFoldResult ofr) { return !isZeroInteger(ofr); });
380381
int64_t nLoops = loopRanges.size();
381382
tiledOffsets.reserve(nLoops);
382383
tiledSizes.reserve(nLoops);
383384
for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) {
384385
bool overflow = loopIdx >= numThreads.size();
385-
bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0);
386+
bool isZero = !overflow && isZeroInteger(numThreads[loopIdx]);
386387
// Degenerate case: take the whole domain.
387388
if (overflow || isZero) {
388389
tiledOffsets.push_back(loopRanges[loopIdx].offset);
@@ -413,7 +414,7 @@ static void calculateTileOffsetsAndSizes(
413414
OpFoldResult residualTileSize = makeComposedFoldedAffineApply(
414415
b, loc, i + j * m - n,
415416
{offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size});
416-
if (!isConstantIntValue(residualTileSize, 0)) {
417+
if (!isZeroInteger(residualTileSize)) {
417418
OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply(
418419
b, loc, -i + m, {offsetPerThread, size});
419420
tileSizePerThread =
@@ -655,7 +656,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
655656
Operation *tiledOp = nullptr;
656657

657658
SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
658-
numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });
659+
numThreads, [](OpFoldResult ofr) { return !isZeroInteger(ofr); });
659660
SmallVector<Value> materializedNonZeroNumThreads =
660661
getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads);
661662

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ struct LinalgOpPartialReductionInterface
369369

370370
SmallVector<OpFoldResult> tiledShape;
371371
for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) {
372-
if (isZeroIndex(tileSize)) {
372+
if (isZeroInteger(tileSize)) {
373373
tiledShape.push_back(dimSize);
374374
} else {
375375
tiledShape.push_back(tileSize);
@@ -732,7 +732,7 @@ struct PackOpTiling
732732
// iterated or inner dims are not tiled. Otherwise, it will generate a
733733
// sequence of non-trivial ops (for partial tiles).
734734
for (auto offset : offsets.take_back(numTiles))
735-
if (!isConstantIntValue(offset, 0))
735+
if (!isZeroInteger(offset))
736736
return failure();
737737

738738
for (auto iter :

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
5959
TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
6060

6161
void visitDimExpr(AffineDimExpr expr) {
62-
isTiled |= !isZeroIndex(tileSizes[expr.getPosition()]);
62+
isTiled |= !isZeroInteger(tileSizes[expr.getPosition()]);
6363
}
6464
void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
6565
visit(expr.getLHS());
@@ -741,7 +741,7 @@ SmallVector<OpFoldResult> computeTileOffsets(OpBuilder &b, Location loc,
741741
SmallVector<OpFoldResult> offsets;
742742
for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
743743
LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
744-
bool isTiled = !isZeroIndex(tileSizes[idx]);
744+
bool isTiled = !isZeroInteger(tileSizes[idx]);
745745
offsets.push_back(isTiled ? ivs[idxIvs++] : b.getIndexAttr(0));
746746
LLVM_DEBUG(llvm::dbgs()
747747
<< "computeTileOffsets: " << offsets.back() << "\n");
@@ -754,7 +754,7 @@ SmallVector<OpFoldResult> computeTileSizes(OpBuilder &b, Location loc,
754754
ArrayRef<OpFoldResult> sizeBounds) {
755755
SmallVector<OpFoldResult> sizes;
756756
for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) {
757-
bool isTiled = !isZeroIndex(tileSizes[idx]);
757+
bool isTiled = !isZeroInteger(tileSizes[idx]);
758758
// Before composing, we need to make range a closed interval.
759759
OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx];
760760
AffineExpr d0 = getAffineDimExpr(0, b.getContext());
@@ -810,7 +810,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
810810
bool omitPartialTileCheck) {
811811
assert(ivs.size() == static_cast<size_t>(llvm::count_if(
812812
llvm::make_range(tileSizes.begin(), tileSizes.end()),
813-
[](OpFoldResult v) { return !isZeroIndex(v); })) &&
813+
[](OpFoldResult v) { return !isZeroInteger(v); })) &&
814814
"expected as many ivs as non-zero sizes");
815815

816816
// Construct (potentially temporary) mins and maxes on which to apply maps

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1894,9 +1894,7 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
18941894
// reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
18951895
// are 0.
18961896
if (auto prev = src.getDefiningOp<SubViewOp>())
1897-
if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
1898-
return isConstantIntValue(val, 0);
1899-
}))
1897+
if (llvm::all_of(prev.getMixedOffsets(), isZeroInteger))
19001898
return prev.getSource();
19011899

19021900
return nullptr;
@@ -3290,11 +3288,9 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
32903288
auto srcSizes = srcSubview.getMixedSizes();
32913289
auto sizes = getMixedSizes();
32923290
auto offsets = getMixedOffsets();
3293-
bool allOffsetsZero = llvm::all_of(
3294-
offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
3291+
bool allOffsetsZero = llvm::all_of(offsets, isZeroInteger);
32953292
auto strides = getMixedStrides();
3296-
bool allStridesOne = llvm::all_of(
3297-
strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
3293+
bool allStridesOne = llvm::all_of(strides, isOneInteger);
32983294
bool allSizesSame = llvm::equal(sizes, srcSizes);
32993295
if (allOffsetsZero && allStridesOne && allSizesSame &&
33003296
resultMemrefType == sourceMemrefType)

mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,9 +251,7 @@ struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
251251
// to do.
252252
SmallVector<OpFoldResult> indices =
253253
getAsOpFoldResult(loadStoreLikeOp.getIndices());
254-
if (llvm::all_of(indices, [](const OpFoldResult &opFold) {
255-
return isConstantIntValue(opFold, 0);
256-
})) {
254+
if (llvm::all_of(indices, isZeroInteger)) {
257255
return rewriter.notifyMatchFailure(
258256
loadStoreLikeOp, "no computation to extract: offsets are 0s");
259257
}

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
133133
tileSizes.resize(numLoops, zero);
134134
for (auto [index, range, nt] :
135135
llvm::enumerate(iterationDomain, numThreads)) {
136-
if (isConstantIntValue(nt, 0))
136+
if (isZeroInteger(nt))
137137
continue;
138138

139139
tileSizes[index] = affine::makeComposedFoldedAffineApply(
@@ -265,7 +265,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
265265

266266
// Non-tiled cases, set the offset and size to the
267267
// `loopRange.offset/size`.
268-
if (isConstantIntValue(nt, 0)) {
268+
if (isZeroInteger(nt)) {
269269
offsets.push_back(loopRange.offset);
270270
sizes.push_back(loopRange.size);
271271
continue;
@@ -280,7 +280,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
280280
{loopRange.offset, nt, tileSize, loopRange.size});
281281

282282
OpFoldResult size = tileSize;
283-
if (!isConstantIntValue(residualTileSize, 0)) {
283+
if (!isZeroInteger(residualTileSize)) {
284284
OpFoldResult sizeMinusOffsetPerThread =
285285
affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
286286
{offset, loopRange.size});
@@ -316,7 +316,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
316316

317317
// Non-tiled cases, set the offset and size to the
318318
// `loopRange.offset/size`.
319-
if (isConstantIntValue(tileSize, 0)) {
319+
if (isZeroInteger(tileSize)) {
320320
offsets.push_back(loopRange.offset);
321321
sizes.push_back(loopRange.size);
322322
continue;
@@ -341,7 +341,7 @@ getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
341341
SmallVector<OpFoldResult> lbs, ubs, steps;
342342
for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
343343
// No loop if the tile size is 0.
344-
if (isConstantIntValue(tileSize, 0))
344+
if (isZeroInteger(tileSize))
345345
continue;
346346
lbs.push_back(loopRange.offset);
347347
ubs.push_back(loopRange.size);
@@ -495,7 +495,7 @@ static LogicalResult generateLoopNestUsingForallOp(
495495
// Prune the zero numthreads.
496496
SmallVector<OpFoldResult> nonZeroNumThreads;
497497
for (auto nt : numThreads) {
498-
if (isConstantIntValue(nt, 0))
498+
if (isZeroInteger(nt))
499499
continue;
500500
nonZeroNumThreads.push_back(nt);
501501
}
@@ -551,7 +551,7 @@ static LogicalResult generateLoopNest(
551551
YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
552552
// If the tile sizes are all zero, no loops are generated. Just call the
553553
// callback function to handle untiled case.
554-
if (llvm::all_of(tileSizes, isZeroIndex)) {
554+
if (llvm::all_of(tileSizes, isZeroInteger)) {
555555
SmallVector<Value> tiledResults;
556556
SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
557557
return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
@@ -999,7 +999,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
999999
// 5b. Early return cloned op if tiling is not happening. We can not
10001000
// return the original op because it could lead to `rewriter.replaceOp(op,
10011001
// op->getResults())` and users would get crash.
1002-
if (llvm::all_of(tileSizes, isZeroIndex)) {
1002+
if (llvm::all_of(tileSizes, isZeroInteger)) {
10031003
tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
10041004
tilingResult =
10051005
TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
@@ -1290,9 +1290,7 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
12901290
sliceSizes = sliceOp.getMixedSizes();
12911291

12921292
// expect all strides of sliceOp being 1
1293-
if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
1294-
return !isConstantIntValue(ofr, 1);
1295-
}))
1293+
if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
12961294
return failure();
12971295

12981296
unsigned sliceResultNumber =
@@ -2114,9 +2112,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(
21142112
SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
21152113

21162114
// 9. Check all insert stride is 1.
2117-
if (llvm::any_of(strides, [](OpFoldResult stride) {
2118-
return !isConstantIntValue(stride, 1);
2119-
})) {
2115+
if (!llvm::all_of(strides, isOneInteger)) {
21202116
return rewriter.notifyMatchFailure(
21212117
candidateSliceOp, "containingOp's result yield with stride");
21222118
}

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,7 @@ static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter,
768768
// If an `affine.apply` operation is generated for denormalization, the use
769769
// of `origLb` in those ops must not be replaced. These arent not generated
770770
// when `origLb == 0` and `origStep == 1`.
771-
if (!isConstantIntValue(origLb, 0) || !isConstantIntValue(origStep, 1)) {
771+
if (!isZeroInteger(origLb) || !isOneInteger(origStep)) {
772772
if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
773773
preservedUses.insert(preservedUse);
774774
}
@@ -785,8 +785,8 @@ void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
785785
}
786786
Value denormalizedIv;
787787
SmallPtrSet<Operation *, 2> preserve;
788-
bool isStepOne = isConstantIntValue(origStep, 1);
789-
bool isZeroBased = isConstantIntValue(origLb, 0);
788+
bool isStepOne = isOneInteger(origStep);
789+
bool isZeroBased = isZeroInteger(origLb);
790790

791791
Value scaled = normalizedIv;
792792
if (!isStepOne) {

mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
614614
// Check for single block, unit-stride for-loop that is generated by
615615
// sparsifier, which means no data dependence analysis is required,
616616
// and its loop-body is very restricted in form.
617-
if (!op.getRegion().hasOneBlock() || !isConstantIntValue(op.getStep(), 1) ||
617+
if (!op.getRegion().hasOneBlock() || !isOneInteger(op.getStep()) ||
618618
!op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()))
619619
return failure();
620620
// Analyze (!codegen) and rewrite (codegen) loop-body.

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2839,8 +2839,7 @@ OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
28392839
return getResult();
28402840
if (auto result = foldInsertAfterExtractSlice(*this))
28412841
return result;
2842-
if (llvm::any_of(getMixedSizes(),
2843-
[](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }))
2842+
if (llvm::any_of(getMixedSizes(), isZeroInteger))
28442843
return getDest();
28452844
return OpFoldResult();
28462845
}

mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,9 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
135135
SmallVector<OpFoldResult> newStrides(rank, b.getIndexAttr(1));
136136
for (unsigned dim = 0; dim < rank; ++dim) {
137137
auto low = padOp.getMixedLowPad()[dim];
138-
bool hasLowPad = !isConstantIntValue(low, 0);
138+
bool hasLowPad = !isZeroInteger(low);
139139
auto high = padOp.getMixedHighPad()[dim];
140-
bool hasHighPad = !isConstantIntValue(high, 0);
140+
bool hasHighPad = !isZeroInteger(high);
141141
auto offset = offsets[dim];
142142
auto length = sizes[dim];
143143
// If the dim has no padding, we dont need to calculate new values for that
@@ -208,7 +208,7 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
208208

209209
// Check if newLength is zero. In that case, no SubTensorOp should be
210210
// executed.
211-
if (isConstantIntValue(newLength, 0)) {
211+
if (isZeroInteger(newLength)) {
212212
hasZeroLen = true;
213213
} else if (!hasZeroLen) {
214214
Value check = b.create<arith::CmpIOp>(

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,7 @@ static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp,
646646
// Dest is not read if it is entirely overwritten. E.g.:
647647
// tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
648648
bool allOffsetsZero =
649-
llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroIndex);
649+
llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroInteger);
650650
RankedTensorType destType = insertSliceOp.getDestType();
651651
bool sizesMatchDestSizes =
652652
areConstantIntValues(insertSliceOp.getMixedSizes(), destType.getShape());

mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ struct BubbleUpExpandShapeThroughExtractSlice
452452
std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
453453
isZeroOffsetAndFullSize =
454454
[](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
455-
if (!isConstantIntValue(offset, 0))
455+
if (!isZeroInteger(offset))
456456
return false;
457457
FailureOr<bool> maybeEqual =
458458
ValueBoundsConstraintSet::areEqual(sliceSize, size);
@@ -476,7 +476,7 @@ struct BubbleUpExpandShapeThroughExtractSlice
476476
// Find the first expanded dim after the first dim with non-unit extracted
477477
// size.
478478
for (; i < e; ++i) {
479-
if (!isConstantIntValue(sizes[indices[i]], 1)) {
479+
if (!isOneInteger(sizes[indices[i]])) {
480480
// +1 to skip the first non-unit size dim.
481481
i++;
482482
break;

mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
2727
return failure();
2828

2929
// `TilingInterface` currently only supports strides being 1.
30-
if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
31-
return !isConstantIntValue(ofr, 1);
32-
}))
30+
if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
3331
return failure();
3432

3533
FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
@@ -49,9 +47,7 @@ FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
4947
return failure();
5048

5149
// `TilingInterface` currently only supports strides being 1.
52-
if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
53-
return !isConstantIntValue(ofr, 1);
54-
}))
50+
if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger))
5551
return failure();
5652

5753
FailureOr<TilingResult> tiledResult =

0 commit comments

Comments
 (0)