Skip to content

[mlir][linalg][conv] Flatten the channel dimension when vectorizing #71918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 6, 2023

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Nov 10, 2023

The current vectorization of 1D depthwise convolutions in Linalg is
sub-optimal for tensor with a low number of channel dimensions, e.g.:

linalg.depthwise_conv_1d_nwc_wc
    {dilations = dense<1> : vector<1xi64>,
    strides = dense<1> : vector<1xi64>}
    ins(%input, %filter : tensor<1x8x3xi8>, tensor<1x3xi8>)
    outs(%output : tensor<1x8x3xi8>) -> tensor<1x8x3xi8>

That's due to the fact that ultimately (i.e. at LLVM level),
vectorization happens along the trailing dimension (i.e. the channel
dimension). In this case it leads to vectors with 3 elements (or worse,
if there's e.g. only 1 channel dimension). For comparison, a 128 bit
wide vector registers can hold 16 x i8.

Instead, this patch adds an option to flatten/collapse the channel
dimension into the width dimension of the input/filter/output using
vector.shape_cast operation:

    %sc_input = vector.shape_cast %input : vector<1x8x3xi8> to vector<1x24xi8>
    %sc_output = vector.shape_cast %output : vector<1x8x3xi8> to vector<1x24xi8>
    %b_filter = vector.broadcast %filter : vector<3xi8> to vector<1x8x3xi8>
    %sc_filter = vector.shape_cast %b_filter : vector<1x8x3xi8> to vector<1x24xi8>

This new vectorization mode is implemented in depthwiseConv by
inserting vector.shape_cast Ops before and after
depthwiseConv1dSliceAsMulAcc is invoked. It can be selected through
e.g. a transform dialect attribute:

  transform.structured.vectorize_children_and_apply_patterns %conv {flatten_1d_depthwise_conv}

A forthcoming patch will implement a strategy to automatically switch
between the two implementations, depending on the shape of the input
tensors.

Co-authored by: Bradley Smith [email protected]

@llvmbot
Copy link
Member

llvmbot commented Nov 10, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Andrzej Warzyński (banach-space)

Changes

The current vectorization of 1D depthwise convolutions in Linalg is
sub-optimal for tensor with a low number of channel dimensions, e.g.:

linalg.depthwise_conv_1d_nwc_wc
    {dilations = dense&lt;1&gt; : vector&lt;1xi64&gt;,
    strides = dense&lt;1&gt; : vector&lt;1xi64&gt;}
    ins(%input, %filter : tensor&lt;1x8x3xi8&gt;, tensor&lt;1x3xi8&gt;)
    outs(%output : tensor&lt;1x8x3xi8&gt;) -&gt; tensor&lt;1x8x3xi8&gt;

That's due to the fact that ultimately (i.e. at LLVM level),
vectorization happens along the trailing dimension (i.e. the channel
dimension). In this case it leads to vectors with 3 elements (or worse,
if there's e.g. only 1 channel dimension). For comparison, a 128 bit
wide vector registers can hold 16 x i8.

Instead, this patch adds an option to flatten/collapse the channel
dimension into the width dimension of the input/output:

    %collapsed = tensor.collapse_shape %input [[0], [1, 2]] : tensor&lt;1x8x3xi8&gt; into tensor&lt;1x24xi8&gt;
    %collapsed_0 = tensor.collapse_shape %outpu [[0], [1, 2]] : tensor&lt;1x8x3xi8&gt; into tensor&lt;1x24xi8&gt;

(Note that for this to work, the filter is broadcast rather than
re-shaped. Please see the test cases for details).

The new vectorization strategy is implemented in depthwiseConvFlatten,
which was implemented based on depthwiseConvGeneric (i.e. the original
vectorization hook). The current vectorization is preserved and kept as
the default option. New vectorization can be selected through e.g. a
transform dialect attribute:

  transform.structured.vectorize_children_and_apply_patterns %conv {flatten_1d_depthwise_conv}

A forthcoming patch will implement a strategy to automatically switch
between the two implementations, depending on the shape of the input
tensors.

Co-authored by: Bradley Smith <[email protected]>


Patch is 38.22 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71918.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+3-1)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+2-1)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+16-5)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+217-11)
  • (added) mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir (+222)
  • (modified) mlir/test/Dialect/Linalg/vectorize-convolution.mlir (+6-6)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index f1c3d717f1fa951..310efe164f93950 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2034,6 +2034,7 @@ def VectorizeChildrenAndApplyPatternsOp :
   let arguments = (ins TransformHandleTypeInterface:$target,
                    UnitAttr:$vectorize_padding,
                    UnitAttr:$vectorize_nd_extract,
+                   UnitAttr:$flatten_1d_depthwise_conv,
                    UnitAttr:$disable_multi_reduction_to_contract_patterns,
                    UnitAttr:$disable_transfer_permutation_map_lowering_patterns);
   let results = (outs TransformHandleTypeInterface:$transformed);
