@@ -112,6 +112,64 @@ SmallVector<OpFoldResult> getMixedSizesXfer(bool hasTensorSemantics,
112
112
Operation *xfer,
113
113
RewriterBase &rewriter);
114
114
115
+ // / A pattern for ops that implement `MaskableOpInterface` and that _might_ be
116
+ // / masked (i.e. inside `vector.mask` Op region). In particular:
117
+ // / 1. Matches `SourceOp` operation, Op.
118
+ // / 2.1. If Op is masked, retrieves the masking Op, maskOp, and updates the
119
+ // / insertion point to avoid inserting new ops into the `vector.mask` Op
120
+ // / region (which only allows one Op).
121
+ // / 2.2 If Op is not masked, this step is skipped.
122
+ // / 3. Invokes `matchAndRewriteMaskableOp` on Op and optionally maskOp if
123
+ // / found in step 2.1.
124
+ // /
125
+ // / This wrapper frees patterns from re-implementing the logic to update the
126
+ // / insertion point when a maskable Op is masked. Such patterns are still
127
+ // / responsible for providing an updated ("rewritten") version of:
128
+ // / a. the source Op when mask _is not_ present,
129
+ // / b. the source Op and the masking Op when mask _is_ present.
130
+ // / Note that the return value from `matchAndRewriteMaskableOp` depends on the
131
+ // / case above.
132
+ template <class SourceOp >
133
+ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
134
+ using OpRewritePattern<SourceOp>::OpRewritePattern;
135
+
136
+ private:
137
+ LogicalResult matchAndRewrite (SourceOp sourceOp,
138
+ PatternRewriter &rewriter) const final {
139
+ auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation ());
140
+ if (!maskableOp)
141
+ return failure ();
142
+
143
+ Operation *rootOp = sourceOp;
144
+
145
+ // If this Op is masked, update the insertion point to avoid inserting into
146
+ // the vector.mask Op region.
147
+ OpBuilder::InsertionGuard guard (rewriter);
148
+ MaskingOpInterface maskOp;
149
+ if (maskableOp.isMasked ()) {
150
+ maskOp = maskableOp.getMaskingOp ();
151
+ rewriter.setInsertionPoint (maskOp);
152
+ rootOp = maskOp;
153
+ }
154
+
155
+ FailureOr<Value> newOp =
156
+ matchAndRewriteMaskableOp (sourceOp, maskOp, rewriter);
157
+ if (failed (newOp))
158
+ return failure ();
159
+
160
+ rewriter.replaceOp (rootOp, *newOp);
161
+ return success ();
162
+ }
163
+
164
+ public:
165
+ // Matches SourceOp that can potentially be masked with `maskingOp`. If the
166
+ // latter is present, returns an updated masking op (with a replacement for
167
+ // `sourceOp` nested inside). Otherwise, returns an updated `sourceOp`.
168
+ virtual FailureOr<Value>
169
+ matchAndRewriteMaskableOp (SourceOp sourceOp, MaskingOpInterface maskingOp,
170
+ PatternRewriter &rewriter) const = 0 ;
171
+ };
172
+
115
173
} // namespace vector
116
174
117
175
// / Constructs a permutation map of invariant memref indices to vector
0 commit comments