Skip to content

Commit 0ac889b

Browse files
authored
[mlir][Affine] Extend linearize/delinearize cancelation to partial tails (#116872)
xisting patterns would cancel out the linearize_index / delinearize_index pairs that had the exact same basis, like %0 = affine.linearize_index [%w, %x, %y, %z] by (X, Y, Z) : index %1:4 = affine.delinearize_index %0 into (W, X, Y, Z) : index, ... This commit extends the canonicalization to handle instances where the entire basis doesn't match, as in %0 = affine.linearize_index [%w, %x, %y, %z] by (X, Y, Z) : index %1:3 = affine.delinearize_index %0 into (XY, Y, Z) : index, ... where we can replace the last two results of the delinearize_index operation with the last two inputs of the linearize_index, creating a more canonical (fewer total computations to perform) result.
1 parent 6f68d03 commit 0ac889b

File tree

2 files changed

+63
-11
lines changed

2 files changed

+63
-11
lines changed

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4666,14 +4666,16 @@ struct DropUnitExtentBasis
46664666
};
46674667

46684668
/// If a `affine.delinearize_index`'s input is a `affine.linearize_index
4669-
/// disjoint` and the two operations have the same basis, replace the
4670-
/// delinearizeation results with the inputs of the `affine.linearize_index`
4671-
/// since they are exact inverses of each other.
4669+
/// disjoint` and the two operations end with the same basis elements,
4670+
/// cancel those parts of the operations out because they are inverses
4671+
/// of each other.
4672+
///
4673+
/// If the operations have the same basis, cancel them entirely.
46724674
///
46734675
/// The `disjoint` flag is needed on the `affine.linearize_index` because
46744676
/// otherwise, there is no guarantee that the inputs to the linearization are
46754677
/// in-bounds the way the outputs of the delinearization would be.
4676-
struct CancelDelinearizeOfLinearizeDisjointExact
4678+
struct CancelDelinearizeOfLinearizeDisjointExactTail
46774679
: public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
46784680
using OpRewritePattern::OpRewritePattern;
46794681

@@ -4685,22 +4687,54 @@ struct CancelDelinearizeOfLinearizeDisjointExact
46854687
return rewriter.notifyMatchFailure(delinearizeOp,
46864688
"index doesn't come from linearize");
46874689

4688-
if (!linearizeOp.getDisjoint() ||
4689-
linearizeOp.getEffectiveBasis() != delinearizeOp.getEffectiveBasis())
4690+
if (!linearizeOp.getDisjoint())
4691+
return rewriter.notifyMatchFailure(linearizeOp, "not disjoint");
4692+
4693+
ValueRange linearizeIns = linearizeOp.getMultiIndex();
4694+
// Note: we use the full basis so we don't lose outer bounds later.
4695+
SmallVector<OpFoldResult> linearizeBasis = linearizeOp.getMixedBasis();
4696+
SmallVector<OpFoldResult> delinearizeBasis = delinearizeOp.getMixedBasis();
4697+
size_t numMatches = 0;
4698+
for (auto [linSize, delinSize] : llvm::zip(
4699+
llvm::reverse(linearizeBasis), llvm::reverse(delinearizeBasis))) {
4700+
if (linSize != delinSize)
4701+
break;
4702+
++numMatches;
4703+
}
4704+
4705+
if (numMatches == 0)
46904706
return rewriter.notifyMatchFailure(
4691-
linearizeOp, "not disjoint or basis doesn't match delinearize");
4707+
delinearizeOp, "final basis element doesn't match linearize");
4708+
4709+
// The easy case: everything lines up and the basis match sup completely.
4710+
if (numMatches == linearizeBasis.size() &&
4711+
numMatches == delinearizeBasis.size() &&
4712+
linearizeIns.size() == delinearizeOp.getNumResults()) {
4713+
rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
4714+
return success();
4715+
}
46924716

4693-
rewriter.replaceOp(delinearizeOp, linearizeOp.getMultiIndex());
4717+
Value newLinearize = rewriter.create<affine::AffineLinearizeIndexOp>(
4718+
linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
4719+
ArrayRef<OpFoldResult>{linearizeBasis}.drop_back(numMatches),
4720+
linearizeOp.getDisjoint());
4721+
auto newDelinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
4722+
delinearizeOp.getLoc(), newLinearize,
4723+
ArrayRef<OpFoldResult>{delinearizeBasis}.drop_back(numMatches),
4724+
delinearizeOp.hasOuterBound());
4725+
SmallVector<Value> mergedResults(newDelinearize.getResults());
4726+
mergedResults.append(linearizeIns.take_back(numMatches).begin(),
4727+
linearizeIns.take_back(numMatches).end());
4728+
rewriter.replaceOp(delinearizeOp, mergedResults);
46944729
return success();
46954730
}
46964731
};
46974732
} // namespace
46984733

46994734
void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
47004735
RewritePatternSet &patterns, MLIRContext *context) {
4701-
patterns
4702-
.insert<CancelDelinearizeOfLinearizeDisjointExact, DropUnitExtentBasis>(
4703-
context);
4736+
patterns.insert<CancelDelinearizeOfLinearizeDisjointExactTail,
4737+
DropUnitExtentBasis>(context);
47044738
}
47054739

47064740
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Affine/canonicalize.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1739,6 +1739,24 @@ func.func @cancel_delinearize_linearize_disjoint_delinearize_extra_bound(%arg0:
17391739

17401740
// -----
17411741

1742+
// CHECK-LABEL: func @cancel_delinearize_linearize_disjoint_partial(
1743+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,
1744+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index,
1745+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index,
1746+
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index,
1747+
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index)
1748+
// CHECK: %[[LIN:.+]] = affine.linearize_index disjoint [%[[ARG0]], %[[ARG1]]] by (%[[ARG3]], 4) : index
1749+
// CHECK: %[[DELIN:.+]]:2 = affine.delinearize_index %[[LIN]] into (8) : index, index
1750+
// CHECK: return %[[DELIN]]#0, %[[DELIN]]#1, %[[ARG2]]
1751+
func.func @cancel_delinearize_linearize_disjoint_partial(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> (index, index, index) {
1752+
%0 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (%arg3, 4, %arg4) : index
1753+
%1:3 = affine.delinearize_index %0 into (8, %arg4)
1754+
: index, index, index
1755+
return %1#0, %1#1, %1#2 : index, index, index
1756+
}
1757+
1758+
// -----
1759+
17421760
// Without `disjoint`, the cancelation isn't guaranteed to be the identity.
17431761
// CHECK-LABEL: func @no_cancel_delinearize_linearize_exact(
17441762
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index,

0 commit comments

Comments
 (0)