@@ -2045,7 +2046,8 @@ def VectorizeChildrenAndApplyPatternsOp :
   let builders = [
     OpBuilder<(ins "Value":$target,
                CArg<"bool", "false">:$vectorizePadding,
-               CArg<"bool", "false">:$vectorizeNDExtract)>,
+               CArg<"bool", "false">:$vectorizeNDExtract,
+               CArg<"bool", "false">:$flatten1DDepthwise)>
   ];
   let extraClassDeclaration = [{
     ::mlir::DiagnosedSilenceableFailure applyToOne(
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6547648f7495c31..a4aee1f45249c2b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -753,7 +753,8 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
 LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
                         ArrayRef<int64_t> inputVectorSizes = {},
                         ArrayRef<bool> inputScalableVecDims = {},
-                        bool vectorizeNDExtract = false);
+                        bool vectorizeNDExtract = false,
+                        bool flatten1DDepthwiseConv = false);
 
 /// Emit a suitable vector form for a Copy op with fully static shape.
 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index de4965f937162ea..35e8be7806928e1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2937,7 +2937,7 @@ LogicalResult TileUsingForallOp::verify() {
 
 void transform::VectorizeChildrenAndApplyPatternsOp::build(
     OpBuilder &builder, OperationState &result, Value target,
-    bool vectorizePadding, bool vectorizeExtract) {
+    bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
   result.addOperands(target);
   if (vectorizePadding) {
     result.addAttribute(
@@ -2951,6 +2951,12 @@ void transform::VectorizeChildrenAndApplyPatternsOp::build(
             result.name),
         builder.getUnitAttr());
   }
+  if (flatten1DDepthwiseConv) {
+    result.addAttribute(
+        VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
+            result.name),
+        builder.getUnitAttr());
+  }
   result.addTypes(transform::AnyOpType::get(builder.getContext()));
 }
 
@@ -2959,22 +2965,26 @@ namespace {
 /// VectorizeChildrenAndApplyPatternsOp::applyToOne.
 struct VectorizationPattern : public RewritePattern {
   explicit VectorizationPattern(MLIRContext *context,
-                                bool vectorizeExtract = false)
+                                bool vectorizeExtract = false,
+                                bool flattenConv = false)
       : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
-        vectorizeNDExtract(vectorizeExtract) {}
+        vectorizeNDExtract(vectorizeExtract),
+        flatten1DDepthwiseConv(flattenConv) {}
   LogicalResult matchAndRewrite(Operation *op,
                                 PatternRewriter &rewriter) const override {
     LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
     if (!linalgOp)
       return rewriter.notifyMatchFailure(op, "expected Linalg Op");
     return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{},
-                     /*scalableVecDims=*/{}, vectorizeNDExtract);
+                     /*scalableVecDims=*/{}, vectorizeNDExtract,
+                     flatten1DDepthwiseConv);
   }
 
 private:
   /// Controls whether to vectorize `tensor.extract` when the input tensor is
   /// rank >= 2.
   bool vectorizeNDExtract = false;
+  bool flatten1DDepthwiseConv = false;
 };
 } // namespace
 
@@ -2991,7 +3001,8 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
 
   MLIRContext *ctx = getContext();
   RewritePatternSet patterns(ctx);
-  patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract());
+  patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
+                                     getFlatten_1dDepthwiseConv());
 
   if (!getDisableTransferPermutationMapLoweringPatterns())
     vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index b8d82159856825f..f6f74b448edf9a8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -44,8 +44,9 @@ using namespace mlir::linalg;
 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 
 /// Try to vectorize `convOp` as a convolution.
-static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
-                                                   LinalgOp convOp);
+static FailureOr<Operation *>
+vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
+                     bool flatten1DDepthwiseConv = false);
 
 /// Return the unique instance of OpType in `block` if it is indeed unique.
 /// Return null if none or more than 1 instances exist.
