@@ -55,6 +55,7 @@ using namespace mlir::linalg;
55
55
static FailureOr<Operation *>
56
56
vectorizeConvolution (RewriterBase &rewriter, LinalgOp convOp,
57
57
ArrayRef<int64_t > inputVecSizes = {},
58
+ ArrayRef<bool > inputVecScalableFlags = {},
58
59
bool flatten1DDepthwiseConv = false );
59
60
60
61
// / Return the unique instance of OpType in `block` if it is indeed unique.
@@ -1713,21 +1714,31 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
1713
1714
return success ();
1714
1715
}
1715
1716
1716
- static LogicalResult vectorizeDynamicLinalgOpPrecondition (linalg::LinalgOp op) {
1717
- if (auto conv = dyn_cast<linalg::DepthwiseConv1DNwcWcOp>(op.getOperation ())) {
1718
- // Support dynamic shapes in 1D depthwise convolution, but only in the
1719
- // _channel_ dimension. That's exclusively to support scalable
1720
- // vectorisation.
1721
- auto lhsShaped = op.getDpsInputOperand (0 )->get ();
1722
- ArrayRef<int64_t > lhsShape =
1723
- cast<ShapedType>(lhsShaped.getType ()).getShape ();
1724
- auto shapeWithoutCh = lhsShape.drop_back (1 );
1725
- if (ShapedType::isDynamicShape (shapeWithoutCh))
1726
- return failure ();
1717
+ static LogicalResult vectorizeDynamicConvOpPrecondition (linalg::LinalgOp conv) {
1718
+ if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv.getOperation ())) {
1719
+ LDBG (" Not a depth-wise 1D conv, dynamic shapes are not supported\n " );
1720
+ return failure ();
1721
+ }
1727
1722
1728
- return success ();
1723
+ // Support dynamic shapes in 1D depthwise convolution, but only in the
1724
+ // _channel_ dimension. That's exclusively to support scalable
1725
+ // vectorisation.
1726
+ auto lhs = conv.getDpsInputOperand (0 )->get ();
1727
+ ArrayRef<int64_t > lhsShape = cast<ShapedType>(lhs.getType ()).getShape ();
1728
+ auto shapeWithoutCh = lhsShape.drop_back (1 );
1729
+ if (ShapedType::isDynamicShape (shapeWithoutCh)) {
1730
+ LDBG (" Dynamically-shaped op vectorization precondition failed: only "
1731
+ " channel dim can be dynamic\n " );
1732
+ return failure ();
1729
1733
}
1730
1734
1735
+ return success ();
1736
+ }
1737
+
1738
+ static LogicalResult vectorizeDynamicLinalgOpPrecondition (linalg::LinalgOp op) {
1739
+ if (isa<ConvolutionOpInterface>(op.getOperation ()))
1740
+ return vectorizeDynamicConvOpPrecondition (op);
1741
+
1731
1742
// TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
1732
1743
// linalg.copy ops and ops that implement ContractionOpInterface for now.
1733
1744
if (!isElementwise (op) &&
@@ -2016,7 +2027,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2016
2027
// inference.
2017
2028
if (isa<ConvolutionOpInterface>(linalgOp.getOperation ())) {
2018
2029
FailureOr<Operation *> convOr = vectorizeConvolution (
2019
- rewriter, linalgOp, inputVectorSizes, flatten1DDepthwiseConv);
2030
+ rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2031
+ flatten1DDepthwiseConv);
2020
2032
if (succeeded (convOr)) {
2021
2033
llvm::append_range (results, (*convOr)->getResults ());
2022
2034
return success ();
@@ -3150,19 +3162,21 @@ struct Conv1DGenerator
3150
3162
// / TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
3151
3163
// / > 1.
3152
3164
FailureOr<Operation *> depthwiseConv (uint64_t channelDimVecSize,
3165
+ bool channelDimScalableFlag,
3153
3166
bool flatten) {
3154
3167
if (!valid)
3155
3168
return rewriter.notifyMatchFailure (op, " unvectorizable depthwise conv" );
3156
3169
3157
3170
bool scalableChDim = false ;
3171
+ bool useMasking = false ;
3158
3172
int64_t nSize, wSize, cSize, kwSize;
3159
3173
// kernel{kw, c}
3160
3174
bindShapeDims (rhsShapedType, kwSize, cSize);
3161
- // Dynamic channel size implies scalable vectorisation
3162
3175
if (ShapedType::isDynamic (cSize)) {
3163
3176
assert (channelDimVecSize != 0 && " Channel dim vec size must be > 0" );
3164
3177
cSize = channelDimVecSize;
3165
- scalableChDim = true ;
3178
+ scalableChDim = channelDimScalableFlag;
3179
+ useMasking = true ;
3166
3180
}
3167
3181
// out{n, w, c}
3168
3182
bindShapeDims (resShapedType, nSize, wSize);
@@ -3197,13 +3211,10 @@ struct Conv1DGenerator
3197
3211
auto maybeMaskXferOp = [&](ArrayRef<int64_t > maskShape,
3198
3212
ArrayRef<bool > scalableDims,
3199
3213
Operation *opToMask) {
3200
- bool scalableChDim = scalableDims.back ();
3201
- if (!scalableChDim)
3214
+ if (!useMasking)
3202
3215
return opToMask;
3203
-
3204
3216
auto maskType =
3205
3217
VectorType::get (maskShape, rewriter.getI1Type (), scalableDims);
3206
-
3207
3218
SmallVector<OpFoldResult> mixedSourceDims =
3208
3219
hasTensorSemantics
3209
3220
? TypeSwitch<Operation *, SmallVector<OpFoldResult>>(opToMask)
@@ -3479,6 +3490,7 @@ struct Conv1DGenerator
3479
3490
// / Entry point that transposes into the common form:
3480
3491
// / {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3481
3492
FailureOr<Operation *> generateDilatedConv (uint64_t vecChDimSize = 0 ,
3493
+ bool vecChDimScalableFlag = true ,
3482
3494
bool flatten = false ) {
3483
3495
AffineExpr n, w, c, kw;
3484
3496
bindDims (ctx, n, w, c, kw);
@@ -3490,7 +3502,7 @@ struct Conv1DGenerator
3490
3502
if (layout ({/* lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3491
3503
/* rhsIndex*/ {kw, c},
3492
3504
/* resIndex*/ {n, w, c}}))
3493
- return depthwiseConv (vecChDimSize, flatten);
3505
+ return depthwiseConv (vecChDimSize, vecChDimScalableFlag, flatten);
3494
3506
3495
3507
return rewriter.notifyMatchFailure (op, " not a depthwise::Nwc layout" );
3496
3508
}
@@ -3556,10 +3568,9 @@ struct Conv1DGenerator
3556
3568
3557
3569
// / Helper function to vectorize a LinalgOp with convolution semantics.
3558
3570
// TODO: extend the generic vectorization to support windows and drop this.
3559
- static FailureOr<Operation *>
3560
- vectorizeConvolution (RewriterBase &rewriter, LinalgOp op,
3561
- ArrayRef<int64_t > inputVecSizes,
3562
- bool flatten1DDepthwiseConv) {
3571
+ static FailureOr<Operation *> vectorizeConvolution (
3572
+ RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t > inputVecSizes,
3573
+ ArrayRef<bool > inputScalableVecDims, bool flatten1DDepthwiseConv) {
3563
3574
// The ConvolutionOpInterface gives us guarantees of existence for
3564
3575
// strides/dilations. However, we do not need to rely on those, we can
3565
3576
// simply use them if present, otherwise use the default and let the generic
@@ -3586,12 +3597,15 @@ vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
3586
3597
return res;
3587
3598
3588
3599
uint64_t vecChDimSize = ShapedType::kDynamic ;
3600
+ bool vecChDimScalableFlag = false ;
3589
3601
if (!inputVecSizes.empty ()) {
3590
3602
// Only use the input vector size corresponding to the channel dim. Other
3591
3603
// vector dims will be inferred from the Ops.
3592
3604
vecChDimSize = inputVecSizes[2 ];
3605
+ vecChDimScalableFlag = inputScalableVecDims[2 ];
3593
3606
}
3594
- return e.generateDilatedConv (vecChDimSize, flatten1DDepthwiseConv);
3607
+ return e.generateDilatedConv (vecChDimSize, vecChDimScalableFlag,
3608
+ flatten1DDepthwiseConv);
3595
3609
}
3596
3610
3597
3611
struct VectorizeConvolution : public OpInterfaceRewritePattern <LinalgOp> {
0 commit comments