Skip to content

Commit 03c2f5d

Browse files
authored
[mlir][linalg][conv] Flatten the channel dimension when vectorizing (#71918)
The current vectorization of 1D depthwise convolutions in Linalg is _sub-optimal_ for tensor with a low number of channel dimensions, e.g.: ```mlir linalg.depthwise_conv_1d_nwc_wc {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} ins(%input, %filter : tensor<1x8x3xi8>, tensor<1x3xi8>) outs(%output : tensor<1x8x3xi8>) -> tensor<1x8x3xi8> ``` That's due to the fact that ultimately (i.e. at LLVM level), vectorization happens along the trailing dimension (i.e. the channel dimension). In this case it leads to vectors with 3 elements (or worse, if there's e.g. only 1 channel dimension). For comparison, a 128 bit wide vector registers can hold 16 x i8. Instead, this patch adds an option to flatten/collapse the channel dimension into the width dimension of the input/filter/output using `vector.shape_cast` operation: ```mlir %sc_input = vector.shape_cast %input : vector<1x8x3xi8> to vector<1x24xi8> %sc_output = vector.shape_cast %output : vector<1x8x3xi8> to vector<1x24xi8> %b_filter = vector.broadcast %filter : vector<3xi8> to vector<1x8x3xi8> %sc_filter = vector.shape_cast %b_filter : vector<1x8x3xi8> to vector<1x24xi8> ``` This new vectorization mode is implemented in `depthwiseConv` by inserting `vector.shape_cast` Ops before and after `depthwiseConv1dSliceAsMulAcc` is invoked. It can be selected through e.g. a transform dialect attribute: ```mlir transform.structured.vectorize_children_and_apply_patterns %conv {flatten_1d_depthwise_conv} ``` A forthcoming patch will implement a strategy to automatically switch between the two implementations, depending on the shape of the input tensors. Co-authored by: Bradley Smith <[email protected]>
1 parent 98ce2de commit 03c2f5d

File tree

5 files changed

+388
-29
lines changed

5 files changed