@@ -1664,7 +1665,8 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
 LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
                                       ArrayRef<int64_t> inputVectorSizes,
                                       ArrayRef<bool> inputScalableVecDims,
-                                      bool vectorizeNDExtract) {
+                                      bool vectorizeNDExtract,
+                                      bool flatten1DDepthwiseConv) {
   LDBG("Attempting to vectorize:\n" << *op << "\n");
   LDBG("Input vector sizes: ");
   LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -1696,8 +1698,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
             // TODO: isaConvolutionOpInterface that can also infer from generic
             // features. Will require stride/dilation attributes inference.
             if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
-              FailureOr<Operation *> convOr =
-                  vectorizeConvolution(rewriter, linalgOp);
+              FailureOr<Operation *> convOr = vectorizeConvolution(
+                  rewriter, linalgOp, flatten1DDepthwiseConv);
               if (succeeded(convOr)) {
                 llvm::append_range(results, (*convOr)->getResults());
                 return success();
@@ -2822,7 +2824,7 @@ struct Conv1DGenerator
   /// kw is always unrolled.
   /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
   /// > 1.
-  FailureOr<Operation *> depthwiseConv() {
+  FailureOr<Operation *> depthwiseConvGeneric() {
     if (!valid)
       return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
 
@@ -2936,6 +2938,176 @@ struct Conv1DGenerator
         .getOperation();
   }
 
+  /// Generate a vector implementation for ("flatten channel dim"):
+  /// ```
+  ///   Op def: (     n,     w,     c,    kw)
+  ///    Iters: ({Par(), Par(), Par(), Red()})
+  ///   Layout: {{n, 1 * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
+  /// ```
+  /// c of the input/output is collapsed with w. kw is always unrolled and
+  /// broadcast to match w.
+  ///
+  /// TODO: Add support for non-unit stride/dilation
+  /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
+  /// > 1.
+  FailureOr<Operation *> depthwiseConvFlatten() {
+    if (!valid)
+      return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
+
+    int64_t nSize, iSize, wSize, cSize, kwSize;
+    // kernel{kw, c}
+    bindShapeDims(rhsShapedType, kwSize, cSize);
+    // out{n, w, c}
+    bindShapeDims(resShapedType, nSize, wSize);
+    // in{n, w, c}
+    bindShapeDims(lhsShapedType, nSize, iSize);
+
+    vector::TransferWriteOp write;
+    Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+
+    if (strideW == 1)
+      return rewriter.notifyMatchFailure(
+          op, "Non-unit strides are not supported yet");
+    if (dilationW == 1)
+      return rewriter.notifyMatchFailure(
+          op, "Non-unit dilations are not supported yet");
+
+    Type lhsEltType = lhsShapedType.getElementType();
+    Type rhsEltType = rhsShapedType.getElementType();
+    Type resEltType = resShapedType.getElementType();
+    VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
+    VectorType lhsType = VectorType::get(
+        {nSize,
+         // iw = (ow * sw + kw *  dw - 1) * c
+         //   (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
+         (((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1) *
+             cSize},
+        lhsEltType);
+
+    VectorType resType = VectorType::get({nSize, wSize * cSize}, resEltType);
+
+    Value res, lhs, lhsFlat, resFlat;
+    // Read rhs slice of size {kw, c} @ [0, 0].
+    Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
+                                                        ValueRange{zero, zero});
+
+    SmallVector<ReassociationIndices> reassociation = {{0}, {1, 2}};
+
+    // Flatten w and c dimensions
+    auto lhsTypeCollapsed = VectorType::get({nSize, iSize * cSize}, lhsEltType);
+    auto linalgOp = dyn_cast<LinalgOp>(op);
+    lhsFlat =
+        linalgOp.hasTensorSemantics()
+            ? (Value)rewriter.create<tensor::CollapseShapeOp>(
+                  loc,
+                  RankedTensorType::get(lhsTypeCollapsed.getShape(),
+                                        lhsEltType),
+                  lhsShaped, reassociation)
+            : (Value)rewriter.create<memref::CollapseShapeOp>(
+                  loc, MemRefType::get(lhsTypeCollapsed.getShape(), lhsEltType),
+                  lhsShaped, reassociation);
+    resFlat =
+        linalgOp.hasTensorSemantics()
+            ? (Value)rewriter.create<tensor::CollapseShapeOp>(
+                  loc, RankedTensorType::get(resType.getShape(), resEltType),
+                  resShaped, reassociation)
+            : (Value)rewriter.create<memref::CollapseShapeOp>(
+                  loc, MemRefType::get(resType.getShape(), resEltType),
+                  resShaped, reassociation);
+
+    // Read lhs slice of size {n, (w * wSize + kw * dilationW) * c} @ [0,
+    // 0].
+    lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsFlat,
+                                                  ValueRange{zero, zero});
+    // Read res slice of size {n, w * c} @ [0, 0].
+    res = rewriter.create<vector::TransferReadOp>(loc, resType, resFlat,
+                                                  ValueRange{zero, zero});
+
+    //===------------------------------------------------------------------===//
+    // Begin vector-only rewrite part
+    //===------------------------------------------------------------------===//
+    // Unroll along kw and read slices of lhs and rhs.
+    SmallVector<Value> lhsVals, rhsVals, resVals;
+    // Extract lhs slice of size {n, wSizeStep * c}
+    //   @ [0, (sw * w + dw * kw) * cSize].
+    for (int64_t kw = 0; kw < kwSize; ++kw) {
+      for (int64_t w = 0; w < wSize; w += wSize) {
+        lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
+            loc, lhs,
+            /*offsets=*/
+            ArrayRef<int64_t>{0, (w * wSize + kw * dilationW) * cSize},
+            /*sizes=*/ArrayRef<int64_t>{nSize, wSize * cSize},
+            /*strides=*/ArrayRef<int64_t>{1, 1}));
+      }
+    }
+    // Extract rhs slice of size {c} @ [kw].
+    for (int64_t kw = 0; kw < kwSize; ++kw) {
+      rhsVals.push_back(rewriter.create<vector::ExtractOp>(
+          loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
+    }
+
+    // Extract res slice
+    // Flattened case:  {n, wSizeStep * c} @ [0, w].
+    for (int64_t w = 0; w < wSize; w += wSize) {
+      resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, res,
+          /*offsets=*/ArrayRef<int64_t>{0, w * cSize},
+          /*sizes=*/ArrayRef<int64_t>{nSize, wSize * cSize},
+          /*strides=*/ArrayRef<int64_t>{1, 1}));
+    }
+
+    auto linearIndex = [&](int64_t kw, int64_t w) {
+      return kw * (wSize / wSize) + w;
+    };
+
+    // Compute contraction:
+    //    O{n, w * c} += I{n, (sw * w + dw * kw) * c} * F{c}
+    for (int64_t kw = 0; kw < kwSize; ++kw) {
+      for (int64_t w = 0; w < wSize; w += wSize) {
+        resVals[w] = depthwiseConv1dFlatSliceAsMulAcc(
+            rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw],
+            resVals[w]);
+      }
+    }
+
+    // Its possible we failed to create the Fma.
+    if (!llvm::all_of(resVals, [](Value v) { return v; })) {
+      // Manually revert (in reverse order) to avoid leaving a bad IR state.
+      for (auto &collection :
+           {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
+        for (Value v : collection)
+          rewriter.eraseOp(v.getDefiningOp());
+      return rewriter.notifyMatchFailure(op, "failed to create FMA");
+    }
+
+    // Write back res slice. This does not depend on kw.
+    // Flattened case: {n, wSizeStep * c} @ [0, w].
+    for (int64_t w = 0; w < wSize; w += wSize) {
+      res = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, resVals[w], res,
+          /*offsets=*/ArrayRef<int64_t>{0, w * cSize},
+          /*strides=*/ArrayRef<int64_t>{1, 1});
+    }
+    //===------------------------------------------------------------------===//
+    // End vector-only rewrite part
+    //===------------------------------------------------------------------===//
+    // Write back res slice of size {n, w * c} @ [0, 0].
+    mlir::vector::TransferWriteOp tWrite =
+        rewriter.create<vector::TransferWriteOp>(loc, res, resFlat,
+                                                 ValueRange{zero, zero});
+
+    // A tensor has to be re-shaped back to it's original shape ...
+    if (linalgOp.hasTensorSemantics())
+      // Re-expand shape
+      return rewriter
+          .create<tensor::ExpandShapeOp>(loc, resShapedType, tWrite.getResult(),
+                                         reassociation)
+          .getOperation();
+    /// ... memrefs don't requie reshaping (re-shape is just a different view
+    /// into the same memref)
+    return tWrite.getOperation();
+  }
+
   /// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc
   Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
                                      Value lhs, Value rhs, Value res) {
@@ -2959,6 +3131,39 @@ struct Conv1DGenerator
     return rewriter.create<arith::AddIOp>(loc, mul, res);
   }
 
