Skip to content

Commit 3ae7afc

Browse files
committed
fixup! [mlir][vector] Adds pattern rewrite for maskable Ops
Move pattern
1 parent 69814ed commit 3ae7afc

File tree

3 files changed

+90
-59
lines changed

3 files changed

+90
-59
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,64 @@ SmallVector<OpFoldResult> getMixedSizesXfer(bool hasTensorSemantics,
112112
Operation *xfer,
113113
RewriterBase &rewriter);
114114

115+
/// A pattern for ops that implement `MaskableOpInterface` and that _might_ be
116+
/// masked (i.e. inside `vector.mask` Op region). In particular:
117+
/// 1. It matches `SourceOp` operation, Op.
118+
/// 2. If Op is masked, retrieves the mask and updates the insertion point to
119+
/// avoid inserting new ops into `vector.mask` Op region (which only allows
120+
/// one Op). If the Op is not masked, this step is a nop.
121+
/// 3. Invokes `matchAndRewriteMaskableOp` on Op that might be nested (not
122+
/// required) in the matched `vector.mask` operation from step 2.
123+
///
124+
/// It frees the patterns implementing this class from worrying about the
125+
/// logic to update the insertion point. However, those patterns are still
126+
/// responsible for providing an updated version of:
127+
/// * the source Op when mask _is not_ present,
128+
/// * the source Op *and* the mask Op when mask _is_ present.
129+
template <class SourceOp>
130+
struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
131+
using OpRewritePattern<SourceOp>::OpRewritePattern;
132+
133+
private:
134+
LogicalResult matchAndRewrite(SourceOp sourceOp,
135+
PatternRewriter &rewriter) const final {
136+
auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation());
137+
if (!maskableOp)
138+
return failure();
139+
140+
// Op to update
141+
Operation *rootOp = sourceOp;
142+
143+
// If this Op is masked:
144+
// * update the insertion point to avoid inserting into the vector.mask
145+
// Op region,
146+
// * update the Op to rewrite so that it's the parent vector.mask Op
147+
OpBuilder::InsertionGuard guard(rewriter);
148+
MaskingOpInterface maskOp;
149+
if (maskableOp.isMasked()) {
150+
maskOp = maskableOp.getMaskingOp();
151+
rewriter.setInsertionPoint(maskOp);
152+
rootOp = maskOp;
153+
}
154+
155+
FailureOr<Value> newOp =
156+
matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
157+
if (failed(newOp))
158+
return failure();
159+
160+
rewriter.replaceOp(rootOp, *newOp);
161+
return success();
162+
}
163+
164+
public:
165+
// Matches SourceOp that can potentially be masked with `maskingOp`. If the
166+
// latter is present, returns an updated masking op (with a replacement for
167+
// `sourceOp` nested inside). Otherwise, returns an updated `sourceOp`.
168+
virtual FailureOr<Value>
169+
matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
170+
PatternRewriter &rewriter) const = 0;
171+
};
172+
115173
} // namespace vector
116174

117175
/// Constructs a permutation map of invariant memref indices to vector

mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp

Lines changed: 1 addition & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -211,64 +211,6 @@ static Value createMul(Location loc, Value x, Value y, bool isInt,
211211

212212
namespace {
213213

214-
/// A pattern for ops that implement `MaskableOpInterface` and that _might_ be
215-
/// masked (i.e. inside `vector.mask` Op region). In particular:
216-
/// 1. It matches `SourceOp` operation, Op.
217-
/// 2. If Op is masked, retrieves the mask and updates the insertion point to
218-
/// avoid inserting new ops into `vector.mask` Op region (which only allows
219-
/// one Op). If the Op is not masked, this step is a nop.
220-
/// 3. Invokes `matchAndRewriteMaskableOp` on Op that might be nested (not
221-
/// required) in the matched `vector.mask` operation from step 2.
222-
///
223-
/// It frees the patterns implementing this class from worrying about the
224-
/// logic to update the insertion point. However, those patterns are still
225-
/// responsible for providing an updated version of:
226-
/// * the source Op when mask _is not_ present,
227-
/// * the source Op *and* the mask Op when mask _is_ present.
228-
template <class SourceOp>
229-
struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
230-
using OpRewritePattern<SourceOp>::OpRewritePattern;
231-
232-
private:
233-
LogicalResult matchAndRewrite(SourceOp sourceOp,
234-
PatternRewriter &rewriter) const final {
235-
auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation());
236-
if (!maskableOp)
237-
return failure();
238-
239-
// Op to update
240-
Operation *rootOp = sourceOp;
241-
242-
// If this Op is masked:
243-
// * update the insertion point to avoid inserting into the vector.mask
244-
// Op region,
245-
// * update the Op to rewrite so that it's the parent vector.mask Op
246-
OpBuilder::InsertionGuard guard(rewriter);
247-
MaskingOpInterface maskOp;
248-
if (maskableOp.isMasked()) {
249-
maskOp = maskableOp.getMaskingOp();
250-
rewriter.setInsertionPoint(maskOp);
251-
rootOp = maskOp;
252-
}
253-
254-
FailureOr<Value> newOp =
255-
matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
256-
if (failed(newOp))
257-
return failure();
258-
259-
rewriter.replaceOp(rootOp, *newOp);
260-
return success();
261-
}
262-
263-
public:
264-
// Matches SourceOp that can potentially be masked with `maskingOp`. If the
265-
// latter is present, returns an updated masking op (with a replacement for
266-
// `sourceOp` nested inside). Otherwise, returns an updated `sourceOp`.
267-
virtual FailureOr<Value>
268-
matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
269-
PatternRewriter &rewriter) const = 0;
270-
};
271-
272214
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
273215
/// semantics to:
274216
/// ```
@@ -283,7 +225,7 @@ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
283225
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
284226
/// the vector.contract op is a row-major matrix multiply.
285227
class ContractionOpToMatmulOpLowering
286-
: public MaskableOpRewritePattern<vector::ContractionOp> {
228+
: public vector::MaskableOpRewritePattern<vector::ContractionOp> {
287229
public:
288230
using MaskableOpRewritePattern::MaskableOpRewritePattern;
289231

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,37 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
279279
return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
280280
}
281281

282+
// template <class SourceOp>
283+
// LogicalResult vector::MaskableOpRewritePattern::matchAndRewrite(
284+
// SourceOp sourceOp, PatternRewriter &rewriter) const final {
285+
// auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation());
286+
// if (!maskableOp)
287+
// return failure();
288+
289+
// // Op to update
290+
// Operation *rootOp = sourceOp;
291+
292+
// // If this Op is masked:
293+
// // * update the insertion point to avoid inserting into the vector.mask
294+
// // Op region,
295+
// // * update the Op to rewrite so that it's the parent vector.mask Op
296+
// OpBuilder::InsertionGuard guard(rewriter);
297+
// MaskingOpInterface maskOp;
298+
// if (maskableOp.isMasked()) {
299+
// maskOp = maskableOp.getMaskingOp();
300+
// rewriter.setInsertionPoint(maskOp);
301+
// rootOp = maskOp;
302+
// }
303+
304+
// FailureOr<Value> newOp =
305+
// matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
306+
// if (failed(newOp))
307+
// return failure();
308+
309+
// rewriter.replaceOp(rootOp, *newOp);
310+
// return success();
311+
// }
312+
282313
std::optional<StaticTileOffsetRange>
283314
vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
284315
if (vType.getRank() <= targetRank)

0 commit comments

Comments
 (0)