Skip to content

Commit 4bba066

Browse files
author
Jerry Wu
committed
Handle masked op in VectorDropLeadUnitDim patterns
1 parent a01b58a commit 4bba066

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,9 @@ struct CastAwayTransferReadLeadingOneDim
223223

224224
LogicalResult matchAndRewrite(vector::TransferReadOp read,
225225
PatternRewriter &rewriter) const override {
226+
// Not supported masked op yet.
227+
if (cast<MaskableOpInterface>(read.getOperation()).isMasked())
228+
return failure();
226229
// TODO: support 0-d corner case.
227230
if (read.getTransferRank() == 0)
228231
return failure();
@@ -274,6 +277,9 @@ struct CastAwayTransferWriteLeadingOneDim
274277

275278
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
276279
PatternRewriter &rewriter) const override {
280+
// Not supported masked op yet.
281+
if (cast<MaskableOpInterface>(write.getOperation()).isMasked())
282+
return failure();
277283
// TODO: support 0-d corner case.
278284
if (write.getTransferRank() == 0)
279285
return failure();
@@ -325,6 +331,9 @@ struct CastAwayTransferWriteLeadingOneDim
325331
LogicalResult
326332
mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
327333
RewriterBase &rewriter) {
334+
// Not supported masked op yet.
335+
if (cast<MaskableOpInterface>(contractOp.getOperation()).isMasked())
336+
return failure();
328337
VectorType oldAccType = dyn_cast<VectorType>(contractOp.getAccType());
329338
if (oldAccType == nullptr)
330339
return failure();

0 commit comments

Comments
 (0)