+  /// Lower lhs{n, w * c} * rhs{c} -> res{n, w * c} to MulAcc
+  Value depthwiseConv1dFlatSliceAsMulAcc(RewriterBase &rewriter, Location loc,
+                                         Value lhs, Value rhs, Value res) {
+    auto rhsTy = rhs.getType().cast<ShapedType>();
+    auto resTy = res.getType().cast<ShapedType>();
+
+    lhs = promote(rewriter, loc, lhs, resTy);
+
+    auto rhsSize = rhs.getType().cast<VectorType>().getShape()[0];
+    auto resSize = res.getType().cast<VectorType>().getShape()[1];
+
+    SmallVector<int64_t, 16> indicies;
+    for (int i = 0; i < resSize / rhsSize; ++i) {
+      for (int j = 0; j < rhsSize; ++j)
+        indicies.push_back(j);
+    }
+
+    rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indicies);
+
+    rhs = rewriter.create<vector::BroadcastOp>(
+        loc, resTy.clone(rhsTy.getElementType()), rhs);
+    rhs = promote(rewriter, loc, rhs, resTy);
+
+    if (!lhs || !rhs)
+      return nullptr;
+
+    if (resTy.getElementType().isa<FloatType>())
+      return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
+
+    auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
+    return rewriter.create<arith::AddIOp>(loc, mul, res);
+  }
+
   /// Entry point for non-channeled convolution:
   ///   {{w + kw}, {kw}, {w}}
   FailureOr<Operation *> generateNonChanneledConv() {
@@ -3049,7 +3254,7 @@ struct Conv1DGenerator
 
   /// Entry point that transposes into the common form:
   ///   {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
-  FailureOr<Operation *> generateDilatedConv() {
+  FailureOr<Operation *> generateDilatedConv(bool flatten = false) {
     AffineExpr n, w, c, kw;
     bindDims(ctx, n, w, c, kw);
     if (!iters({Par(), Par(), Par(), Red()}))
@@ -3060,7 +3265,7 @@ struct Conv1DGenerator
     if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
                 /*rhsIndex*/ {kw, c},
                 /*resIndex*/ {n, w, c}}))
