Skip to content

Commit 2a8ce8a

Browse files
committed
fixup! [mlir][linalg] Add scalable vectorisation for depthwise convolutions
* Add missing dyn dimension in a test * Make sure "flattening" + "masked vectorisation" are not allowed
1 parent 2528f8e commit 2a8ce8a

File tree

5 files changed

+49
-24
lines changed

5 files changed

+49
-24
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,8 @@ LogicalResult promoteSubviewsPrecondition(Operation *op,
460460
LogicalResult vectorizeOpPrecondition(Operation *op,
461461
ArrayRef<int64_t> inputVectorSizes = {},
462462
ArrayRef<bool> inputScalableVecDims = {},
463-
bool vectorizeNDExtract = false);
463+
bool vectorizeNDExtract = false,
464+
bool flatten1DDepthwiseConv = false);
464465

465466
//===----------------------------------------------------------------------===//
466467
// Transformations exposed as functional-style API calls.

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1715,9 +1715,17 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
17151715
return success();
17161716
}
17171717

1718-
static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv) {
1718+
static LogicalResult
1719+
vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv,
1720+
bool flatten1DDepthwiseConv) {
1721+
if (flatten1DDepthwiseConv) {
1722+
LDBG("Vectorization of flattened convs with dynamic shapes is not "
1723+
"supported\n");
1724+
return failure();
1725+
}
1726+
17191727
if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) {
1720-
LDBG("Not a depth-wise 1D conv, dynamic shapes are not supported\n");
1728+
LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n");
17211729
return failure();
17221730
}
17231731

@@ -1735,9 +1743,10 @@ static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv) {
17351743
return success();
17361744
}
17371745

1738-
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
1746+
static LogicalResult
1747+
vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, bool flatten1DDepthwiseConv) {
17391748
if (isa<ConvolutionOpInterface>(op.getOperation()))
1740-
return vectorizeDynamicConvOpPrecondition(op);
1749+
return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
17411750

17421751
// TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
17431752
// linalg.copy ops and ops that implement ContractionOpInterface for now.
@@ -1807,7 +1816,8 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
18071816
static LogicalResult
18081817
vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
18091818
ArrayRef<int64_t> inputVectorSizes,
1810-
bool vectorizeNDExtract) {
1819+
bool vectorizeNDExtract,
1820+
bool flatten1DDepthwiseConv) {
18111821
// tensor with dimension of 0 cannot be vectorized.
18121822
if (llvm::is_contained(linalgOp.getStaticShape(), 0))
18131823
return failure();
@@ -1817,8 +1827,8 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
18171827
inputVectorSizes)))
18181828
return failure();
18191829

1820-
if (linalgOp.hasDynamicShape() &&
1821-
failed(vectorizeDynamicLinalgOpPrecondition(linalgOp))) {
1830+
if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition(
1831+
linalgOp, flatten1DDepthwiseConv))) {
18221832
LDBG("Dynamically-shaped op failed vectorization pre-conditions\n");
18231833
return failure();
18241834
}
@@ -1946,15 +1956,17 @@ vectorizeScalableVectorPrecondition(Operation *op,
19461956

19471957
LogicalResult mlir::linalg::vectorizeOpPrecondition(
19481958
Operation *op, ArrayRef<int64_t> inputVectorSizes,
1949-
ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract) {
1959+
ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract,
1960+
bool flatten1DDepthwiseConv) {
19501961
if (failed(vectorizeScalableVectorPrecondition(op, inputVectorSizes,
19511962
inputScalableVecDims)))
19521963
return failure();
19531964

19541965
return TypeSwitch<Operation *, LogicalResult>(op)
19551966
.Case<linalg::LinalgOp>([&](auto linalgOp) {
19561967
return vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
1957-
vectorizeNDExtract);
1968+
vectorizeNDExtract,
1969+
flatten1DDepthwiseConv);
19581970
})
19591971
.Case<tensor::PadOp>([&](auto padOp) {
19601972
return vectorizePadOpPrecondition(padOp, inputVectorSizes);
@@ -2003,7 +2015,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
20032015
LLVM_DEBUG(llvm::dbgs() << "\n");
20042016

20052017
if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims,
2006-
vectorizeNDExtract))) {
2018+
vectorizeNDExtract, flatten1DDepthwiseConv))) {
20072019
LDBG("Vectorization pre-conditions failed\n");
20082020
return failure();
20092021
}
@@ -3180,6 +3192,9 @@ struct Conv1DGenerator
31803192
scalableChDim = channelDimScalableFlag;
31813193
useMasking = true;
31823194
}
3195+
3196+
assert(!(useMasking && flatten) && "Unsupported flattened conv with dynamic shapes");
3197+
31833198
// out{n, w, c}
31843199
bindShapeDims(resShapedType, nSize, wSize);
31853200

@@ -3282,10 +3297,15 @@ struct Conv1DGenerator
32823297
return kw * (wSize / wSizeStep) + w;
32833298
};
32843299

