Skip to content

Commit 641fe70

Browse files
[mlir][Linalg] Fix and improve vectorization of depthwise convolutions.
When trying to connect the vectorization of depthwise convolutions to e2e execution a number of problems surfaced. Fix an off-by-one error on the size of the input vector (similary to what was previously done for regular conv). Rewrite the lowering to vector.fma instead of vector.contract: the KW reduction dimension has already been unrolled and vector.contract requires a reduction dimension to be valid. Differential Revision: https://reviews.llvm.org/D113884
1 parent ee80ffb commit 641fe70

File tree

2 files changed

+23
-35
lines changed

2 files changed

+23
-35
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1477,7 +1477,7 @@ struct Conv1D_NWC_Generator : public StructuredGenerator<LinalgOp> {
14771477
{nSize,
14781478
// iw = ow * sw + kw * dw - 1
14791479
// (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
1480-
// Perform the proper inclusive -> exclusive -> inclusive
1480+
// Perform the proper inclusive -> exclusive -> inclusive.
14811481
((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
14821482
cSize},
14831483
lhsEltType);
@@ -1557,9 +1557,8 @@ struct Conv1D_NWC_Generator : public StructuredGenerator<LinalgOp> {
15571557
}
15581558

15591559
// Create a contraction: lhs{n, w, c} * rhs{c, f} -> res{n, w, f}
1560-
vector::ContractionOp conv1dSliceAsContraction(OpBuilder &b, Location loc,
1561-
Value lhs, Value rhs,
1562-
Value res) {
1560+
Value conv1dSliceAsContraction(OpBuilder &b, Location loc, Value lhs,
1561+
Value rhs, Value res) {
15631562
StringRef par = Par().strRef, red = Red().strRef;
15641563
AffineExpr n, w, f, c;
15651564
bindDims(ctx, n, w, f, c);
@@ -1597,7 +1596,10 @@ struct Conv1D_NWC_Generator : public StructuredGenerator<LinalgOp> {
15971596
Type rhsEltType = rhsShapedType.getElementType();
15981597
Type resEltType = resShapedType.getElementType();
15991598
VectorType lhsType = VectorType::get(
1600-
{nSize, (wSize - 1) * strideW + 1 + (kwSize - 1) * dilationW + 1,
1599+
{nSize,
1600+
// iw = ow * sw + kw * dw - 1
1601+
// (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
1602+
((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
16011603
cSize},
16021604
lhsEltType);
16031605
VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
@@ -1651,7 +1653,7 @@ struct Conv1D_NWC_Generator : public StructuredGenerator<LinalgOp> {
16511653
// Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
16521654
for (int64_t kw = 0; kw < kwSize; ++kw) {
16531655
for (int64_t w = 0; w < wSize; w += wSizeStep) {
1654-
resVals[w] = dilatedConv1dSliceAsContraction(
1656+
resVals[w] = dilatedConv1dSliceAsFma(
16551657
builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
16561658
}
16571659
}
@@ -1675,17 +1677,11 @@ struct Conv1D_NWC_Generator : public StructuredGenerator<LinalgOp> {
16751677
.getOperation();
16761678
}
16771679

1678-
// Create a contraction: lhs{n, w, c} * rhs{c} -> res{n, w, c}
1679-
vector::ContractionOp dilatedConv1dSliceAsContraction(OpBuilder &b,
1680-
Location loc, Value lhs,
1681-
Value rhs, Value res) {
1682-
StringRef par = Par().strRef, red = Red().strRef;
1683-
AffineExpr n, w, c;
1684-
bindDims(ctx, n, w, c);
1685-
return builder.create<vector::ContractionOp>(
1686-
loc, lhs, rhs, res,
1687-
/*indexingMaps=*/MapList{{n, w, c}, {c}, {n, w, c}},
1688-
/*iteratorTypes=*/ArrayRef<StringRef>{par, par, red});
1680+
/// Lower lhs{n, w, c} * rhs{c} -> res{n, w, c} to fma.
1681+
Value dilatedConv1dSliceAsFma(OpBuilder &b, Location loc, Value lhs,
1682+
Value rhs, Value res) {
1683+
Value bcast = builder.create<vector::BroadcastOp>(loc, res.getType(), rhs);
1684+
return b.create<vector::FMAOp>(loc, lhs, bcast, res);
16891685
}
16901686

16911687
/// Entry point that transposes into the common form:

mlir/test/Dialect/Linalg/vectorize-convolution.mlir

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,6 @@ func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filter: m
200200
return
201201
}
202202

203-
// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
204-
// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2)>
205-
206203
// CHECK: func @depthwise_conv1d_nwc_wc_3x5x4_memref
207204
// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xf32>, %[[FILTER:[0-9a-z]+]]: memref<2x4xf32>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xf32>)
208205

@@ -217,24 +214,19 @@ func @depthwise_conv1d_nwc_wc_3x5x4_memref(%input: memref<3x5x4xf32>, %filter: m
217214
/// w == 0, kw == 0
218215
// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x4xf32>
219216
// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
220-
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x5x4xf32> to vector<3x2x4xf32>
217+
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
221218
/// w == 0, kw == 1
222219
// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x4xf32>
223220
// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
224-
// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x5x4xf32> to vector<3x2x4xf32>
221+
// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x4x4xf32> to vector<3x2x4xf32>
225222

226-
/// w == 0, kw == 0
227-
// CHECK: %[[CONTRACT_0:.+]] = vector.contract {
228-
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[INPUT_MAP]]],
229-
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
230-
// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]]
231-
// CHECK-SAME: : vector<3x2x4xf32>, vector<4xf32> into vector<3x2x4xf32>
232-
/// w == 0, kw == 1
233-
// CHECK: %[[CONTRACT_1:.+]] = vector.contract {
234-
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[INPUT_MAP]]],
235-
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
236-
// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER_1]], %[[CONTRACT_0]]
237-
// CHECK-SAME: : vector<3x2x4xf32>, vector<4xf32> into vector<3x2x4xf32>
223+
/// w == 0, kw = 0
224+
// CHECK: %[[B_FILTER_0:.*]] = vector.broadcast %[[V_FILTER_0]] : vector<4xf32> to vector<3x2x4xf32>
225+
// CHECK: %[[FMA_0:.*]] = vector.fma %[[V_INPUT_0]], %[[B_FILTER_0]], %[[V_OUTPUT_R]] : vector<3x2x4xf32>
226+
227+
/// w == 0, kw = 1
228+
// CHECK: %[[B_FILTER_1:.*]] = vector.broadcast %[[V_FILTER_1]] : vector<4xf32> to vector<3x2x4xf32>
229+
// CHECK: %[[FMA_1:.*]] = vector.fma %[[V_INPUT_1]], %[[B_FILTER_1]], %[[FMA_0]] : vector<3x2x4xf32>
238230

239231
// Write the result back in one shot.
240-
// CHECK: vector.transfer_write %[[CONTRACT_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
232+
// CHECK: vector.transfer_write %[[FMA_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]

0 commit comments

Comments
 (0)