-      return depthwiseConv();
+      return flatten ? depthwiseConvFlatten() : depthwiseConvGeneric();
 
     return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
   }
@@ -3125,8 +3330,9 @@ struct Conv1DGenerator
 
 /// Helper function to vectorize a LinalgOp with convolution semantics.
 // TODO: extend the generic vectorization to support windows and drop this.
-static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
-                                                   LinalgOp op) {
+static FailureOr<Operation *>
+vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
+                     bool flatten1DDepthwiseConv) {
   // The ConvolutionOpInterface gives us guarantees of existence for
   // strides/dilations. However, we do not need to rely on those, we can simply
   // use them if present, otherwise use the default and let the generic conv.
@@ -3151,7 +3357,7 @@ static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
   res = e.generateNcwPooling();
   if (succeeded(res))
     return res;
-  return e.generateDilatedConv();
+  return e.generateDilatedConv(flatten1DDepthwiseConv);
 }
 
 struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
new file mode 100644
index 000000000000000..6b0f920bfa42e7f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
@@ -0,0 +1,222 @@
+// RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s
+
+func.func @flatten_tensor(%input: tensor<1x8x3xi8>, %filter: tensor<1x3xi8>, %output: tensor<1x8x3xi8>) -> (tensor<1x8x3xi8>) {
+  %res = linalg.depthwise_conv_1d_nwc_wc
+    {dilations = dense<1> : vector<1xi64>,
+    strides = dense<1> : vector<1xi64>}
+    ins(%input, %filter : tensor<1x8x3xi8>, tensor<1x3xi8>)
+    outs(%output : tensor<1x8x3xi8>) -> tensor<1x8x3xi8>
+  return %res : tensor<1x8x3xi8>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL:   func.func @flatten_tensor(
+// CHECK-SAME:                              %[[VAL_0:.*]]: tensor<1x8x3xi8>,
+// CHECK-SAME:                              %[[VAL_1:.*]]: tensor<1x3xi8>,
+// CHECK-SAME:                              %[[VAL_2:.*]]: tensor<1x8x3xi8>) -> tensor<1x8x3xi8> ...
[truncated]

@banach-space
Copy link
Contributor Author

For context, here's the previous conversation on this specialisation:

We have tried to implement this as pre-processing phase (as discussed in D149155), but that leads to either:

  • linalg.generics with indexing maps which are not permutations (does not vectorize), or
  • strided loads (which are bound to be slow).

So, we rejected the idea of any pre-processing and have considered two alternatives:

  1. add a special case to the vectoriser (implemented here), or
  2. implement a post-processing phase that would flatten the channel dimension (potential alternative).

The benefits of Option 1. are that:

  • it's available today (and this is now quite critical for us), and
  • it only requires changes in one place (i.e. the vectorizer).

The benefits of Option 2. would be that:

  • we wouldn't be adding any new cases to the vectorizer (is that the design goal?),

However, I expect that we might see some missing canonicalisations or other issues (TBH, I've only looked at it briefly).

I am happy to investigate Option 2. if that's preferable. However, I would prefer to do it in follow-up patches and to have this in-tree in the meantime.

WDYT? And thank you for all the feedback so far :) 🙏🏻

Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a first pass with some minor restructuring comments.

However, as I dug deeper, I have the impression that you are doing a one-off manual implementation of a new swap(vector.insert/extractStridedSlice, vector.shape_cast) pattern.

It seems to me you could achieve the same benefits by rewriting:

        resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc,
                                                  lhsVals[linearIndex(kw, w)],
                                                  rhsVals[kw], resVals[w]);

as shape_cast + depthwiseConv1dSliceAsMulAcc + shape_cast and applying such new swap/propagation patterns.

I think this would fit within what you refer to as option 2.

In the current form, I am afraid the manual one-off application of that pattern to your current use case is unappealing to me.

@dcaballe
Copy link
Contributor

Should we schedule a meeting to discuss this? I agree with Nicolas on that this might be a one-off case but I also acknowledge that we have been looking at multiple options with significant drawbacks and we have to find a way to move this forward.

A few replies:

but that leads to either:
linalg.generics with indexing maps which are not permutations (does not vectorize), or
strided loads (which are bound to be slow).

Ultimately, this is something we may want to support. Semi-affine maps should be somewhat challenging but not extremely difficult to support (I already played a bit with this) and I think it has been a long-lasting TODO. Strided loads is another TODO, in general, as we are currently generating gather ops for loads with a short static stride.

we wouldn't be adding any new cases to the vectorizer (is that the design goal?),

I think the ultimate goal is to move all the convolution decomposition step before the vectorizer but nobody has been brave enough to do it :). I think Nicolas' rewrite suggestion goes along these lines.