+388
-29
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2038,6 +2038,7 @@ def VectorizeChildrenAndApplyPatternsOp :
20382038
let arguments = (ins TransformHandleTypeInterface:$target,
20392039
UnitAttr:$vectorize_padding,
20402040
UnitAttr:$vectorize_nd_extract,
2041+
UnitAttr:$flatten_1d_depthwise_conv,
20412042
UnitAttr:$disable_multi_reduction_to_contract_patterns,
20422043
UnitAttr:$disable_transfer_permutation_map_lowering_patterns);
20432044
let results = (outs TransformHandleTypeInterface:$transformed);
@@ -2049,7 +2050,8 @@ def VectorizeChildrenAndApplyPatternsOp :
20492050
let builders = [
20502051
OpBuilder<(ins "Value":$target,
20512052
CArg<"bool", "false">:$vectorizePadding,
2052-
CArg<"bool", "false">:$vectorizeNDExtract)>,
2053+
CArg<"bool", "false">:$vectorizeNDExtract,
2054+
CArg<"bool", "false">:$flatten1DDepthwise)>
20532055
];
20542056
let extraClassDeclaration = [{
20552057
::mlir::DiagnosedSilenceableFailure applyToOne(

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,8 @@ LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/);
753753
LogicalResult vectorize(RewriterBase &rewriter, Operation *op,
754754
ArrayRef<int64_t> inputVectorSizes = {},
755755
ArrayRef<bool> inputScalableVecDims = {},
756-
bool vectorizeNDExtract = false);
756+
bool vectorizeNDExtract = false,
757+
bool flatten1DDepthwiseConv = false);
757758

758759
/// Emit a suitable vector form for a Copy op with fully static shape.
759760
LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp);

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2946,7 +2946,7 @@ LogicalResult TileUsingForallOp::verify() {
29462946

29472947
void transform::VectorizeChildrenAndApplyPatternsOp::build(
29482948
OpBuilder &builder, OperationState &result, Value target,
2949-
bool vectorizePadding, bool vectorizeExtract) {
2949+
bool vectorizePadding, bool vectorizeExtract, bool flatten1DDepthwiseConv) {
29502950
result.addOperands(target);
29512951
if (vectorizePadding) {
29522952
result.addAttribute(
@@ -2960,6 +2960,12 @@ void transform::VectorizeChildrenAndApplyPatternsOp::build(
29602960
result.name),
29612961
builder.getUnitAttr());
29622962
}
2963+
if (flatten1DDepthwiseConv) {
2964+
result.addAttribute(
2965+
VectorizeChildrenAndApplyPatternsOp::getFlatten_1dDepthwiseConvAttrName(
2966+
result.name),
2967+
builder.getUnitAttr());
2968+
}
29632969
result.addTypes(transform::AnyOpType::get(builder.getContext()));
29642970
}
29652971

@@ -2968,22 +2974,29 @@ namespace {
29682974
/// VectorizeChildrenAndApplyPatternsOp::applyToOne.
29692975
struct VectorizationPattern : public RewritePattern {
29702976
explicit VectorizationPattern(MLIRContext *context,
2971-
bool vectorizeExtract = false)
2977+
bool vectorizeExtract = false,
2978+
bool flattenConv = false)
29722979
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
2973-
vectorizeNDExtract(vectorizeExtract) {}
2980+
vectorizeNDExtract(vectorizeExtract),
2981+
flatten1DDepthwiseConv(flattenConv) {}
29742982
LogicalResult matchAndRewrite(Operation *op,
29752983
PatternRewriter &rewriter) const override {
29762984
LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
29772985
if (!linalgOp)
29782986
return rewriter.notifyMatchFailure(op, "expected Linalg Op");
29792987
return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{},
2980-
/*scalableVecDims=*/{}, vectorizeNDExtract);
2988+
/*scalableVecDims=*/{}, vectorizeNDExtract,
2989+
flatten1DDepthwiseConv);
29812990
}
29822991

29832992
private:
29842993
/// Controls whether to vectorize `tensor.extract` when the input tensor is
29852994
/// rank >= 2.
29862995
bool vectorizeNDExtract = false;
2996+
/// Controls whether to "flatten" the channel dimension when vectorising 1D
2997+
/// depthwise convolutions. This should lead to bette vectorization for
2998+
/// tensors with a low number of channel dimensions.
2999+
bool flatten1DDepthwiseConv = false;
29873000
};
29883001
} // namespace
29893002

@@ -3000,7 +3013,8 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
30003013

30013014
MLIRContext *ctx = getContext();
30023015
RewritePatternSet patterns(ctx);
3003-
patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract());
3016+
patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract(),
3017+
getFlatten_1dDepthwiseConv());
30043018

30053019
if (!getDisableTransferPermutationMapLoweringPatterns())
30063020
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ using namespace mlir::linalg;
4444
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
4545

4646
/// Try to vectorize `convOp` as a convolution.
47-
static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
48-
LinalgOp convOp);
47+
static FailureOr<Operation *>
48+
vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
49+
bool flatten1DDepthwiseConv = false);
4950

