Skip to content

Commit b7324b6

Browse files
authored
[mlir][vector] Adds pattern rewrite for maskable Ops (#83827)
Adds a generic pattern rewrite for maskable Ops, `MaskableOpRewritePattern`, that will work for both masked and un-masked cases, e.g. for both: * `vector.mask {vector.contract}` (masked), and * `vector.contract` (not masked). This helps to reduce code-duplication and standardise how we implement such patterns. Fixes #78787
1 parent f6f474c commit b7324b6

File tree

2 files changed

+143
-94
lines changed

2 files changed

+143
-94
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,64 @@ SmallVector<OpFoldResult> getMixedSizesXfer(bool hasTensorSemantics,
112112
Operation *xfer,
113113
RewriterBase &rewriter);
114114

115+
/// A pattern for ops that implement `MaskableOpInterface` and that _might_ be
116+
/// masked (i.e. inside `vector.mask` Op region). In particular:
117+
/// 1. Matches `SourceOp` operation, Op.
118+
/// 2.1. If Op is masked, retrieves the masking Op, maskOp, and updates the
119+
/// insertion point to avoid inserting new ops into the `vector.mask` Op
120+
/// region (which only allows one Op).
121+
/// 2.2 If Op is not masked, this step is skipped.
122+
/// 3. Invokes `matchAndRewriteMaskableOp` on Op and optionally maskOp if
123+
/// found in step 2.1.
124+
///
125+
/// This wrapper frees patterns from re-implementing the logic to update the
126+
/// insertion point when a maskable Op is masked. Such patterns are still
127+
/// responsible for providing an updated ("rewritten") version of:
128+
/// a. the source Op when mask _is not_ present,
129+
/// b. the source Op and the masking Op when mask _is_ present.
130+
/// Note that the return value from `matchAndRewriteMaskableOp` depends on the
131+
/// case above.
132+
template <class SourceOp>
133+
struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
134+
using OpRewritePattern<SourceOp>::OpRewritePattern;
135+
136+
private:
137+
LogicalResult matchAndRewrite(SourceOp sourceOp,
138+
PatternRewriter &rewriter) const final {
139+
auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation());
140+
if (!maskableOp)
141+
return failure();
142+
143+
Operation *rootOp = sourceOp;
144+
145+
// If this Op is masked, update the insertion point to avoid inserting into
146+
// the vector.mask Op region.
147+
OpBuilder::InsertionGuard guard(rewriter);
148+
MaskingOpInterface maskOp;
149+
if (maskableOp.isMasked()) {
150+
maskOp = maskableOp.getMaskingOp();
151+
rewriter.setInsertionPoint(maskOp);
152+
rootOp = maskOp;
153+
}
154+
155+
FailureOr<Value> newOp =
156+
matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
157+
if (failed(newOp))
158+
return failure();
159+
160+
rewriter.replaceOp(rootOp, *newOp);
161+
return success();
162+
}
163+
164+
public:
165+
// Matches SourceOp that can potentially be masked with `maskingOp`. If the
166+
// latter is present, returns an updated masking op (with a replacement for
167+
// `sourceOp` nested inside). Otherwise, returns an updated `sourceOp`.
168+
virtual FailureOr<Value>
169+
matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
170+
PatternRewriter &rewriter) const = 0;
171+
};
172+
115173
} // namespace vector
116174

117175
/// Constructs a permutation map of invariant memref indices to vector

0 commit comments

Comments
 (0)