I haven't looked much into this mechanism but there also the possibility of adding this specialization as a hook vectorization pattern. Not sure if we still use that mechanism but it's still there in the vectorizer so it should work. Perhaps we could extend the API so users can provide external ad-hoc vectorization patterns? (Just a random though, perhaps this is not even feasible).

Hopefully that helps!

@nicolasvasilache
Copy link
Contributor

@banach-space I really like how your recent change significantly reduced the complexity compared to the previous approach. Hopefully the shape_cast swaps nicely with the vector transfers next :)

Copy link

github-actions bot commented Nov 16, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@banach-space
Copy link
Contributor Author

@banach-space I really like how your recent change significantly reduced the complexity compared to the previous approach.

@nicolasvasilache Thank you - I was about to write a long reply in which I say "thank you" for your observation (which, indeed, makes things incredibly simpler). But you have already noticed that :)

Hopefully the shape_cast swaps nicely with the vector transfers next :)

That would be happening as a post vectorisation canonicalisation, right?

Should we schedule a meeting to discuss this?

Yes - let me ping you offline :)

Regarding linalg.generics with maps which are not permutations - we need to take a closer look and make sure that the resulting code for convolutions would be equally good. I can take another look at my examples, but I need to prioritise landing this first :) But yes, an important TODO.

As for "strided" loads/stores - super important TODO as well :)

In any case, I have refactored this patch following Nicolas' observation (apologies for over-complicating it so much before). Thank you, that was incredibly helpful 🙏🏻 .

Now, this change is still a "one off" pattern application within the vectoriser. Should this be done elsewhere? We'd need to match whatever depthwiseConv1dSliceAsMulAcc is generating post vectorisation, which would turn this into something quite complex.