5051
/// Return the unique instance of OpType in `block` if it is indeed unique.
5152
/// Return null if none or more than 1 instances exist.
@@ -1664,7 +1665,8 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
16641665
LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
16651666
ArrayRef<int64_t> inputVectorSizes,
16661667
ArrayRef<bool> inputScalableVecDims,
1667-
bool vectorizeNDExtract) {
1668+
bool vectorizeNDExtract,
1669+
bool flatten1DDepthwiseConv) {
16681670
LDBG("Attempting to vectorize:\n" << *op << "\n");
16691671
LDBG("Input vector sizes: ");
16701672
LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
@@ -1696,8 +1698,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
16961698
// TODO: isaConvolutionOpInterface that can also infer from generic
16971699
// features. Will require stride/dilation attributes inference.
16981700
if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
1699-
FailureOr<Operation *> convOr =
1700-
vectorizeConvolution(rewriter, linalgOp);
1701+
FailureOr<Operation *> convOr = vectorizeConvolution(
1702+
rewriter, linalgOp, flatten1DDepthwiseConv);
17011703
if (succeeded(convOr)) {
17021704
llvm::append_range(results, (*convOr)->getResults());
17031705
return success();
@@ -2822,7 +2824,7 @@ struct Conv1DGenerator
28222824
/// kw is always unrolled.
28232825
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
28242826
/// > 1.
2825-
FailureOr<Operation *> depthwiseConv() {
2827+
FailureOr<Operation *> depthwiseConv(bool flatten) {
28262828
if (!valid)
28272829
return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
28282830

@@ -2869,15 +2871,17 @@ struct Conv1DGenerator
28692871
//===------------------------------------------------------------------===//
28702872
// Unroll along kw and read slices of lhs and rhs.
28712873
SmallVector<Value> lhsVals, rhsVals, resVals;
2874+
auto inOutSliceSizes = SmallVector<int64_t>{nSize, wSizeStep, cSize};
2875+
auto inOutStrides = SmallVector<int64_t>{1, 1, 1};
2876+
28722877
// Extract lhs slice of size {n, wSizeStep, c}
28732878
// @ [0, sw * w + dw * kw, 0].
28742879
for (int64_t kw = 0; kw < kwSize; ++kw) {
28752880
for (int64_t w = 0; w < wSize; w += wSizeStep) {
28762881
lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
28772882
loc, lhs,
28782883
/*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
2879-
/*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
2880-
/*strides=*/ArrayRef<int64_t>{1, 1, 1}));
2884+
inOutSliceSizes, inOutStrides));
28812885
}
28822886
}
28832887
// Extract rhs slice of size {c} @ [kw].
@@ -2889,21 +2893,39 @@ struct Conv1DGenerator
28892893
for (int64_t w = 0; w < wSize; w += wSizeStep) {
28902894
resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
28912895
loc, res,
2892-
/*offsets=*/ArrayRef<int64_t>{0, w, 0},
2893-
/*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
2894-
/*strides=*/ArrayRef<int64_t>{1, 1, 1}));
2896+
/*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
2897+
inOutStrides));
28952898
}
28962899

28972900
auto linearIndex = [&](int64_t kw, int64_t w) {
28982901
return kw * (wSize / wSizeStep) + w;
28992902
};
29002903

2904+
auto inOutFlattenSliceSizes =
2905+
SmallVector<int64_t>{nSize, wSizeStep * cSize};
2906+
auto lhsCastType = VectorType::get(inOutFlattenSliceSizes, lhsEltType);
2907+
auto resCastType = VectorType::get(inOutFlattenSliceSizes, resEltType);
29012908
// Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
29022909
for (int64_t kw = 0; kw < kwSize; ++kw) {
29032910
for (int64_t w = 0; w < wSize; w += wSizeStep) {
2904-
resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc,
2905-
lhsVals[linearIndex(kw, w)],
2906-
rhsVals[kw], resVals[w]);
2911+
Value lhsVal = lhsVals[linearIndex(kw, w)];
2912+
Value resVal = resVals[w];
2913+
ShapedType filterBCastTy = cast<ShapedType>(resVal.getType());
2914+
if (flatten) {
2915+
// Flatten the input and filter vectors (collapse the channel
2916+
// dimension)
2917+
lhsVal = rewriter.create<vector::ShapeCastOp>(
2918+
loc, lhsCastType, lhsVals[linearIndex(kw, w)]);
2919+
resVal = rewriter.create<vector::ShapeCastOp>(loc, resCastType,
2920+
resVals[w]);
2921+
}
2922+
resVals[w] = depthwiseConv1dSliceAsMulAcc(
2923+
rewriter, loc, lhsVal, rhsVals[kw], resVal, filterBCastTy, flatten);
2924+
if (flatten) {
2925+
// Un-flatten the output vector (restore the channel dimension)
2926+
resVals[w] = rewriter.create<vector::ShapeCastOp>(
2927+
loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
2928+
}
29072929
}
29082930
}
29092931

@@ -2936,17 +2958,27 @@ struct Conv1DGenerator
29362958
.getOperation();
29372959
}
29382960

2939-
/// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc
2961+
/// Lower:
2962+
/// * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
2963+
/// * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
2964+
/// to MulAcc.
29402965
Value depthwiseConv1dSliceAsMulAcc(RewriterBase &rewriter, Location loc,
2941-
Value lhs, Value rhs, Value res) {
2966+
Value lhs, Value rhs, Value res,
2967+
ShapedType bcastTy, bool flatten) {
29422968
auto rhsTy = cast<ShapedType>(rhs.getType());
29432969
auto resTy = cast<ShapedType>(res.getType());
29442970

29452971
// TODO(suderman): Change this to use a vector.ima intrinsic.
29462972
lhs = promote(rewriter, loc, lhs, resTy);
29472973

29482974
rhs = rewriter.create<vector::BroadcastOp>(
2949-
loc, resTy.clone(rhsTy.getElementType()), rhs);
2975+
loc, bcastTy.clone(rhsTy.getElementType()), rhs);
2976+
if (flatten) {
2977+
// Flatten the channel dimension
2978+
rhs = rewriter.create<vector::ShapeCastOp>(
2979+
loc, resTy.clone(rhsTy.getElementType()), rhs);
2980+
}
2981+
29502982
rhs = promote(rewriter, loc, rhs, resTy);
29512983

29522984
if (!lhs || !rhs)
@@ -3049,7 +3081,7 @@ struct Conv1DGenerator
30493081

30503082
/// Entry point that transposes into the common form:
30513083
/// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3052-
FailureOr<Operation *> generateDilatedConv() {
3084+
FailureOr<Operation *> generateDilatedConv(bool flatten = false) {
30533085
AffineExpr n, w, c, kw;
30543086
bindDims(ctx, n, w, c, kw);
30553087
if (!iters({Par(), Par(), Par(), Red()}))
@@ -3060,7 +3092,7 @@ struct Conv1DGenerator
30603092
if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
30613093
/*rhsIndex*/ {kw, c},
30623094
/*resIndex*/ {n, w, c}}))
3063-
return depthwiseConv();
3095+
return depthwiseConv(flatten);
30643096

30653097
return rewriter.notifyMatchFailure(op, "not a depthwise::Nwc layout");
30663098
}
@@ -3125,8 +3157,9 @@ struct Conv1DGenerator
31253157

31263158
/// Helper function to vectorize a LinalgOp with convolution semantics.
31273159
// TODO: extend the generic vectorization to support windows and drop this.
3128-
static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
3129-
LinalgOp op) {
3160+
static FailureOr<Operation *>
3161+
vectorizeConvolution(RewriterBase &rewriter, LinalgOp op,
3162+
bool flatten1DDepthwiseConv) {
31303163
// The ConvolutionOpInterface gives us guarantees of existence for
31313164
// strides/dilations. However, we do not need to rely on those, we can simply
31323165
// use them if present, otherwise use the default and let the generic conv.
@@ -3151,7 +3184,7 @@ static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
31513184
res = e.generateNcwPooling();
31523185
if (succeeded(res))
31533186
return res;
3154-
return e.generateDilatedConv();
3187+
return e.generateDilatedConv(flatten1DDepthwiseConv);
31553188
}
31563189

31573190
struct VectorizeConvolution : public OpInterfaceRewritePattern<LinalgOp> {

0 commit comments

Comments
 (0)