Skip to content

Commit 8e8f56d

Browse files
committed
fixup! [mlir][linalg][conv] Flatten the channel dimension when vectorizing
Final tweaks (more comments, revert unrelated change in a test file)
1 parent 69fa20e commit 8e8f56d

File tree

4 files changed

+267
-173
lines changed

4 files changed

+267
-173
lines changed

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2990,6 +2990,9 @@ struct VectorizationPattern : public RewritePattern {
29902990
/// Controls whether to vectorize `tensor.extract` when the input tensor is
29912991
/// rank >= 2.
29922992
bool vectorizeNDExtract = false;
2993+
/// Controls whether to "flatten" the channel dimension when vectorising 1D
2994+
/// depthwise convolutions. This should lead to bette vectorization for
2995+
/// tensors with a low number of channel dimensions.
29932996
bool flatten1DDepthwiseConv = false;
29942997
};
29952998
} // namespace

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2881,8 +2881,7 @@ struct Conv1DGenerator
28812881
lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
28822882
loc, lhs,
28832883
/*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
2884-
inOutSliceSizes,
2885-
inOutStrides));
2884+
inOutSliceSizes, inOutStrides));
28862885
}
28872886
}
28882887
// Extract rhs slice of size {c} @ [kw].
@@ -2894,35 +2893,39 @@ struct Conv1DGenerator
28942893
for (int64_t w = 0; w < wSize; w += wSizeStep) {
28952894
resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>(
28962895
loc, res,
2897-
/*offsets=*/ArrayRef<int64_t>{0, w, 0},
2898-
inOutSliceSizes,
2896+
/*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes,
28992897
inOutStrides));
29002898
}
29012899

29022900
auto linearIndex = [&](int64_t kw, int64_t w) {
29032901
return kw * (wSize / wSizeStep) + w;
29042902
};
29052903

2906-
auto inOutFlattenSliceSizes = SmallVector<int64_t>{nSize, wSizeStep * cSize};
2904+
auto inOutFlattenSliceSizes =
2905+
SmallVector<int64_t>{nSize, wSizeStep * cSize};
29072906
auto lhsCastType = VectorType::get(inOutFlattenSliceSizes, lhsEltType);
2908-
auto resCastType = VectorType::get(inOutFlattenSliceSizes, lhsEltType);
2907+
auto resCastType = VectorType::get(inOutFlattenSliceSizes, resEltType);
29092908
// Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
29102909
for (int64_t kw = 0; kw < kwSize; ++kw) {
29112910
for (int64_t w = 0; w < wSize; w += wSizeStep) {
29122911
Value lhsVal = lhsVals[linearIndex(kw, w)];
29132912
Value resVal = resVals[w];
29142913
ShapedType filterBCastTy = cast<ShapedType>(resVal.getType());
29152914
if (flatten) {
2915+
// Flatten the input and filter vectors (collapse the channel
2916+
// dimension)
29162917
lhsVal = rewriter.create<vector::ShapeCastOp>(
29172918
loc, lhsCastType, lhsVals[linearIndex(kw, w)]);
29182919
resVal = rewriter.create<vector::ShapeCastOp>(loc, resCastType,
29192920
resVals[w]);
29202921
}
29212922
resVals[w] = depthwiseConv1dSliceAsMulAcc(
29222923
rewriter, loc, lhsVal, rhsVals[kw], resVal, filterBCastTy, flatten);
2923-
if (flatten)
2924+
if (flatten) {
2925+
// Un-flatten the output vector (restore the channel dimension)
29242926
resVals[w] = rewriter.create<vector::ShapeCastOp>(
29252927
loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]);
2928+
}
29262929
}
29272930
}
29282931

@@ -2970,8 +2973,11 @@ struct Conv1DGenerator
29702973

29712974
rhs = rewriter.create<vector::BroadcastOp>(
29722975
loc, bcastTy.clone(rhsTy.getElementType()), rhs);
2973-
if (flatten)
2974-
rhs = rewriter.create<vector::ShapeCastOp>(loc, resTy, rhs);
2976+
if (flatten) {
2977+
// Flatten the channel dimension
2978+
rhs = rewriter.create<vector::ShapeCastOp>(
2979+
loc, resTy.clone(rhsTy.getElementType()), rhs);
2980+
}
29752981

29762982
rhs = promote(rewriter, loc, rhs, resTy);
29772983

0 commit comments

Comments
 (0)