@@ -2881,8 +2881,7 @@ struct Conv1DGenerator
2881
2881
lhsVals.push_back (rewriter.create <vector::ExtractStridedSliceOp>(
2882
2882
loc, lhs,
2883
2883
/* offsets=*/ ArrayRef<int64_t >{0 , w * strideW + kw * dilationW, 0 },
2884
- inOutSliceSizes,
2885
- inOutStrides));
2884
+ inOutSliceSizes, inOutStrides));
2886
2885
}
2887
2886
}
2888
2887
// Extract rhs slice of size {c} @ [kw].
@@ -2894,35 +2893,39 @@ struct Conv1DGenerator
2894
2893
for (int64_t w = 0 ; w < wSize; w += wSizeStep) {
2895
2894
resVals.push_back (rewriter.create <vector::ExtractStridedSliceOp>(
2896
2895
loc, res,
2897
- /* offsets=*/ ArrayRef<int64_t >{0 , w, 0 },
2898
- inOutSliceSizes,
2896
+ /* offsets=*/ ArrayRef<int64_t >{0 , w, 0 }, inOutSliceSizes,
2899
2897
inOutStrides));
2900
2898
}
2901
2899
2902
2900
auto linearIndex = [&](int64_t kw, int64_t w) {
2903
2901
return kw * (wSize / wSizeStep) + w;
2904
2902
};
2905
2903
2906
- auto inOutFlattenSliceSizes = SmallVector<int64_t >{nSize, wSizeStep * cSize};
2904
+ auto inOutFlattenSliceSizes =
2905
+ SmallVector<int64_t >{nSize, wSizeStep * cSize};
2907
2906
auto lhsCastType = VectorType::get (inOutFlattenSliceSizes, lhsEltType);
2908
- auto resCastType = VectorType::get (inOutFlattenSliceSizes, lhsEltType );
2907
+ auto resCastType = VectorType::get (inOutFlattenSliceSizes, resEltType );
2909
2908
// Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
2910
2909
for (int64_t kw = 0 ; kw < kwSize; ++kw) {
2911
2910
for (int64_t w = 0 ; w < wSize; w += wSizeStep) {
2912
2911
Value lhsVal = lhsVals[linearIndex (kw, w)];
2913
2912
Value resVal = resVals[w];
2914
2913
ShapedType filterBCastTy = cast<ShapedType>(resVal.getType ());
2915
2914
if (flatten) {
2915
+ // Flatten the input and filter vectors (collapse the channel
2916
+ // dimension)
2916
2917
lhsVal = rewriter.create <vector::ShapeCastOp>(
2917
2918
loc, lhsCastType, lhsVals[linearIndex (kw, w)]);
2918
2919
resVal = rewriter.create <vector::ShapeCastOp>(loc, resCastType,
2919
2920
resVals[w]);
2920
2921
}
2921
2922
resVals[w] = depthwiseConv1dSliceAsMulAcc (
2922
2923
rewriter, loc, lhsVal, rhsVals[kw], resVal, filterBCastTy, flatten);
2923
- if (flatten)
2924
+ if (flatten) {
2925
+ // Un-flatten the output vector (restore the channel dimension)
2924
2926
resVals[w] = rewriter.create <vector::ShapeCastOp>(
2925
2927
loc, VectorType::get (inOutSliceSizes, resEltType), resVals[w]);
2928
+ }
2926
2929
}
2927
2930
}
2928
2931
@@ -2970,8 +2973,11 @@ struct Conv1DGenerator
2970
2973
2971
2974
rhs = rewriter.create <vector::BroadcastOp>(
2972
2975
loc, bcastTy.clone (rhsTy.getElementType ()), rhs);
2973
- if (flatten)
2974
- rhs = rewriter.create <vector::ShapeCastOp>(loc, resTy, rhs);
2976
+ if (flatten) {
2977
+ // Flatten the channel dimension
2978
+ rhs = rewriter.create <vector::ShapeCastOp>(
2979
+ loc, resTy.clone (rhsTy.getElementType ()), rhs);
2980
+ }
2975
2981
2976
2982
rhs = promote (rewriter, loc, rhs, resTy);
2977
2983
0 commit comments