@@ -44,8 +44,9 @@ using namespace mlir::linalg;
44
44
#define LDBG (X ) LLVM_DEBUG(DBGS() << X << " \n " )
45
45
46
46
// / Try to vectorize `convOp` as a convolution.
47
- static FailureOr<Operation *> vectorizeConvolution (RewriterBase &rewriter,
48
- LinalgOp convOp);
47
+ static FailureOr<Operation *>
48
+ vectorizeConvolution (RewriterBase &rewriter, LinalgOp convOp,
49
+ bool flatten1DDepthwiseConv = false );
49
50
50
51
// / Return the unique instance of OpType in `block` if it is indeed unique.
51
52
// / Return null if none or more than 1 instances exist.
@@ -1664,7 +1665,8 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
1664
1665
LogicalResult mlir::linalg::vectorize (RewriterBase &rewriter, Operation *op,
1665
1666
ArrayRef<int64_t > inputVectorSizes,
1666
1667
ArrayRef<bool > inputScalableVecDims,
1667
- bool vectorizeNDExtract) {
1668
+ bool vectorizeNDExtract,
1669
+ bool flatten1DDepthwiseConv) {
1668
1670
LDBG (" Attempting to vectorize:\n " << *op << " \n " );
1669
1671
LDBG (" Input vector sizes: " );
1670
1672
LLVM_DEBUG (llvm::interleaveComma (inputVectorSizes, llvm::dbgs ()));
@@ -1696,8 +1698,8 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
1696
1698
// TODO: isaConvolutionOpInterface that can also infer from generic
1697
1699
// features. Will require stride/dilation attributes inference.
1698
1700
if (isa<ConvolutionOpInterface>(linalgOp.getOperation ())) {
1699
- FailureOr<Operation *> convOr =
1700
- vectorizeConvolution ( rewriter, linalgOp);
1701
+ FailureOr<Operation *> convOr = vectorizeConvolution (
1702
+ rewriter, linalgOp, flatten1DDepthwiseConv );
1701
1703
if (succeeded (convOr)) {
1702
1704
llvm::append_range (results, (*convOr)->getResults ());
1703
1705
return success ();
@@ -2822,7 +2824,7 @@ struct Conv1DGenerator
2822
2824
// / kw is always unrolled.
2823
2825
// / TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
2824
2826
// / > 1.
2825
- FailureOr<Operation *> depthwiseConv () {
2827
+ FailureOr<Operation *> depthwiseConv (bool flatten ) {
2826
2828
if (!valid)
2827
2829
return rewriter.notifyMatchFailure (op, " unvectorizable depthwise conv" );
2828
2830
@@ -2869,15 +2871,17 @@ struct Conv1DGenerator
2869
2871
// ===------------------------------------------------------------------===//
2870
2872
// Unroll along kw and read slices of lhs and rhs.
2871
2873
SmallVector<Value> lhsVals, rhsVals, resVals;
2874
+ auto inOutSliceSizes = SmallVector<int64_t >{nSize, wSizeStep, cSize};
2875
+ auto inOutStrides = SmallVector<int64_t >{1 , 1 , 1 };
2876
+
2872
2877
// Extract lhs slice of size {n, wSizeStep, c}
2873
2878
// @ [0, sw * w + dw * kw, 0].
2874
2879
for (int64_t kw = 0 ; kw < kwSize; ++kw) {
2875
2880
for (int64_t w = 0 ; w < wSize; w += wSizeStep) {
2876
2881
lhsVals.push_back (rewriter.create <vector::ExtractStridedSliceOp>(
2877
2882
loc, lhs,
2878
2883
/* offsets=*/ ArrayRef<int64_t >{0 , w * strideW + kw * dilationW, 0 },
2879
- /* sizes=*/ ArrayRef<int64_t >{nSize, wSizeStep, cSize},
2880
- /* strides=*/ ArrayRef<int64_t >{1 , 1 , 1 }));
2884
+ inOutSliceSizes, inOutStrides));
2881
2885
}
2882
2886
}
2883
2887
// Extract rhs slice of size {c} @ [kw].
@@ -2889,21 +2893,39 @@ struct Conv1DGenerator
2889
2893
for (int64_t w = 0 ; w < wSize; w += wSizeStep) {
2890
2894
resVals.push_back (rewriter.create <vector::ExtractStridedSliceOp>(
2891
2895
loc, res,
2892
- /* offsets=*/ ArrayRef<int64_t >{0 , w, 0 },
2893
- /* sizes=*/ ArrayRef<int64_t >{nSize, wSizeStep, cSize},
2894
- /* strides=*/ ArrayRef<int64_t >{1 , 1 , 1 }));
2896
+ /* offsets=*/ ArrayRef<int64_t >{0 , w, 0 }, inOutSliceSizes,
2897
+ inOutStrides));
2895
2898
}
2896
2899
2897
2900
auto linearIndex = [&](int64_t kw, int64_t w) {
2898
2901
return kw * (wSize / wSizeStep) + w;
2899
2902
};
2900
2903
2904
+ auto inOutFlattenSliceSizes =
2905
+ SmallVector<int64_t >{nSize, wSizeStep * cSize};
2906
+ auto lhsCastType = VectorType::get (inOutFlattenSliceSizes, lhsEltType);
2907
+ auto resCastType = VectorType::get (inOutFlattenSliceSizes, resEltType);
2901
2908
// Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
2902
2909
for (int64_t kw = 0 ; kw < kwSize; ++kw) {
2903
2910
for (int64_t w = 0 ; w < wSize; w += wSizeStep) {
2904
- resVals[w] = depthwiseConv1dSliceAsMulAcc (rewriter, loc,
2905
- lhsVals[linearIndex (kw, w)],
2906
- rhsVals[kw], resVals[w]);
2911
+ Value lhsVal = lhsVals[linearIndex (kw, w)];
2912
+ Value resVal = resVals[w];
2913
+ ShapedType filterBCastTy = cast<ShapedType>(resVal.getType ());
2914
+ if (flatten) {
2915
+ // Flatten the input and filter vectors (collapse the channel
2916
+ // dimension)
2917
+ lhsVal = rewriter.create <vector::ShapeCastOp>(
2918
+ loc, lhsCastType, lhsVals[linearIndex (kw, w)]);
2919
+ resVal = rewriter.create <vector::ShapeCastOp>(loc, resCastType,
2920
+ resVals[w]);
2921
+ }
2922
+ resVals[w] = depthwiseConv1dSliceAsMulAcc (
2923
+ rewriter, loc, lhsVal, rhsVals[kw], resVal, filterBCastTy, flatten);
2924
+ if (flatten) {
2925
+ // Un-flatten the output vector (restore the channel dimension)
2926
+ resVals[w] = rewriter.create <vector::ShapeCastOp>(
2927
+ loc, VectorType::get (inOutSliceSizes, resEltType), resVals[w]);
2928
+ }
2907
2929
}
2908
2930
}
2909
2931
@@ -2936,17 +2958,27 @@ struct Conv1DGenerator
2936
2958
.getOperation ();
2937
2959
}
2938
2960
2939
- // / Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to MulAcc
2961
+ // / Lower:
2962
+ // / * lhs{n, w, c} * rhs{c} -> res{n, w, c} (flatten = false)
2963
+ // / * lhs{n, w * c} * rhs{c} -> res{n, w * c} (flatten = true)
2964
+ // / to MulAcc.
2940
2965
Value depthwiseConv1dSliceAsMulAcc (RewriterBase &rewriter, Location loc,
2941
- Value lhs, Value rhs, Value res) {
2966
+ Value lhs, Value rhs, Value res,
2967
+ ShapedType bcastTy, bool flatten) {
2942
2968
auto rhsTy = cast<ShapedType>(rhs.getType ());
2943
2969
auto resTy = cast<ShapedType>(res.getType ());
2944
2970
2945
2971
// TODO(suderman): Change this to use a vector.ima intrinsic.
2946
2972
lhs = promote (rewriter, loc, lhs, resTy);
2947
2973
2948
2974
rhs = rewriter.create <vector::BroadcastOp>(
2949
- loc, resTy.clone (rhsTy.getElementType ()), rhs);
2975
+ loc, bcastTy.clone (rhsTy.getElementType ()), rhs);
2976
+ if (flatten) {
2977
+ // Flatten the channel dimension
2978
+ rhs = rewriter.create <vector::ShapeCastOp>(
2979
+ loc, resTy.clone (rhsTy.getElementType ()), rhs);
2980
+ }
2981
+
2950
2982
rhs = promote (rewriter, loc, rhs, resTy);
2951
2983
2952
2984
if (!lhs || !rhs)
@@ -3049,7 +3081,7 @@ struct Conv1DGenerator
3049
3081
3050
3082
// / Entry point that transposes into the common form:
3051
3083
// / {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
3052
- FailureOr<Operation *> generateDilatedConv () {
3084
+ FailureOr<Operation *> generateDilatedConv (bool flatten = false ) {
3053
3085
AffineExpr n, w, c, kw;
3054
3086
bindDims (ctx, n, w, c, kw);
3055
3087
if (!iters ({Par (), Par (), Par (), Red ()}))
@@ -3060,7 +3092,7 @@ struct Conv1DGenerator
3060
3092
if (layout ({/* lhsIndex*/ {n, strideW * w + dilationW * kw, c},
3061
3093
/* rhsIndex*/ {kw, c},
3062
3094
/* resIndex*/ {n, w, c}}))
3063
- return depthwiseConv ();
3095
+ return depthwiseConv (flatten );
3064
3096
3065
3097
return rewriter.notifyMatchFailure (op, " not a depthwise::Nwc layout" );
3066
3098
}
@@ -3125,8 +3157,9 @@ struct Conv1DGenerator
3125
3157
3126
3158
// / Helper function to vectorize a LinalgOp with convolution semantics.
3127
3159
// TODO: extend the generic vectorization to support windows and drop this.
3128
- static FailureOr<Operation *> vectorizeConvolution (RewriterBase &rewriter,
3129
- LinalgOp op) {
3160
+ static FailureOr<Operation *>
3161
+ vectorizeConvolution (RewriterBase &rewriter, LinalgOp op,
3162
+ bool flatten1DDepthwiseConv) {
3130
3163
// The ConvolutionOpInterface gives us guarantees of existence for
3131
3164
// strides/dilations. However, we do not need to rely on those, we can simply
3132
3165
// use them if present, otherwise use the default and let the generic conv.
@@ -3151,7 +3184,7 @@ static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter,
3151
3184
res = e.generateNcwPooling ();
3152
3185
if (succeeded (res))
3153
3186
return res;
3154
- return e.generateDilatedConv ();
3187
+ return e.generateDilatedConv (flatten1DDepthwiseConv );
3155
3188
}
3156
3189
3157
3190
struct VectorizeConvolution : public OpInterfaceRewritePattern <LinalgOp> {
0 commit comments