Skip to content

Commit 158e6dd

Browse files
committed
fixup! [mlir][vector] Adds pattern rewrite for maskable Ops
Move pattern
1 parent f534bdd commit 158e6dd

File tree

3 files changed

+90
-59
lines changed

3 files changed

+90
-59
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,64 @@ bool isContiguousSlice(MemRefType memrefType, VectorType vectorType);
9898
std::optional<StaticTileOffsetRange>
9999
createUnrollIterator(VectorType vType, int64_t targetRank = 1);
100100

101+
/// A pattern for ops that implement `MaskableOpInterface` and that _might_ be
102+
/// masked (i.e. inside `vector.mask` Op region). In particular:
103+
/// 1. It matches `SourceOp` operation, Op.
104+
/// 2. If Op is masked, retrieves the mask and updates the insertion point to
105+
/// avoid inserting new ops into `vector.mask` Op region (which only allows
106+
/// one Op). If the Op is not masked, this step is a nop.
107+
/// 3. Invokes `matchAndRewriteMaskableOp` on Op that might be nested (not
108+
/// required) in the matched `vector.mask` operation from step 2.
109+
///
110+
/// It frees the patterns implementing this class from worrying about the
111+
/// logic to update the insertion point. However, those patterns are still
112+
/// responsible for providing an updated version of:
113+
/// * the source Op when mask _is not_ present,
114+
/// * the source Op *and* the mask Op when mask _is_ present.
115+
template <class SourceOp>
116+
struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
117+
using OpRewritePattern<SourceOp>::OpRewritePattern;
118+
119+
private:
120+
LogicalResult matchAndRewrite(SourceOp sourceOp,
121+
PatternRewriter &rewriter) const final {
122+
auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation());
123+
if (!maskableOp)
124+
return failure();
125+
126+
// Op to update
127+
Operation *rootOp = sourceOp;
128+
129+
// If this Op is masked:
130+
// * update the insertion point to avoid inserting into the vector.mask
131+
// Op region,
132+
// * update the Op to rewrite so that it's the parent vector.mask Op
133+
OpBuilder::InsertionGuard guard(rewriter);
134+
MaskingOpInterface maskOp;
135+
if (maskableOp.isMasked()) {
136+
maskOp = maskableOp.getMaskingOp();
137+
rewriter.setInsertionPoint(maskOp);
138+
rootOp = maskOp;
139+
}
140+
141+
FailureOr<Value> newOp =
142+
matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
143+
if (failed(newOp))
144+
return failure();
145+
146+
rewriter.replaceOp(rootOp, *newOp);
147+
return success();
148+
}
149+
150+
public:
151+
// Matches SourceOp that can potentially be masked with `maskingOp`. If the
152+
// latter is present, returns an updated masking op (with a replacement for
153+
// `sourceOp` nested inside). Otherwise, returns an updated `sourceOp`.
154+
virtual FailureOr<Value>
155+
matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
156+
PatternRewriter &rewriter) const = 0;
157+
};
158+
101159
} // namespace vector
102160

103161
/// Constructs a permutation map of invariant memref indices to vector

mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp

