-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][linalg][conv] Flatten the channel dimension when vectorizing #71918
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Andrzej Warzyński (banach-space) ChangesThe current vectorization of 1D depthwise convolutions in Linalg is linalg.depthwise_conv_1d_nwc_wc
{dilations = dense<1> : vector<1xi64>,
strides = dense<1> : vector<1xi64>}
ins(%input, %filter : tensor<1x8x3xi8>, tensor<1x3xi8>)
outs(%output : tensor<1x8x3xi8>) -> tensor<1x8x3xi8> That's due to the fact that ultimately (i.e. at LLVM level), Instead, this patch adds an option to flatten/collapse the channel %collapsed = tensor.collapse_shape %input [[0], [1, 2]] : tensor<1x8x3xi8> into tensor<1x24xi8>
%collapsed_0 = tensor.collapse_shape %outpu [[0], [1, 2]] : tensor<1x8x3xi8> into tensor<1x24xi8> (Note that for this to work, the filter is broadcast rather than The new vectorization strategy is implemented in transform.structured.vectorize_children_and_apply_patterns %conv {flatten_1d_depthwise_conv} A forthcoming patch will implement a strategy to automatically switch Co-authored by: Bradley Smith <[email protected]> Patch is 38.22 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/71918.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index f1c3d717f1fa951..310efe164f93950 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2034,6 +2034,7 @@ def VectorizeChildrenAndApplyPatternsOp :
let arguments = (ins TransformHandleTypeInterface:$target,
UnitAttr:$vectorize_padding,
UnitAttr:$vectorize_nd_extract,
+ UnitAttr:$flatten_1d_depthwise_conv,
UnitAttr:$disable_multi_reduction_to_contract_patterns,
UnitAttr:$disable_transfer_permutation_map_lowering_patterns);
let results = (outs TransformHandleTypeInterface:$transformed);
@@ -2045,7 +2046,8 @@ def VectorizeChildrenAndApplyPatternsOp :
let builders = [
OpBuilder<(ins "Value":$target,
CArg<"bool", "false">:$vectorizePadding,
- CArg<"bool", "false">:$vectorizeNDExtract)>,
+ CArg<"bool", "false">:$vectorizeNDExtract,
+ CArg<"bool", "false">:$flatten1DDepthwise)>
];
let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6547648f7495c31..a4aee1f45249c2b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -753,7 +753,8 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes = {},
ArrayRef<bool> inputScalableVecDims = {},
- bool vectorizeNDExtract = false);
+ bool vectorizeNDExtract = false,
+ bool flatten1DDepthwiseConv = false);
/// Emit a suitable vector form for a Copy op with fully static shape.
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index de4965f937162ea..35e8be7806928e1 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2937,7 +2937,7 @@ LogicalResult TileUsingForallOp::verify() {
void transform::VectorizeChildrenAndApplyPatternsOp::build(
OpBuilder &builder, OperationState &result, Value target,
- bool vectorizePadding, bool vectorizeExtract) {
+ bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
result.addOperands(target);
if (vectorizePadding) {
result.addAttribute(
@@ -2951,6 +2951,12 @@ void transform::VectorizeChildrenAndApplyPatternsOp::build(
result.name),
builder.getUnitAttr());
}
+ if (flatten1DDepthwiseConv) {
+ result.addAttribute(
+ VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
+ result.name),
+ builder.getUnitAttr());
+ }
result.addTypes(transform::AnyOpType::get(builder.getContext()));
}
@@ -2959,22 +2965,26 @@ namespace {
/// VectorizeChildrenAndApplyPatternsOp::applyToOne.
struct VectorizationPattern : public RewritePattern {
explicit VectorizationPattern(MLIRContext *context,
- bool vectorizeExtract = false)
+ bool vectorizeExtract = false,
+ bool flattenConv = false)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
- vectorizeNDExtract(vectorizeExtract) {}
+ vectorizeNDExtract(vectorizeExtract),
+ flatten1DDepthwiseConv(flattenConv) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
if (!linalgOp)
return rewriter.notifyMatchFailure(op, "expected Linalg Op");
return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{},
- /*scalableVecDims=*/{}, vectorizeNDExtract);
+ /*scalableVecDims=*/{}, vectorizeNDExtract,
+ flatten1DDepthwiseConv);
}
private:
/// Controls whether to vectorize `tensor.extract` when the input tensor is
/// rank >= 2.
bool vectorizeNDExtract = false;
+ bool flatten1DDepthwiseConv = false;
};
} // namespace
@@ -2991,7 +3001,8 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
MLIRContext *ctx = getContext();
RewritePatternSet patterns(ctx);
- patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract());
+ patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
+ getFlatten_1dDepthwiseConv());
if (!getDisableTransferPermutationMapLoweringPatterns())
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index b8d82159856825f..f6f74b448edf9a8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -44,8 +44,9 @@ using namespace mlir::linalg;
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
/// Try to vectorize `convOp` as a convolution.
-static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
- LinalgOp convOp);
+static FailureOr<Operation *>
+vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
+ bool flatten1DDepthwiseConv = false);
/// Return the unique instance of OpType in `block` if it is indeed unique.
/// Return null if none or more than 1 instances exist.
@@ -1664,7 +1665,8 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims,
- bool vectorizeNDExtract) {
+ bool vectorizeNDExtract,
+ bool flatten1DDepthwiseConv) {
LDBG("Attempting to vectorize:\n" << *op << "\n");
LDBG("Input vector sizes: ");
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -1696,8 +1698,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
// TODO: isaConvolutionOpInterface that can also infer from generic
// features. Will require stride/dilation attributes inference.
if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
- FailureOr<Operation *> convOr =
- vectorizeConvolution(rewriter, linalgOp);
+ FailureOr<Operation *> convOr = vectorizeConvolution(
+ rewriter, linalgOp, flatten1DDepthwiseConv);
if (succeeded(convOr)) {
llvm::append_range(results, (*convOr)->getResults());
return success();
@@ -2822,7 +2824,7 @@ struct Conv1DGenerator
/// kw is always unrolled.
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
/// > 1.
- FailureOr<Operation *> depthwiseConv() {
+ FailureOr<Operation *> depthwiseConvGeneric() {
if (!valid)
return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
@@ -2936,6 +2938,176 @@ struct Conv1DGenerator
.getOperation();
}
+ /// Generate a vector implementation for ("flatten channel dim"):
+ /// ```
+ /// Op def: ( n, w, c, kw)
+ /// Iters: ({Par(), Par(), Par(), Red()})
+ /// Layout: {{n, 1 * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
+ /// ```
+ /// c of the input/output is collapsed with w. kw is always unrolled and
+ /// broadcast to match w.
+ ///
+ /// TODO: Add support for non-unit stride/dilation
+ /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
+ /// > 1.
+ FailureOr<Operation *> depthwiseConvFlatten() {
+ if (!valid)
+ return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
+
+ int64_t nSize, iSize, wSize, cSize, kwSize;
+ // kernel{kw, c}
+ bindShapeDims(rhsShapedType, kwSize, cSize);
+ // out{n, w, c}
+ bindShapeDims(resShapedType, nSize, wSize);
+ // in{n, w, c}
+ bindShapeDims(lhsShapedType, nSize, iSize);
+
+ vector::TransferWriteOp write;
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+
+ if (strideW == 1)
+ return rewriter.notifyMatchFailure(
+ op, "Non-unit strides are not supported yet");
+ if (dilationW == 1)
+ return rewriter.notifyMatchFailure(
+ op, "Non-unit dilations are not supported yet");
+
+ Type lhsEltType = lhsShapedType.getElementType();
+ Type rhsEltType = rhsShapedType.getElementType();
+ Type resEltType = resShapedType.getElementType();
+ VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
+ VectorType lhsType = VectorType::get(
+ {nSize,
+ // iw = (ow * sw + kw * dw - 1) * c
+ // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
+ (((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1) *
+ cSize},
+ lhsEltType);
+
+ VectorType resType = VectorType::get({nSize, wSize * cSize}, resEltType);
+
+ Value res, lhs, lhsFlat, resFlat;
+ // Read rhs slice of size {kw, c} @ [0, 0].
+ Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
+ ValueRange{zero, zero});
+
+ SmallVector<ReassociationIndices> reassociation = {{0}, {1, 2}};
+
+ // Flatten w and c dimensions
+ auto lhsTypeCollapsed = VectorType::get({nSize, iSize * cSize}, lhsEltType);
+ auto linalgOp = dyn_cast<LinalgOp>(op);
+ lhsFlat =
+ linalgOp.hasTensorSemantics()
+ ? (Value)rewriter.create<tensor::CollapseShapeOp>(
+ loc,
+ RankedTensorType::get(lhsTypeCollapsed.getShape(),
+ lhsEltType),
+ lhsShaped, reassociation)
+ : (Value)rewriter.create<memref::CollapseShapeOp>(
+ loc, MemRefType::get(lhsTypeCollapsed.getShape(), lhsEltType),
+ lhsShaped, reassociation);
+ resFlat =
+ linalgOp.hasTensorSemantics()
+ ? (Value)rewriter.create<tensor::CollapseShapeOp>(
+ loc, RankedTensorType::get(resType.getShape(), resEltType),
+ resShaped, reassociation)
+ : (Value)rewriter.create<memref::CollapseShapeOp>(
+ loc, MemRefType::get(resType.getShape(), resEltType),
+ resShaped, reassociation);
+
+ // Read lhs slice of size {n, (w * wSize + kw * dilationW) * c} @ [0,
+ // 0].
+ lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsFlat,
+ ValueRange{zero, zero});
+ // Read res slice of size {n, w * c} @ [0, 0].
+ res = rewriter.create<vector::TransferReadOp>(loc, resType, resFlat,
+ ValueRange{zero, zero});
+
+ //===------------------------------------------------------------------===//
+ // Begin vector-only rewrite part
+ //===------------------------------------------------------------------===//
+ // Unroll along kw and read slices of lhs and rhs.
+ SmallVector<Value> lhsVals, rhsVals, resVals;
+ // Extract lhs slice of size {n, wSizeStep * c}
+ // @ [0, (sw * w + dw * kw) * cSize].
+ for (int64_t kw = 0; kw < kwSize; ++kw) {
+ for (int64_t w = 0; w < wSize; w += wSize) {
+ lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, lhs,
+ /*offsets=*/
+ ArrayRef<int64_t>{0, (w * wSize + kw * dilationW) * cSize},
+ /*sizes=*/ArrayRef<int64_t>{nSize, wSize * cSize},
+ /*strides=*/ArrayRef<int64_t>{1, 1}));
+ }
+ }
+ // Extract rhs slice of size {c} @ [kw].
+ for (int64_t kw = 0; kw < kwSize; ++kw) {
+ rhsVals.push_back(rewriter.create<vector::ExtractOp>(
+ loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
+ }
+
+ // Extract res slice
+ // Flattened case: {n, wSizeStep * c} @ [0, w].
+ for (int64_t w = 0; w < wSize; w += wSize) {
+ resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, res,
+ /*offsets=*/ArrayRef<int64_t>{0, w * cSize},
+ /*sizes=*/ArrayRef<int64_t>{nSize, wSize * cSize},
+ /*strides=*/ArrayRef<int64_t>{1, 1}));
+ }
+
+ auto linearIndex = [&](int64_t kw, int64_t w) {
+ return kw * (wSize / wSize) + w;
+ };
+
+ // Compute contraction:
+ // O{n, w * c} += I{n, (sw * w + dw * kw) * c} * F{c}
+ for (int64_t kw = 0; kw < kwSize; ++kw) {
+ for (int64_t w = 0; w < wSize; w += wSize) {
+ resVals[w] = depthwiseConv1dFlatSliceAsMulAcc(
+ rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw],
+ resVals[w]);
+ }
+ }
+
+ // Its possible we failed to create the Fma.
+ if (!llvm::all_of(resVals, [](Value v) { return v; })) {
+ // Manually revert (in reverse order) to avoid leaving a bad IR state.
+ for (auto &collection :
+ {resVals, rhsVals, lhsVals, {res, rhs, lhs, zero}})
+ for (Value v : collection)
+ rewriter.eraseOp(v.getDefiningOp());
+ return rewriter.notifyMatchFailure(op, "failed to create FMA");
+ }
+
+ // Write back res slice. This does not depend on kw.
+ // Flattened case: {n, wSizeStep * c} @ [0, w].
+ for (int64_t w = 0; w < wSize; w += wSize) {
+ res = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, resVals[w], res,
+ /*offsets=*/ArrayRef<int64_t>{0, w * cSize},
+ /*strides=*/ArrayRef<int64_t>{1, 1});
+ }
+ //===------------------------------------------------------------------===//
+ // End vector-only rewrite part
+ //===------------------------------------------------------------------===//
+ // Write back res slice of size {n, w * c} @ [0, 0].
+ mlir::vector::TransferWriteOp tWrite =
+ rewriter.create<vector::TransferWriteOp>(loc, res, resFlat,
+ ValueRange{zero, zero});
+
+ // A tensor has to be re-shaped back to it's original shape ...
+ if (linalgOp.hasTensorSemantics())
+ // Re-expand shape
+ return rewriter
+ .create<tensor::ExpandShapeOp>(loc, resShapedType, tWrite.getResult(),
+ reassociation)
+ .getOperation();
+ /// ... memrefs don't requie reshaping (re-shape is just a different view
+ /// into the same memref)
+ return tWrite.getOperation();
+ }
+
/// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc
Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
Value lhs, Value rhs, Value res) {
@@ -2959,6 +3131,39 @@ struct Conv1DGenerator
return rewriter.create<arith::AddIOp>(loc, mul, res);
}
+ /// Lower lhs{n, w * c} * rhs{c} -> res{n, w * c} to MulAcc
+ Value depthwiseConv1dFlatSliceAsMulAcc(RewriterBase &rewriter, Location loc,
+ Value lhs, Value rhs, Value res) {
+ auto rhsTy = rhs.getType().cast<ShapedType>();
+ auto resTy = res.getType().cast<ShapedType>();
+
+ lhs = promote(rewriter, loc, lhs, resTy);
+
+ auto rhsSize = rhs.getType().cast<VectorType>().getShape()[0];
+ auto resSize = res.getType().cast<VectorType>().getShape()[1];
+
+ SmallVector<int64_t, 16> indicies;
+ for (int i = 0; i < resSize / rhsSize; ++i) {
+ for (int j = 0; j < rhsSize; ++j)
+ indicies.push_back(j);
+ }
+
+ rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indicies);
+
+ rhs = rewriter.create<vector::BroadcastOp>(
+ loc, resTy.clone(rhsTy.getElementType()), rhs);
+ rhs = promote(rewriter, loc, rhs, resTy);
+
+ if (!lhs || !rhs)
+ return nullptr;
+
+ if (resTy.getElementType().isa<FloatType>())
+ return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res);
+
+ auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs);
+ return rewriter.create<arith::AddIOp>(loc, mul, res);
+ }
+
/// Entry point for non-channeled convolution:
/// {{w + kw}, {kw}, {w}}
FailureOr<Operation *> generateNonChanneledConv() {
@@ -3049,7 +3254,7 @@ struct Conv1DGenerator
/// Entry point that transposes into the common form:
/// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
- FailureOr<Operation *> generateDilatedConv() {
+ FailureOr<Operation *> generateDilatedConv(bool flatten = false) {
AffineExpr n, w, c, kw;
bindDims(ctx, n, w, c, kw);
if (!iters({Par(), Par(), Par(), Red()}))
@@ -3060,7 +3265,7 @@ struct Conv1DGenerator
if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
/*rhsIndex*/ {kw, c},
/*resIndex*/ {n, w, c}}))
- return depthwiseConv();
+ return flatten ? depthwiseConvFlatten() : depthwiseConvGeneric();
return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
}
@@ -3125,8 +3330,9 @@ struct Conv1DGenerator
/// Helper function to vectorize a LinalgOp with convolution semantics.
// TODO: extend the generic vectorization to support windows and drop this.
-static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
- LinalgOp op) {
+static FailureOr<Operation *>
+vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
+ bool flatten1DDepthwiseConv) {
// The ConvolutionOpInterface gives us guarantees of existence for
// strides/dilations. However, we do not need to rely on those, we can simply
// use them if present, otherwise use the default and let the generic conv.
@@ -3151,7 +3357,7 @@ static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
res = e.generateNcwPooling();
if (succeeded(res))
return res;
- return e.generateDilatedConv();
+ return e.generateDilatedConv(flatten1DDepthwiseConv);
}
struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {
diff --git a/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
new file mode 100644
index 000000000000000..6b0f920bfa42e7f
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorize-convolution-flatten.mlir
@@ -0,0 +1,222 @@
+// RUN: mlir-opt -split-input-file -transform-interpreter %s | FileCheck %s
+
+func.func @flatten_tensor(%input: tensor<1x8x3xi8>, %filter: tensor<1x3xi8>, %output: tensor<1x8x3xi8>) -> (tensor<1x8x3xi8>) {
+ %res = linalg.depthwise_conv_1d_nwc_wc
+ {dilations = dense<1> : vector<1xi64>,
+ strides = dense<1> : vector<1xi64>}
+ ins(%input, %filter : tensor<1x8x3xi8>, tensor<1x3xi8>)
+ outs(%output : tensor<1x8x3xi8>) -> tensor<1x8x3xi8>
+ return %res : tensor<1x8x3xi8>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.depthwise_conv_1d_nwc_wc"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 {flatten_1d_depthwise_conv} : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @flatten_tensor(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x3xi8>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x3xi8>,
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<1x8x3xi8>) -> tensor<1x8x3xi8> ...
[truncated]
|
For context, here's the previous conversation on this specialisation: We have tried to implement this as pre-processing phase (as discussed in D149155), but that leads to either:
So, we rejected the idea of any pre-processing and have considered two alternatives:
The benefits of Option 1. are that:
The benefits of Option 2. would be that:
However, I expect that we might see some missing canonicalisations or other issues (TBH, I've only looked at it briefly). I am happy to investigate Option 2. if that's preferable. However, I would prefer to do it in follow-up patches and to have this in-tree in the meantime. WDYT? And thank you for all the feedback so far :) 🙏🏻 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made a first pass with some minor restructuring comments.
However, as I dug deeper, I have the impression that you are doing a one-off manual implementation of a new swap(vector.insert/extractStridedSlice, vector.shape_cast)
pattern.
It seems to me you could achieve the same benefits by rewriting:
resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc,
lhsVals[linearIndex(kw, w)],
rhsVals[kw], resVals[w]);
as shape_cast + depthwiseConv1dSliceAsMulAcc + shape_cast
and applying such new swap/propagation patterns.
I think this would fit within what you refer to as option 2.
In the current form, I am afraid the manual one-off application of that pattern to your current use case is unappealing to me.
Should we schedule a meeting to discuss this? I agree with Nicolas on that this might be a one-off case but I also acknowledge that we have been looking at multiple options with significant drawbacks and we have to find a way to move this forward. A few replies:
Ultimately, this is something we may want to support. Semi-affine maps should be somewhat challenging but not extremely difficult to support (I already played a bit with this) and I think it has been a long-lasting TODO. Strided loads is another TODO, in general, as we are currently generating gather ops for loads with a short static stride.
I think the ultimate goal is to move all the convolution decomposition step before the vectorizer but nobody has been brave enough to do it :). I think Nicolas' rewrite suggestion goes along these lines. I haven't looked much into this mechanism but there also the possibility of adding this specialization as a hook vectorization pattern. Not sure if we still use that mechanism but it's still there in the vectorizer so it should work. Perhaps we could extend the API so users can provide external ad-hoc vectorization patterns? (Just a random though, perhaps this is not even feasible). Hopefully that helps! |
@banach-space I really like how your recent change significantly reduced the complexity compared to the previous approach. Hopefully the shape_cast swaps nicely with the vector transfers next :) |
✅ With the latest revision this PR passed the C/C++ code formatter. |
@nicolasvasilache Thank you - I was about to write a long reply in which I say "thank you" for your observation (which, indeed, makes things incredibly simpler). But you have already noticed that :)
That would be happening as a post vectorisation canonicalisation, right?
Yes - let me ping you offline :) Regarding As for "strided" loads/stores - super important TODO as well :) In any case, I have refactored this patch following Nicolas' observation (apologies for over-complicating it so much before). Thank you, that was incredibly helpful 🙏🏻 . Now, this change is still a "one off" pattern application within the vectoriser. Should this be done elsewhere? We'd need to match whatever Btw, I need to refine the tests - I am aware of that. And to benchmark. Finally, to better visualise the current changes: Input func.func @conv_dill_2(%input: memref<3x5x4xf32>,
%filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
linalg.depthwise_conv_1d_nwc_wc
{dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>)
outs(%output : memref<3x2x4xf32>)
return
} Vectorisation without shape casting func.func @conv_dill_2(%arg0: memref<3x5x4xf32>, %arg1: memref<2x4xf32>, %arg2: memref<3x2x4xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<3x5x4xf32>, vector<3x4x4xf32>
%1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<2x4xf32>, vector<2x4xf32>
%2 = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<3x2x4xf32>, vector<3x2x4xf32>
%3 = vector.extract_strided_slice %0 {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
%4 = vector.extract_strided_slice %0 {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
%5 = vector.extract %1[0] : vector<4xf32> from vector<2x4xf32>
%6 = vector.extract %1[1] : vector<4xf32> from vector<2x4xf32>
%7 = vector.broadcast %5 : vector<4xf32> to vector<3x2x4xf32>
%8 = vector.fma %3, %7, %2 : vector<3x2x4xf32>
%9 = vector.broadcast %6 : vector<4xf32> to vector<3x2x4xf32>
%10 = vector.fma %4, %9, %8 : vector<3x2x4xf32>
vector.transfer_write %10, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<3x2x4xf32>, memref<3x2x4xf32>
return
} Vectorisation with shape casting func.func @conv_dill_2(%arg0: memref<3x5x4xf32>, %arg1: memref<2x4xf32>, %arg2: memref<3x2x4xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f32
%0 = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<3x5x4xf32>, vector<3x4x4xf32>
%1 = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<2x4xf32>, vector<2x4xf32>
%2 = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<3x2x4xf32>, vector<3x2x4xf32>
%3 = vector.extract_strided_slice %0 {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
%4 = vector.extract_strided_slice %0 {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
%5 = vector.extract %1[0] : vector<4xf32> from vector<2x4xf32>
%6 = vector.extract %1[1] : vector<4xf32> from vector<2x4xf32>
%7 = vector.shape_cast %3 : vector<3x2x4xf32> to vector<3x8xf32>
%8 = vector.shape_cast %2 : vector<3x2x4xf32> to vector<3x8xf32>
%9 = vector.broadcast %5 : vector<4xf32> to vector<3x2x4xf32>
%10 = vector.shape_cast %9 : vector<3x2x4xf32> to vector<3x8xf32>
%11 = vector.fma %7, %10, %8 : vector<3x8xf32>
%12 = vector.shape_cast %4 : vector<3x2x4xf32> to vector<3x8xf32>
%13 = vector.broadcast %6 : vector<4xf32> to vector<3x2x4xf32>
%14 = vector.shape_cast %13 : vector<3x2x4xf32> to vector<3x8xf32>
%15 = vector.fma %12, %14, %11 : vector<3x8xf32>
%16 = vector.shape_cast %15 : vector<3x8xf32> to vector<3x2x4xf32>
vector.transfer_write %16, %arg2[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<3x2x4xf32>, memref<3x2x4xf32>
return
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool!
2cb7081
to
75cc4a6
Compare
Thank you for all the suggestions! @nicolasvasilache I believe that the current implementation addresses all of your comments.
There's hope 🤞🏻 , but it's going to be a journey 😅 (I have already landed #73522). If there's no new traffic then I will merge this soon. -Andrzej |
75cc4a6
to
8c877d6
Compare
Nicolas is OOO so I would wait a bit and land and address any further comments in a follow-up commit. That should be ok as we discussed this approach in a meeting. |
The current vectorization of 1D depthwise convolutions in Linalg is _sub-optimal_ for tensor with a low number of channel dimensions, e.g.: ```mlir linalg.depthwise_conv_1d_nwc_wc {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} ins(%input, %filter : tensor<1x8x3xi8>, tensor<1x3xi8>) outs(%output : tensor<1x8x3xi8>) -> tensor<1x8x3xi8> ``` That's due to the fact that ultimately (i.e. at LLVM level), vectorization happens along the trailing dimension (i.e. the channel dimension). In this case it leads to vectors with 3 elements (or worse, if there's e.g. only 1 channel dimension). For comparison, a 128 bit wide vector registers can hold 16 x i8. Instead, this patch adds an option to flatten/collapse the channel dimension into the width dimension of the input/output: ```mlir %collapsed = tensor.collapse_shape %input [[0], [1, 2]] : tensor<1x8x3xi8> into tensor<1x24xi8> %collapsed_0 = tensor.collapse_shape %outpu [[0], [1, 2]] : tensor<1x8x3xi8> into tensor<1x24xi8> ``` (Note that for this to work, the filter is broadcast rather than re-shaped. Please see the test cases for details). The new vectorization strategy is implemented in `depthwiseConvFlatten`, which was implemented based on `depthwiseConvGeneric` (i.e. the original vectorization hook). The current vectorization is preserved and kept as the default option. New vectorization can be selected through e.g. a transform dialect attribute: ```mlir transform.structured.vectorize_children_and_apply_patterns %conv {flatten_1d_depthwise_conv} ``` A forthcoming patch will implement a strategy to automatically switch between the two implementations, depending on the shape of the input tensors. Co-authored by: Bradley Smith <[email protected]>
…izing Following on from Nicolas' observation, this commit refactors the implementation to simply replace: ``` resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]); ``` with shape_cast + depthwiseConv1dSliceAsMulAcc + shape_cast.
…izing Revert renaming + fix formatting
…izing Final tweaks (more comments, revert unrelated change in a test file)
8c877d6
to
8e8f56d
Compare
Updates the vectorisation of 1D depthwise convolution when flattening the channel dimension (introduced in llvm#71918). In particular - how the convolution filter is "flattened". ATM, the vectoriser will use `vector.shape_cast`: ```mlir %b_filter = vector.broadcast %filter : vector<4xf32> to vector<3x2x4xf32> %sc_filter = vector.shape_cast %b_filter : vector<3x2x4xf32> to vector<3x8xf32> ``` This lowering is not ideal - `vector.shape_cast` can be convenient when it's folded away, but that's not happening in this case. Instead, this patch updates the vectoriser to use `vector.shuffle` (the overall result is identical): ```mlir %sh_filter = vector.shuffle %filter, %filter [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32> %b_filter = vector.broadcast %sh_filter : vector<8xf32> to vector<3x8xf32> ```
Updates the vectorisation of 1D depthwise convolution when flattening the channel dimension (introduced in #71918). In particular - how the convolution filter is "flattened". ATM, the vectoriser will use `vector.shape_cast`: ```mlir %b_filter = vector.broadcast %filter : vector<4xf32> to vector<3x2x4xf32> %sc_filter = vector.shape_cast %b_filter : vector<3x2x4xf32> to vector<3x8xf32> ``` This lowering is not ideal - `vector.shape_cast` can be convenient when it's folded away, but that's not happening in this case. Instead, this patch updates the vectoriser to use `vector.shuffle` (the overall result is identical): ```mlir %sh_filter = vector.shuffle %filter, %filter [0, 1, 2, 3, 0, 1, 2, 3] : vector<4xf32>, vector<4xf32> %b_filter = vector.broadcast %sh_filter : vector<8xf32> to vector<3x8xf32> ```
The current vectorization of 1D depthwise convolutions in Linalg is
sub-optimal for tensor with a low number of channel dimensions, e.g.:
That's due to the fact that ultimately (i.e. at LLVM level),
vectorization happens along the trailing dimension (i.e. the channel
dimension). In this case it leads to vectors with 3 elements (or worse,
if there's e.g. only 1 channel dimension). For comparison, a 128 bit
wide vector registers can hold 16 x i8.
Instead, this patch adds an option to flatten/collapse the channel
dimension into the width dimension of the input/filter/output using
vector.shape_cast
operation:This new vectorization mode is implemented in
depthwiseConv
byinserting
vector.shape_cast
Ops before and afterdepthwiseConv1dSliceAsMulAcc
is invoked. It can be selected throughe.g. a transform dialect attribute:
A forthcoming patch will implement a strategy to automatically switch
between the two implementations, depending on the shape of the input
tensors.
Co-authored by: Bradley Smith [email protected]