3300+
// Note - the scalable flags are ignored as flattening combined with
3301+
// scalable vectorization is not supported.
32853302
auto inOutFlattenSliceSizes =
32863303
SmallVector<int64_t>{nSize, wSizeStep * cSize};
3287-
auto lhsCastType = VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3288-
auto resCastType = VectorType::get(inOutFlattenSliceSizes, resEltType);
3304+
auto lhsTypeAfterFlattening =
3305+
VectorType::get(inOutFlattenSliceSizes, lhsEltType);
3306+
auto resTypeAfterFlattening =
3307+
VectorType::get(inOutFlattenSliceSizes, resEltType);
3308+
32893309
// Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
32903310
for (int64_t kw = 0; kw < kwSize; ++kw) {
32913311
for (int64_t w = 0; w < wSize; w += wSizeStep) {
@@ -3295,9 +3315,9 @@ struct Conv1DGenerator
32953315
// Flatten the input and output vectors (collapse the channel
32963316
// dimension)
32973317
lhsVal = rewriter.create<vector::ShapeCastOp>(
3298-
loc, lhsCastType, lhsVals[linearIndex(kw, w)]);
3299-
resVal = rewriter.create<vector::ShapeCastOp>(loc, resCastType,
3300-
resVals[w]);
3318+
loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]);
3319+
resVal = rewriter.create<vector::ShapeCastOp>(
3320+
loc, resTypeAfterFlattening, resVals[w]);
33013321
}
33023322
resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal,
33033323
rhsVals[kw], resVal, flatten);
@@ -3353,6 +3373,10 @@ struct Conv1DGenerator
33533373
lhs = promote(rewriter, loc, lhs, resTy);
33543374

33553375
if (flatten) {
3376+
// NOTE: This following logic won't work for scalable vectors. For this
3377+
// reason, "flattening" is not supported when shapes are dynamic (this
3378+
// should be captured by one of the pre-conditions).
3379+
33563380
// There are two options for handling the filter:
33573381
// * shape_cast(broadcast(filter))
33583382
// * broadcast(shuffle(filter))

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -306,14 +306,14 @@ SmallVector<OpFoldResult> vector::getMixedSizesXfer(bool hasTensorSemantics,
306306
RewriterBase &rewriter) {
307307
auto loc = xfer->getLoc();
308308

309-
Value blah = TypeSwitch<Operation *, Value>(xfer)
309+
Value base = TypeSwitch<Operation *, Value>(xfer)
310310
.Case<vector::TransferReadOp>(
311311
[&](auto readOp) { return readOp.getSource(); })
312312
.Case<vector::TransferWriteOp>(
313313
[&](auto writeOp) { return writeOp.getOperand(1); });
314314

315315
SmallVector<OpFoldResult> mixedSourceDims =
316-
hasTensorSemantics ? tensor::getMixedSizes(rewriter, loc, blah)
317-
: memref::getMixedSizes(rewriter, loc, blah);
316+
hasTensorSemantics ? tensor::getMixedSizes(rewriter, loc, base)
317+
: memref::getMixedSizes(rewriter, loc, base);
318318
return mixedSourceDims;
319319
}

mlir/test/Dialect/Linalg/vectorization-unsupported.mlir

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@ module attributes {transform.with_named_sequence} {
1919

2020
// -----
2121

22-
func.func @depthwise_conv1d_ncw_cw(%input: memref<3x5x4xf32>, %filter: memref<5x1xf32>, %output: memref<3x5x4xf32>) {
22+
// Masked vectorisation of 1D depthwise CW convs is not yet supported
23+
24+
func.func @depthwise_conv1d_ncw_cw(%input: memref<3x?x4xf32>, %filter: memref<?x1xf32>, %output: memref<3x?x4xf32>) {
2325
// expected-error @+1 {{Attempted to vectorize, but failed}}
2426
linalg.depthwise_conv_1d_ncw_cw
2527
{dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
26-
ins(%input, %filter : memref<3x5x4xf32>, memref<5x1xf32>)
27-
outs(%output : memref<3x5x4xf32>)
28+
ins(%input, %filter : memref<3x?x4xf32>, memref<?x1xf32>)
29+
outs(%output : memref<3x?x4xf32>)
2830
return
2931
}
3032

mlir/test/Dialect/Linalg/vectorize-conv-masked-and-scalable.mlir

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,6 @@ module attributes {transform.with_named_sequence} {
120120
// CHECK: %[[OUT:.*]] = vector.mask %[[MASK_OUT]] { vector.transfer_write %[[OUT_INS]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : vector<1x8x[4]xi8>, tensor<1x8x?xi8> } : vector<1x8x[4]xi1> -> tensor<1x8x?xi8>
121121
// CHECK: return %[[OUT]] : tensor<1x8x?xi8>
122122

123-
124-
125123
// -----
126124

127125
func.func @depthwise_conv1d_nwc_wc_3x5x4xf32_memref_dilation_2(

0 commit comments

Comments
 (0)