Lines changed: 1 addition & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -211,64 +211,6 @@ static Value createMul(Location loc, Value x, Value y, bool isInt,
211211

212212
namespace {
213213

214-
/// A pattern for ops that implement `MaskableOpInterface` and that _might_ be
215-
/// masked (i.e. inside `vector.mask` Op region). In particular:
216-
/// 1. It matches `SourceOp` operation, Op.
217-
/// 2. If Op is masked, retrieves the mask and updates the insertion point to
218-
/// avoid inserting new ops into `vector.mask` Op region (which only allows
219-
/// one Op). If the Op is not masked, this step is a nop.
220-
/// 3. Invokes `matchAndRewriteMaskableOp` on Op that might be nested (not
221-
/// required) in the matched `vector.mask` operation from step 2.
222-
///
223-
/// It frees the patterns implementing this class from worrying about the
224-
/// logic to update the insertion point. However, those patterns are still
225-
/// responsible for providing an updated version of:
226-
/// * the source Op when mask _is not_ present,
227-
/// * the source Op *and* the mask Op when mask _is_ present.
228-
template <class SourceOp>
229-
struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
230-
using OpRewritePattern<SourceOp>::OpRewritePattern;
231-
232-
private:
233-
LogicalResult matchAndRewrite(SourceOp sourceOp,
234-
PatternRewriter &rewriter) const final {
235-
auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation());
236-
if (!maskableOp)
237-
return failure();
238-
239-
// Op to update
240-
Operation *rootOp = sourceOp;
241-
242-
// If this Op is masked:
243-
// * update the insertion point to avoid inserting into the vector.mask
244-
// Op region,
245-
// * update the Op to rewrite so that it's the parent vector.mask Op
246-
OpBuilder::InsertionGuard guard(rewriter);
247-
MaskingOpInterface maskOp;
248-
if (maskableOp.isMasked()) {
249-
maskOp = maskableOp.getMaskingOp();
250-
rewriter.setInsertionPoint(maskOp);
251-
rootOp = maskOp;
252-
}
253-
254-
FailureOr<Value> newOp =
255-
matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
256-
if (failed(newOp))
257-
return failure();
258-
259-
rewriter.replaceOp(rootOp, *newOp);
260-
return success();
261-
}
262-
263-
public:
264-
// Matches SourceOp that can potentially be masked with `maskingOp`. If the
265-
// latter is present, returns an updated masking op (with a replacement for
266-
// `sourceOp` nested inside). Otherwise, returns an updated `sourceOp`.
267-
virtual FailureOr<Value>
268-
matchAndRewriteMaskableOp(SourceOp sourceOp, MaskingOpInterface maskingOp,
269-
PatternRewriter &rewriter) const = 0;
270-
};
271-
272214
/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul
273215
/// semantics to:
274216
/// ```
@@ -283,7 +225,7 @@ struct MaskableOpRewritePattern : OpRewritePattern<SourceOp> {
283225
/// This only kicks in when VectorTransformsOptions is set to OuterProduct and
284226
/// the vector.contract op is a row-major matrix multiply.
285227
class ContractionOpToMatmulOpLowering
286-
: public MaskableOpRewritePattern<vector::ContractionOp> {
228+
: public vector::MaskableOpRewritePattern<vector::ContractionOp> {
287229
public:
288230
using MaskableOpRewritePattern::MaskableOpRewritePattern;
289231

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,37 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
279279
return llvm::all_of(leadingDims, [](auto x) { return x == 1; });
280280
}
281281

282+
// template <class SourceOp>
283+
// LogicalResult vector::MaskableOpRewritePattern::matchAndRewrite(
284+
// SourceOp sourceOp, PatternRewriter &rewriter) const final {
285+
// auto maskableOp = dyn_cast<MaskableOpInterface>(sourceOp.getOperation());
286+
// if (!maskableOp)
287+
// return failure();
288+
289+
// // Op to update
290+
// Operation *rootOp = sourceOp;
291+
292+
// // If this Op is masked:
293+
// // * update the insertion point to avoid inserting into the vector.mask
294+
// // Op region,
295+
// // * update the Op to rewrite so that it's the parent vector.mask Op
296+
// OpBuilder::InsertionGuard guard(rewriter);
297+
// MaskingOpInterface maskOp;
298+
// if (maskableOp.isMasked()) {
299+
// maskOp = maskableOp.getMaskingOp();
300+
// rewriter.setInsertionPoint(maskOp);
301+
// rootOp = maskOp;
302+
// }
303+
304+
// FailureOr<Value> newOp =
305+
// matchAndRewriteMaskableOp(sourceOp, maskOp, rewriter);
306+
// if (failed(newOp))
307+
// return failure();
308+
309+
// rewriter.replaceOp(rootOp, *newOp);
310+
// return success();
311+
// }
312+
282313
std::optional<StaticTileOffsetRange>
283314
vector::createUnrollIterator(VectorType vType, int64_t targetRank) {
284315
if (vType.getRank() <= targetRank)

0 commit comments

Comments
 (0)