Skip to content

Commit 0a06be8

Browse files
committed
fixup! [mlir][linalg] Add scalable vectorisation for depthwise convolutions
Addressing PR comments: - add CSE in tests, update check-lines accordingly - add support for plain (non-scalable) masked vectorisation - moved pre-conditions for vectorisation to a dedicated hook
1 parent d5287b2 commit 0a06be8

File tree

2 files changed

+153
-113
lines changed

2 files changed

+153
-113
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ using namespace mlir::linalg;
5555
static FailureOr<Operation *>
5656
vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
5757
ArrayRef<int64_t> inputVecSizes = {},
58+
ArrayRef<bool> inputVecScalableFlags = {},
5859
bool flatten1DDepthwiseConv = false);
5960

6061
/// Return the unique instance of OpType in `block` if it is indeed unique.
@@ -1713,21 +1714,31 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
17131714
return success();
17141715
}
17151716

1716-
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
1717-
if (auto conv = dyn_cast<linalg::DepthwiseConv1DNwcWcOp>(op.getOperation())) {
1718-
// Support dynamic shapes in 1D depthwise convolution, but only in the
1719-
// _channel_ dimension. That's exclusively to support scalable
1720-
// vectorisation.
1721-
auto lhsShaped = op.getDpsInputOperand(0)->get();
1722-
ArrayRef<int64_t> lhsShape =
1723-
cast<ShapedType>(lhsShaped.getType()).getShape();
1724-
auto shapeWithoutCh = lhsShape.drop_back(1);
1725-
if (ShapedType::isDynamicShape(shapeWithoutCh))
1726-
return failure();
1717+
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv) {
1718+
if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv.getOperation())) {
1719+
LDBG("Not a depth-wise 1D conv, dynamic shapes are not supported\n");
1720+
return failure();
1721+
}
17271722

1728-
return success();
1723+
// Support dynamic shapes in 1D depthwise convolution, but only in the
1724+
// _channel_ dimension. That's exclusively to support scalable
1725+
// vectorisation.
1726+
auto lhs = conv.getDpsInputOperand(0)->get();
1727+
ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape();
1728+
auto shapeWithoutCh = lhsShape.drop_back(1);
1729+
if (ShapedType::isDynamicShape(shapeWithoutCh)) {
1730+
LDBG("Dynamically-shaped op vectorization precondition failed: only "
1731+
"channel dim can be dynamic\n");
1732+
return failure();
17291733
}
17301734

1735+
return success();
1736+
}
1737+
1738+
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
1739+
if (isa<ConvolutionOpInterface>(op.getOperation()))
1740+
return vectorizeDynamicConvOpPrecondition(op);
1741+
17311742
// TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
17321743
// linalg.copy ops and ops that implement ContractionOpInterface for now.
17331744
if (!isElementwise(op) &&
@@ -2016,7 +2027,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
20162027
// inference.
20172028
if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
20182029
FailureOr<Operation *> convOr = vectorizeConvolution(
2019-
rewriter, linalgOp, inputVectorSizes, flatten1DDepthwiseConv);
2030+
rewriter, linalgOp, inputVectorSizes, inputScalableVecDims,
2031+
flatten1DDepthwiseConv);
20202032
if (succeeded(convOr)) {
20212033
llvm::append_range(results, (*convOr)->getResults());
20222034
return success();
@@ -3150,19 +3162,21 @@ struct Conv1DGenerator
31503162
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
31513163
/// > 1.
31523164
FailureOr<Operation *> depthwiseConv(uint64_t channelDimVecSize,
3165+
bool channelDimScalableFlag,
31533166
bool flatten) {
31543167
if (!valid)
31553168
return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
31563169

31573170
bool scalableChDim = false;
3171+
bool useMasking = false;
31583172
int64_t nSize, wSize, cSize, kwSize;
31593173
// kernel{kw, c}
31603174
bindShapeDims(rhsShapedType, kwSize, cSize);
3161-
// Dynamic channel size implies scalable vectorisation
31623175
if (ShapedType::isDynamic(cSize)) {
31633176
assert(channelDimVecSize != 0 && "Channel dim vec size must be > 0");
31643177
cSize = channelDimVecSize;
3165-
scalableChDim = true;
3178+
scalableChDim = channelDimScalableFlag;
3179+
useMasking = true;
31663180
}
31673181
// out{n, w, c}
31683182
bindShapeDims(resShapedType, nSize, wSize);
@@ -3197,13 +3211,10 @@ struct Conv1DGenerator
31973211
auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
31983212
ArrayRef<bool> scalableDims,
31993213
Operation *opToMask) {
3200-
bool scalableChDim = scalableDims.back();
3201-
if (!scalableChDim)
3214+
if (!useMasking)
32023215
return opToMask;
3203-
32043216
auto maskType =
32053217
VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
3206-
32073218
SmallVector<OpFoldResult> mixedSourceDims =
32083219
hasTensorSemantics
32093220
? TypeSwitch<Operation *, SmallVector<OpFoldResult>>(opToMask)
@@ -3479,6 +3490,7 @@ struct Conv1DGenerator
34793490
/// Entry point that transposes into the common form:
34803491
/// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
34813492
FailureOr<Operation *> generateDilatedConv(uint64_t vecChDimSize = 0,
3493+
bool vecChDimScalableFlag = true,
34823494
bool flatten = false) {
34833495
AffineExpr n, w, c, kw;
34843496
bindDims(ctx, n, w, c, kw);
@@ -3490,7 +3502,7 @@ struct Conv1DGenerator
34903502
if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
34913503
/*rhsIndex*/ {kw, c},
34923504
/*resIndex*/ {n, w, c}}))
3493-
return depthwiseConv(vecChDimSize, flatten);
3505+
return depthwiseConv(vecChDimSize, vecChDimScalableFlag, flatten);
34943506

34953507
return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
34963508
}
@@ -3556,10 +3568,9 @@ struct Conv1DGenerator
35563568

35573569
/// Helper function to vectorize a LinalgOp with convolution semantics.
35583570
// TODO: extend the generic vectorization to support windows and drop this.
3559-
static FailureOr<Operation *>
3560-
vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
3561-
ArrayRef<int64_t> inputVecSizes,
3562-
bool flatten1DDepthwiseConv) {
3571+
static FailureOr<Operation *> vectorizeConvolution(
3572+
RewriterBase &rewriter, LinalgOp op, ArrayRef<int64_t> inputVecSizes,
3573+
ArrayRef<bool> inputScalableVecDims, bool flatten1DDepthwiseConv) {
35633574
// The ConvolutionOpInterface gives us guarantees of existence for
35643575
// strides/dilations. However, we do not need to rely on those, we can
35653576
// simply use them if present, otherwise use the default and let the generic
@@ -3586,12 +3597,15 @@ vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
35863597
return res;
35873598

35883599
uint64_t vecChDimSize = ShapedType::kDynamic;
3600+
bool vecChDimScalableFlag = false;
35893601
if (!inputVecSizes.empty()) {
35903602
// Only use the input vector size corresponding to the channel dim. Other
35913603
// vector dims will be inferred from the Ops.
35923604
vecChDimSize = inputVecSizes[2];
3605+
vecChDimScalableFlag = inputScalableVecDims[2];
35933606
}
3594-
return e.generateDilatedConv(vecChDimSize, flatten1DDepthwiseConv);
3607+
return e.generateDilatedConv(vecChDimSize, vecChDimScalableFlag,
3608+
flatten1DDepthwiseConv);
35953609
}
35963610

35973611
struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {

0 commit comments

Comments
 (0)