Btw, I need to refine the tests - I am aware of that. And to benchmark.

Finally, to better visualise the current changes:

Input

func.func @conv_dill_2(%input: memref<3x5x4xf32>,
%filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
  linalg.depthwise_conv_1d_nwc_wc
    {dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
    ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>)
    outs(%output : memref<3x2x4xf32>)
  return
}

Vectorisation without shape casting

  func.func @conv_dill_2(%arg0: memref<3x5x4xf32>, %arg1: memref<2x4xf32>, %arg2: memref<3x2x4xf32>) {
    %c0 = arith.constant 0 : index
    %cst = arith.constant 0.000000e+00 : f32
    %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<3x5x4xf32>, vector<3x4x4xf32>
    %1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<2x4xf32>, vector<2x4xf32>
    %2 = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<3x2x4xf32>, vector<3x2x4xf32>
    %3 = vector.extract_strided_slice %0 {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
    %4 = vector.extract_strided_slice %0 {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
    %5 = vector.extract %1[0] : vector<4xf32> from vector<2x4xf32>
    %6 = vector.extract %1[1] : vector<4xf32> from vector<2x4xf32>
    %7 = vector.broadcast %5 : vector<4xf32> to vector<3x2x4xf32>
    %8 = vector.fma %3, %7, %2 : vector<3x2x4xf32>
    %9 = vector.broadcast %6 : vector<4xf32> to vector<3x2x4xf32>
    %10 = vector.fma %4, %9, %8 : vector<3x2x4xf32>
    vector.transfer_write %10, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<3x2x4xf32>, memref<3x2x4xf32>
    return
  }

Vectorisation with shape casting

  func.func @conv_dill_2(%arg0: memref<3x5x4xf32>, %arg1: memref<2x4xf32>, %arg2: memref<3x2x4xf32>) {
    %c0 = arith.constant 0 : index
    %cst = arith.constant 0.000000e+00 : f32
    %0 = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<3x5x4xf32>, vector<3x4x4xf32>
    %1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<2x4xf32>, vector<2x4xf32>
    %2 = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<3x2x4xf32>, vector<3x2x4xf32>
    %3 = vector.extract_strided_slice %0 {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
    %4 = vector.extract_strided_slice %0 {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
    %5 = vector.extract %1[0] : vector<4xf32> from vector<2x4xf32>
    %6 = vector.extract %1[1] : vector<4xf32> from vector<2x4xf32>
    %7 = vector.shape_cast %3 : vector<3x2x4xf32> to vector<3x8xf32>
    %8 = vector.shape_cast %2 : vector<3x2x4xf32> to vector<3x8xf32>
    %9 = vector.broadcast %5 : vector<4xf32> to vector<3x2x4xf32>
    %10 = vector.shape_cast %9 : vector<3x2x4xf32> to vector<3x8xf32>
    %11 = vector.fma %7, %10, %8 : vector<3x8xf32>
    %12 = vector.shape_cast %4 : vector<3x2x4xf32> to vector<3x8xf32>
    %13 = vector.broadcast %6 : vector<4xf32> to vector<3x2x4xf32>
    %14 = vector.shape_cast %13 : vector<3x2x4xf32> to vector<3x8xf32>
    %15 = vector.fma %12, %14, %11 : vector<3x8xf32>
    %16 = vector.shape_cast %15 : vector<3x8xf32> to vector<3x2x4xf32>
    vector.transfer_write %16, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<3x2x4xf32>, memref<3x2x4xf32>
    return
  }

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

@banach-space
Copy link
Contributor Author

Thank you for all the suggestions!

@nicolasvasilache I believe that the current implementation addresses all of your comments.

Hopefully the shape_cast swaps nicely with the vector transfers next :)

There's hope 🤞🏻 , but it's going to be a journey 😅 (I have already landed #73522). If there's no new traffic then I will merge this soon.

-Andrzej

@banach-space banach-space force-pushed the andrzej/flatten_conv_v2 branch from 75cc4a6 to 8c877d6 Compare December 5, 2023 18:56
@dcaballe
Copy link
Contributor

dcaballe commented Dec 6, 2023

Nicolas is OOO so I would wait a bit and land and address any further comments in a follow-up commit. That should be ok as we discussed this approach in a meeting.

The current vectorization of 1D depthwise convolutions in Linalg is
_sub-optimal_ for tensor with a low number of channel dimensions, e.g.:

```mlir
linalg.depthwise_conv_1d_nwc_wc
    {dilations = dense<1> : vector<1xi64>,
    strides = dense<1> : vector<1xi64>}
    ins(%input, %filter : tensor<1x8x3xi8>, tensor<1x3xi8>)
    outs(%output : tensor<1x8x3xi8>) -> tensor<1x8x3xi8>
```

That's due to the fact that ultimately (i.e. at LLVM level),
vectorization happens along the trailing dimension (i.e. the channel
dimension). In this case it leads to vectors with 3 elements (or worse,
if there's e.g. only 1 channel dimension). For comparison, a 128 bit
wide vector registers can hold 16 x i8.

Instead, this patch adds an option to flatten/collapse the channel
dimension into the width dimension of the input/output:

```mlir
    %collapsed = tensor.collapse_shape %input [[0], [1, 2]] : tensor<1x8x3xi8> into tensor<1x24xi8>
    %collapsed_0 = tensor.collapse_shape %outpu [[0], [1, 2]] : tensor<1x8x3xi8> into tensor<1x24xi8>
```

(Note that for this to work, the filter is broadcast rather than
re-shaped. Please see the test cases for details).

The new vectorization strategy is implemented in `depthwiseConvFlatten`,
which was implemented based on `depthwiseConvGeneric` (i.e. the original
vectorization hook). The current vectorization is preserved and kept as
the default option. New vectorization can be selected through e.g. a
transform dialect attribute:

```mlir
  transform.structured.vectorize_children_and_apply_patterns %conv {flatten_1d_depthwise_conv}
```

A forthcoming patch will implement a strategy to automatically switch
between the two implementations, depending on the shape of the input
tensors.

Co-authored by: Bradley Smith <[email protected]>
…izing

Following on from Nicolas' observation, this commit refactors the
implementation to simply replace:
```
resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc,
                                          lhsVals[linearIndex(kw, w)],
                                          rhsVals[kw], resVals[w]);
```

with shape_cast + depthwiseConv1dSliceAsMulAcc + shape_cast.
…izing

Final tweaks (more comments, revert unrelated change in a test file)
@banach-space banach-space force-pushed the andrzej/flatten_conv_v2 branch from 8c877d6 to 8e8f56d Compare December 6, 2023 18:41
@banach-space banach-space merged commit 03c2f5d into llvm:main Dec 6, 2023
banach-space added a commit to banach-space/llvm-project that referenced this pull request Dec 11, 2023
Updates the vectorisation of 1D depthwise convolution when flattening
the channel dimension (introduced in llvm#71918). In particular - how the
convolution filter is "flattened". ATM, the vectoriser will use
`vector.shape_cast`:

```mlir
  %b_filter = vector.broadcast %filter : vector<4xf32> to vector<3x2x4xf32>
  %sc_filter = vector.shape_cast %b_filter : vector<3x2x4xf32> to vector<3x8xf32>
```

This lowering is not ideal - `vector.shape_cast` can be convenient when
it's folded away, but that's not happening in this case. Instead, this
patch updates the vectoriser to use `vector.shuffle` (the overall result
is identical):

```mlir
  %sh_filter = vector.shuffle %filter, %filter
      [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
  %b_filter = vector.broadcast %sh_filter : vector<8xf32> to vector<3x8xf32>
```
banach-space added a commit that referenced this pull request Dec 15, 2023
Updates the vectorisation of 1D depthwise convolution when flattening
the channel dimension (introduced in #71918). In particular - how the
convolution filter is "flattened". ATM, the vectoriser will use
`vector.shape_cast`:

```mlir
  %b_filter = vector.broadcast %filter : vector<4xf32> to vector<3x2x4xf32>
  %sc_filter = vector.shape_cast %b_filter : vector<3x2x4xf32> to vector<3x8xf32>
```

This lowering is not ideal - `vector.shape_cast` can be convenient when
it's folded away, but that's not happening in this case. Instead, this
patch updates the vectoriser to use `vector.shuffle` (the overall result
is identical):

```mlir
  %sh_filter = vector.shuffle %filter, %filter
      [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32>
  %b_filter = vector.broadcast %sh_filter : vector<8xf32> to vector<3x8xf32>
```
@banach-space banach-space deleted the andrzej/flatten_conv_v2 branch March 8, 2024 14:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants