@@ -1715,9 +1715,17 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
1715
1715
return success ();
1716
1716
}
1717
1717
1718
- static LogicalResult vectorizeDynamicConvOpPrecondition (linalg::LinalgOp conv) {
1718
+ static LogicalResult
1719
+ vectorizeDynamicConvOpPrecondition (linalg::LinalgOp conv,
1720
+ bool flatten1DDepthwiseConv) {
1721
+ if (flatten1DDepthwiseConv) {
1722
+ LDBG (" Vectorization of flattened convs with dynamic shapes is not "
1723
+ " supported\n " );
1724
+ return failure ();
1725
+ }
1726
+
1719
1727
if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1720
- LDBG (" Not a depth-wise 1D conv, dynamic shapes are not supported\n " );
1728
+ LDBG (" Not a 1D depth-wise WC conv, dynamic shapes are not supported\n " );
1721
1729
return failure ();
1722
1730
}
1723
1731
@@ -1735,9 +1743,10 @@ static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv) {
1735
1743
return success ();
1736
1744
}
1737
1745
1738
- static LogicalResult vectorizeDynamicLinalgOpPrecondition (linalg::LinalgOp op) {
1746
+ static LogicalResult
1747
+ vectorizeDynamicLinalgOpPrecondition (linalg::LinalgOp op, bool flatten1DDepthwiseConv) {
1739
1748
if (isa<ConvolutionOpInterface>(op.getOperation ()))
1740
- return vectorizeDynamicConvOpPrecondition (op);
1749
+ return vectorizeDynamicConvOpPrecondition (op, flatten1DDepthwiseConv );
1741
1750
1742
1751
// TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
1743
1752
// linalg.copy ops and ops that implement ContractionOpInterface for now.
@@ -1807,7 +1816,8 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
1807
1816
static LogicalResult
1808
1817
vectorizeLinalgOpPrecondition (LinalgOp linalgOp,
1809
1818
ArrayRef<int64_t > inputVectorSizes,
1810
- bool vectorizeNDExtract) {
1819
+ bool vectorizeNDExtract,
1820
+ bool flatten1DDepthwiseConv) {
1811
1821
// tensor with dimension of 0 cannot be vectorized.
1812
1822
if (llvm::is_contained (linalgOp.getStaticShape (), 0 ))
1813
1823
return failure ();
@@ -1817,8 +1827,8 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
1817
1827
inputVectorSizes)))
1818
1828
return failure ();
1819
1829
1820
- if (linalgOp.hasDynamicShape () &&
1821
- failed ( vectorizeDynamicLinalgOpPrecondition ( linalgOp))) {
1830
+ if (linalgOp.hasDynamicShape () && failed ( vectorizeDynamicLinalgOpPrecondition (
1831
+ linalgOp, flatten1DDepthwiseConv ))) {
1822
1832
LDBG (" Dynamically-shaped op failed vectorization pre-conditions\n " );
1823
1833
return failure ();
1824
1834
}
@@ -1946,15 +1956,17 @@ vectorizeScalableVectorPrecondition(Operation *op,
1946
1956
1947
1957
LogicalResult mlir::linalg::vectorizeOpPrecondition (
1948
1958
Operation *op, ArrayRef<int64_t > inputVectorSizes,
1949
- ArrayRef<bool > inputScalableVecDims, bool vectorizeNDExtract) {
1959
+ ArrayRef<bool > inputScalableVecDims, bool vectorizeNDExtract,
1960
+ bool flatten1DDepthwiseConv) {
1950
1961
if (failed (vectorizeScalableVectorPrecondition (op, inputVectorSizes,
1951
1962
inputScalableVecDims)))
1952
1963
return failure ();
1953
1964
1954
1965
return TypeSwitch<Operation *, LogicalResult>(op)
1955
1966
.Case <linalg::LinalgOp>([&](auto linalgOp) {
1956
1967
return vectorizeLinalgOpPrecondition (linalgOp, inputVectorSizes,
1957
- vectorizeNDExtract);
1968
+ vectorizeNDExtract,
1969
+ flatten1DDepthwiseConv);
1958
1970
})
1959
1971
.Case <tensor::PadOp>([&](auto padOp) {
1960
1972
return vectorizePadOpPrecondition (padOp, inputVectorSizes);
@@ -2003,7 +2015,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2003
2015
LLVM_DEBUG (llvm::dbgs () << " \n " );
2004
2016
2005
2017
if (failed (vectorizeOpPrecondition (op, inputVectorSizes, inputScalableVecDims,
2006
- vectorizeNDExtract))) {
2018
+ vectorizeNDExtract, flatten1DDepthwiseConv ))) {
2007
2019
LDBG (" Vectorization pre-conditions failed\n " );
2008
2020
return failure ();
2009
2021
}
@@ -3180,6 +3192,9 @@ struct Conv1DGenerator
3180
3192
scalableChDim = channelDimScalableFlag;
3181
3193
useMasking = true ;
3182
3194
}
3195
+
3196
+ assert (!(useMasking && flatten) && " Unsupported flattened conv with dynamic shapes" );
3197
+
3183
3198
// out{n, w, c}
3184
3199
bindShapeDims (resShapedType, nSize, wSize);
3185
3200
@@ -3282,10 +3297,15 @@ struct Conv1DGenerator
3282
3297
return kw * (wSize / wSizeStep) + w;
3283
3298
};
3284
3299
3300
+ // Note - the scalable flags are ignored as flattening combined with
3301
+ // scalable vectorization is not supported.
3285
3302
auto inOutFlattenSliceSizes =
3286
3303
SmallVector<int64_t >{nSize, wSizeStep * cSize};
3287
- auto lhsCastType = VectorType::get (inOutFlattenSliceSizes, lhsEltType);
3288
- auto resCastType = VectorType::get (inOutFlattenSliceSizes, resEltType);
3304
+ auto lhsTypeAfterFlattening =
3305
+ VectorType::get (inOutFlattenSliceSizes, lhsEltType);
3306
+ auto resTypeAfterFlattening =
3307
+ VectorType::get (inOutFlattenSliceSizes, resEltType);
3308
+
3289
3309
// Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
3290
3310
for (int64_t kw = 0 ; kw < kwSize; ++kw) {
3291
3311
for (int64_t w = 0 ; w < wSize; w += wSizeStep) {
@@ -3295,9 +3315,9 @@ struct Conv1DGenerator
3295
3315
// Flatten the input and output vectors (collapse the channel
3296
3316
// dimension)
3297
3317
lhsVal = rewriter.create <vector::ShapeCastOp>(
3298
- loc, lhsCastType , lhsVals[linearIndex (kw, w)]);
3299
- resVal = rewriter.create <vector::ShapeCastOp>(loc, resCastType,
3300
- resVals[w]);
3318
+ loc, lhsTypeAfterFlattening , lhsVals[linearIndex (kw, w)]);
3319
+ resVal = rewriter.create <vector::ShapeCastOp>(
3320
+ loc, resTypeAfterFlattening, resVals[w]);
3301
3321
}
3302
3322
resVals[w] = depthwiseConv1dSliceAsMulAcc (rewriter, loc, lhsVal,
3303
3323
rhsVals[kw], resVal, flatten);
@@ -3353,6 +3373,10 @@ struct Conv1DGenerator
3353
3373
lhs = promote (rewriter, loc, lhs, resTy);
3354
3374
3355
3375
if (flatten) {
3376
+ // NOTE: This following logic won't work for scalable vectors. For this
3377
+ // reason, "flattening" is not supported when shapes are dynamic (this
3378
+ // should be captured by one of the pre-conditions).
3379
+
3356
3380
// There are two options for handling the filter:
3357
3381
// * shape_cast(broadcast(filter))
3358
3382
// * broadcast(shuffle(filter))
0 commit comments