Skip to content

Commit 99ff697

Browse files
[mlir][Vector] Add support for 1D depthwise conv vectorization
At this time the 2 flavors of conv are a little too different to allow significant code sharing and other will likely come up. so we go the easy route first by duplicating and adapting. Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D113758
1 parent 19c1d03 commit 99ff697

File tree

2 files changed

+205
-9
lines changed

2 files changed

+205
-9
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 155 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,16 +1390,25 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
13901390
// Convolution vectorization patterns
13911391
//===----------------------------------------------------------------------===//
13921392
namespace {
1393-
/// Generate a vector implementation for:
1393+
/// Generate a vector implementation for either:
13941394
/// ```
13951395
/// Op def: ( n, w, c, kw, f )
13961396
/// Iters: ({Par(), Par(), Par(), Red(), Red()})
13971397
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
13981398
/// ```
13991399
/// kw is unrolled, w is unrolled iff dilationW > 1.
1400-
struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
1401-
Conv1D_NWC_WCF_Generator(OpBuilder &builder, LinalgOp linalgOp, int strideW,
1402-
int dilationW)
1400+
///
1401+
/// or
1402+
///
1403+
/// ```
1404+
/// Op def: ( n, w, c, kw )
1405+
/// Iters: ({Par(), Par(), Par(), Red()})
1406+
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
1407+
/// ```
1408+
/// kw is unrolled, w is unrolled iff dilationW > 1.
1409+
struct Conv1D_NWC_Generator : public StructuredGenerator<LinalgOp> {
1410+
Conv1D_NWC_Generator(OpBuilder &builder, LinalgOp linalgOp, int strideW,
1411+
int dilationW)
14031412
: StructuredGenerator<LinalgOp>(builder, linalgOp), valid(false),
14041413
strideW(strideW), dilationW(dilationW) {
14051414
// Determine whether `linalgOp` can be generated with this generator
@@ -1413,7 +1422,8 @@ struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
14131422
resShapedType = resShaped.getType().dyn_cast<ShapedType>();
14141423
if (!lhsShapedType || !rhsShapedType || !resShapedType)
14151424
return;
1416-
if (lhsShapedType.getRank() != 3 || rhsShapedType.getRank() != 3 ||
1425+
if (lhsShapedType.getRank() != 3 ||
1426+
(rhsShapedType.getRank() != 2 && rhsShapedType.getRank() != 3) ||
14171427
resShapedType.getRank() != 3)
14181428
return;
14191429

@@ -1553,12 +1563,130 @@ struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
15531563
/*iteratorTypes=*/ArrayRef<StringRef>{par, par, par, red});
15541564
}
15551565

1566+
/// Generate a vector implementation for:
1567+
/// ```
1568+
/// Op def: ( n, w, c, kw)
1569+
/// Iters: ({Par(), Par(), Par(), Red()})
1570+
/// Layout: {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
1571+
/// ```
1572+
/// kw is always unrolled.
1573+
/// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is > 1.
1574+
FailureOr<Operation *> dilated_conv() {
1575+
if (!valid)
1576+
return failure();
1577+
1578+
int nSize = lhsShapedType.getShape()[0];
1579+
int wSize = resShapedType.getShape()[1];
1580+
int cSize = lhsShapedType.getShape()[2];
1581+
int kwSize = rhsShapedType.getShape()[0];
1582+
1583+
vector::TransferWriteOp write;
1584+
Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1585+
1586+
// w is unrolled (i.e. wSizeStep == 1) iff strideW > 1.
1587+
// When strideW == 1, we can batch the contiguous loads and avoid unrolling
1588+
int64_t wSizeStep = strideW == 1 ? wSize : 1;
1589+
1590+
Type lhsEltType = lhsShapedType.getElementType();
1591+
Type rhsEltType = rhsShapedType.getElementType();
1592+
Type resEltType = resShapedType.getElementType();
1593+
VectorType lhsType = VectorType::get(
1594+
{nSize, (wSize - 1) * strideW + 1 + (kwSize - 1) * dilationW + 1,
1595+
cSize},
1596+
lhsEltType);
1597+
VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
1598+
VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
1599+
1600+
// Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, 0].
1601+
Value lhs = builder.create<vector::TransferReadOp>(
1602+
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
1603+
// Read rhs slice of size {kw, c} @ [0, 0].
1604+
Value rhs = builder.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
1605+
ValueRange{zero, zero});
1606+
// Read res slice of size {n, w, c} @ [0, 0, 0].
1607+
Value res = builder.create<vector::TransferReadOp>(
1608+
loc, resType, resShaped, ValueRange{zero, zero, zero});
1609+
1610+
//===------------------------------------------------------------------===//
1611+
// Begin vector-only rewrite part
1612+
//===------------------------------------------------------------------===//
1613+
// Unroll along kw and read slices of lhs and rhs.
1614+
SmallVector<Value> lhsVals, rhsVals, resVals;
1615+
for (int64_t kw = 0; kw < kwSize; ++kw) {
1616+
// Extract rhs slice of size {c} @ [kw].
1617+
rhsVals.push_back(builder.create<vector::ExtractOp>(
1618+
loc, rhs, /*offsets=*/ArrayRef<int64_t>{kw}));
1619+
1620+
for (int64_t w = 0; w < wSize; w += wSizeStep) {
1621+
// Extract lhs slice of size {n, wSizeStep, c}
1622+
// @ [0, sw * w + dw * kw, 0].
1623+
lhsVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
1624+
loc, lhs,
1625+
/*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0},
1626+
/*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
1627+
/*strides=*/ArrayRef<int64_t>{1, 1, 1}));
1628+
1629+
// This does not depend on kw.
1630+
if (kw == 0) {
1631+
// Extract res slice: {n, wSizeStep, c} @ [0, w, 0].
1632+
resVals.push_back(builder.create<vector::ExtractStridedSliceOp>(
1633+
loc, res,
1634+
/*offsets=*/ArrayRef<int64_t>{0, w, 0},
1635+
/*sizes=*/ArrayRef<int64_t>{nSize, wSizeStep, cSize},
1636+
/*strides=*/ArrayRef<int64_t>{1, 1, 1}));
1637+
}
1638+
}
1639+
}
1640+
1641+
auto linearIndex = [&](int64_t kw, int64_t w) {
1642+
return kw * (wSize / wSizeStep) + w;
1643+
};
1644+
1645+
// Compute contraction: O{n, w, c} += I{n, sw * w + dw * kw, c} * F{c}
1646+
for (int64_t kw = 0; kw < kwSize; ++kw) {
1647+
for (int64_t w = 0; w < wSize; w += wSizeStep) {
1648+
resVals[w] = dilatedConv1dSliceAsContraction(
1649+
builder, loc, lhsVals[linearIndex(kw, w)], rhsVals[kw], resVals[w]);
1650+
}
1651+
}
1652+
1653+
// Write back res slice: {n, wSizeStep, c} @ [0, w, 0].
1654+
// This does not depend on kw.
1655+
for (int64_t w = 0; w < wSize; w += wSizeStep) {
1656+
res = builder.create<vector::InsertStridedSliceOp>(
1657+
loc, resVals[w], res,
1658+
/*offsets=*/ArrayRef<int64_t>{0, w, 0},
1659+
/*strides=*/ArrayRef<int64_t>{1, 1, 1});
1660+
}
1661+
//===------------------------------------------------------------------===//
1662+
// End vector-only rewrite part
1663+
//===------------------------------------------------------------------===//
1664+
1665+
// Write back res slice of size {n, w, c} @ [0, 0, 0].
1666+
return builder
1667+
.create<vector::TransferWriteOp>(loc, res, resShaped,
1668+
ValueRange{zero, zero, zero})
1669+
.getOperation();
1670+
}
1671+
1672+
// Create a contraction: lhs{n, w, c} * rhs{c} -> res{n, w, c}
1673+
vector::ContractionOp dilatedConv1dSliceAsContraction(OpBuilder &b,
1674+
Location loc, Value lhs,
1675+
Value rhs, Value res) {
1676+
StringRef par = Par().strRef, red = Red().strRef;
1677+
AffineExpr n, w, c;
1678+
bindDims(ctx, n, w, c);
1679+
return builder.create<vector::ContractionOp>(
1680+
loc, lhs, rhs, res,
1681+
/*indexingMaps=*/MapList{{n, w, c}, {c}, {n, w, c}},
1682+
/*iteratorTypes=*/ArrayRef<StringRef>{par, par, red});
1683+
}
1684+
15561685
/// Entry point that transposes into the common form:
15571686
/// {{n, strideW * w + dilationW * kw, c}, {kw, c, f}, {n, w, f}}
15581687
FailureOr<Operation *> generateConv() {
15591688
AffineExpr n, w, f, kw, c;
15601689
bindDims(ctx, n, w, f, kw, c);
1561-
15621690
if (!iters({Par(), Par(), Par(), Red(), Red()}))
15631691
return failure();
15641692

@@ -1570,6 +1698,22 @@ struct Conv1D_NWC_WCF_Generator : public StructuredGenerator<LinalgOp> {
15701698
return failure();
15711699
}
15721700

1701+
/// Entry point that transposes into the common form:
1702+
/// {{n, strideW * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
1703+
FailureOr<Operation *> generateDilatedConv() {
1704+
AffineExpr n, w, c, kw;
1705+
bindDims(ctx, n, w, c, kw);
1706+
if (!iters({Par(), Par(), Par(), Red()}))
1707+
return failure();
1708+
1709+
// No transposition needed.
1710+
if (layout({/*lhsIndex*/ {n, strideW * w + dilationW * kw, c},
1711+
/*rhsIndex*/ {kw, c},
1712+
/*resIndex*/ {n, w, c}}))
1713+
return dilated_conv();
1714+
return failure();
1715+
}
1716+
15731717
private:
15741718
bool valid;
15751719
int strideW, dilationW;
@@ -1588,8 +1732,11 @@ vectorizeConvolution(OpBuilder &b, ConvolutionOpInterface convOp) {
15881732
auto stride = strides ? *strides.getValues<uint64_t>().begin() : 1;
15891733
auto dilation = dilations ? *dilations.getValues<uint64_t>().begin() : 1;
15901734
LinalgOp linalgOp = cast<LinalgOp>(convOp.getOperation());
1591-
Conv1D_NWC_WCF_Generator e(b, linalgOp, stride, dilation);
1592-
return e.generateConv();
1735+
Conv1D_NWC_Generator e(b, linalgOp, stride, dilation);
1736+
auto res = e.generateConv();
1737+
if (succeeded(res))
1738+
return res;
1739+
return e.generateDilatedConv();
15931740
}
15941741

15951742
struct VectorizeConvolution

mlir/test/Dialect/Linalg/vectorize-convolution.mlir

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
180180
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
181181
// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]]
182182
// CHECK-SAME: : vector<4x2x3xf32>, vector<3x8xf32> into vector<4x2x8xf32>
183-
/// w == 1, kw == 1
183+
/// w == 0, kw == 1
184184
// CHECK: %[[CONTRACT_1:.+]] = vector.contract {
185185
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
186186
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
@@ -189,3 +189,52 @@ func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<2x3x8xf
189189

190190
// Write the result back in one shot.
191191
// CHECK: vector.transfer_write %[[CONTRACT_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
192+
193+
// -----
194+
195+
func @depthwise_conv1d_nwc_3x5x4_memref(%input: memref<3x5x4xf32>, %filter: memref<2x4xf32>, %output: memref<3x2x4xf32>) {
196+
linalg.depthwise_conv1D_nw
197+
{dilations = dense<2> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
198+
ins(%input, %filter : memref<3x5x4xf32>, memref<2x4xf32>)
199+
outs(%output : memref<3x2x4xf32>)
200+
return
201+
}
202+
203+
// CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
204+
// CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2) -> (d2)>
205+
206+
// CHECK: func @depthwise_conv1d_nwc_3x5x4_memref
207+
// CHECK-SAME: (%[[INPUT:[0-9a-z]+]]: memref<3x5x4xf32>, %[[FILTER:[0-9a-z]+]]: memref<2x4xf32>, %[[OUTPUT:[0-9a-z]+]]: memref<3x2x4xf32>)
208+
209+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
210+
// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
211+
212+
/// Read the whole data in one shot.
213+
// CHECK: %[[V_INPUT_R:.+]] = vector.transfer_read %[[INPUT]][%[[C0]], %[[C0]], %[[C0]]]
214+
// CHECK: %[[V_FILTER_R:.+]] = vector.transfer_read %[[FILTER]][%[[C0]], %[[C0]]]
215+
// CHECK: %[[V_OUTPUT_R:.+]] = vector.transfer_read %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]
216+
217+
/// w == 0, kw == 0
218+
// CHECK: %[[V_FILTER_0:.+]] = vector.extract %[[V_FILTER_R]][0] : vector<2x4xf32>
219+
// CHECK: %[[V_INPUT_0:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
220+
// CHECK-SAME: {offsets = [0, 0, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x5x4xf32> to vector<3x2x4xf32>
221+
/// w == 0, kw == 1
222+
// CHECK: %[[V_FILTER_1:.+]] = vector.extract %[[V_FILTER_R]][1] : vector<2x4xf32>
223+
// CHECK: %[[V_INPUT_1:.+]] = vector.extract_strided_slice %[[V_INPUT_R]]
224+
// CHECK-SAME: {offsets = [0, 2, 0], sizes = [3, 2, 4], strides = [1, 1, 1]} : vector<3x5x4xf32> to vector<3x2x4xf32>
225+
226+
/// w == 0, kw == 0
227+
// CHECK: %[[CONTRACT_0:.+]] = vector.contract {
228+
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[INPUT_MAP]]],
229+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
230+
// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER_0]], %[[V_OUTPUT_R]]
231+
// CHECK-SAME: : vector<3x2x4xf32>, vector<4xf32> into vector<3x2x4xf32>
232+
/// w == 0, kw == 1
233+
// CHECK: %[[CONTRACT_1:.+]] = vector.contract {
234+
// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[INPUT_MAP]]],
235+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
236+
// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER_1]], %[[CONTRACT_0]]
237+
// CHECK-SAME: : vector<3x2x4xf32>, vector<4xf32> into vector<3x2x4xf32>
238+
239+
// Write the result back in one shot.
240+
// CHECK: vector.transfer_write %[[CONTRACT_1]], %[[OUTPUT]][%[[C0]], %[[C0]], %[[C0]]]

0 commit comments

Comments
 (0)