Skip to content

Commit bb95fed

Browse files
committed
[mlir][linalg] Add a new helper hook - isVectorizable
The newly added hook simply returns `false` for Ops for which there's no "vectorization logic" in the Linalg Vectorizer (i.e. the `vectorize` method). It's added so that the following two TD ops expose identical level of functionality (that's not the case ATM): * `transform.structured.vectorize_children_and_apply_patterns` * `transform.structured.vectorize`
1 parent eb3361d commit bb95fed

File tree

3 files changed

+18
-7
lines changed

3 files changed

+18
-7
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,6 +762,13 @@ LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst);
762762
/// memory is freed when going outside of the scope.
763763
LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
764764

765+
/// Check if this Op is vectorizable. All Linalg Op are vectorizable, as well
766+
/// as selected Tensor Ops. Note that this is merely a high level check and
767+
/// that the vectorizer also requires various additional pre-conditions to be
768+
/// met for it to work. These are only checked for Ops that are supported,
769+
/// other Ops should be rejected early.
770+
bool isVectorizable(Operation *);
771+
765772
/// Emit a suitable vector form for an operation. If provided,
766773
/// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes`
767774
/// must match the rank of the iteration space of the operation and the sizes

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3411,11 +3411,11 @@ struct VectorizationPattern : public RewritePattern {
34113411
flatten1DDepthwiseConv(flattenConv) {}
34123412
LogicalResult matchAndRewrite(Operation *op,
34133413
PatternRewriter &rewriter) const override {
3414-
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
3415-
if (!linalgOp)
3416-
return rewriter.notifyMatchFailure(op, "expected Linalg Op");
3417-
return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{},
3418-
/*scalableVecDims=*/{}, vectorizeNDExtract,
3414+
if (!linalg::isVectorizable(op))
3415+
return rewriter.notifyMatchFailure(op,
3416+
"Unsupported Op, cannot vectorize");
3417+
return vectorize(rewriter, op, /*inputVectorSizes=*/{},
3418+
/*inputScalableVecDims=*/{}, vectorizeNDExtract,
34193419
flatten1DDepthwiseConv);
34203420
}
34213421

@@ -3496,8 +3496,7 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply(
34963496

34973497
// TODO: Check that the correct number of vectorSizes was provided.
34983498
for (Operation *target : targets) {
3499-
if (!isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
3500-
target)) {
3499+
if (!linalg::isVectorizable(target)) {
35013500
return mlir::emitSilenceableFailure(target->getLoc())
35023501
<< "Unsupported Op, cannot vectorize";
35033502
}

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2129,6 +2129,11 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
21292129
}
21302130
}
21312131

2132+
bool mlir::linalg::isVectorizable(Operation *op) {
2133+
return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
2134+
op);
2135+
}
2136+
21322137
/// Emit a suitable vector form for an operation. If provided,
21332138
/// `inputVectorSizes` are used to vectorize this operation.
21342139
/// `inputVectorSizes` must match the rank of the iteration space of the

0 commit comments

Comments
 (0)