Skip to content

Commit da7b6fe

Browse files
author
MaheshRavishankar
committed
[mlir][Linalg] Allow tiling of batch dimension for convolution ops with padding.
Existing tiling implementation of Linalg would still work for tiling the batch dimensions of the convolution op. Differential Revision: https://reviews.llvm.org/D76637
1 parent b0cd7b2 commit da7b6fe

File tree

2 files changed

+54
-9
lines changed

2 files changed

+54
-9
lines changed

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -336,9 +336,12 @@ Optional<TiledLinalgOp> static tileLinalgOpImpl(OpBuilder &b, LinalgOp op,
336336
"expected matching number of tile sizes and loops");
337337

338338
if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
339-
// TODO(ntv): add a level of indirection to linalg.generic.
340-
if (convOp.padding())
341-
llvm_unreachable("Unexpected conv with padding");
339+
// For conv op only support tiling along batch dimension (which is the first
340+
// loop).
341+
if (convOp.padding() &&
342+
!llvm::all_of(tileSizes.drop_front(),
343+
[](Value val) { return isZero(val); }))
344+
return llvm::None;
342345
}
343346

344347
// If permutation is empty, use the identity. Build the permutation map
@@ -420,12 +423,6 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
420423
if (tileSizes.empty())
421424
return llvm::None;
422425

423-
if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
424-
// TODO(ntv): add a level of indirection to linalg.generic.
425-
if (convOp.padding())
426-
llvm_unreachable("Unexpected conv with padding");
427-
}
428-
429426
// The following uses the convention that "tiling by zero" skips tiling a
430427
// particular dimension. This convention is significantly simpler to handle
431428
// instead of adjusting affine maps to account for missing dimensions.
@@ -436,6 +433,14 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ArrayRef<int64_t> tileSizes,
436433
if (llvm::all_of(tileSizes, [](int64_t v) { return v == 0; }))
437434
return llvm::None;
438435

436+
if (auto convOp = dyn_cast<linalg::ConvOp>(op.getOperation())) {
437+
// For conv op only support tiling along batch dimension (which is the first
438+
// loop).
439+
if (convOp.padding() && !llvm::all_of(tileSizes.drop_front(),
440+
[](int64_t val) { return val == 0; }))
441+
return llvm::None;
442+
}
443+
439444
// Create a builder for tile size constants.
440445
OpBuilder::InsertionGuard g(b);
441446
b.setInsertionPoint(op);
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,0,0,4" | FileCheck %s -check-prefix=TILE-23004
2+
// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2" | FileCheck %s -check-prefix=TILE-20000
3+
4+
// TILE-23004-DAG: #[[strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>
5+
// TILE-20000-DAG: #[[strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>
6+
// TILE-20000-DAG: #[[minmap:.*]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)>
7+
// TILE-20000-DAG: #[[subviewstride:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)>
8+
9+
func @conv_padding(%arg0: memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, %arg1: memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, %arg2: memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>) {
10+
linalg.conv(%arg0, %arg1, %arg2) {dilations = [10, 20], padding = dense<[[1, 1], [0, 1]]> : tensor<2x2xi64>, strides = [30, 40]} : memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>, memref<?x?x?x?xf32, offset: ?, strides: [?, ?, ?, 1]>
11+
return
12+
}
13+
// TILE-23004-LABEL: func @conv_padding(
14+
// TILE-23004-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?x?x?xf32, #[[strided4D]]>
15+
// TILE-23004-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?x?x?xf32, #[[strided4D]]>
16+
// TILE-23004-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?x?x?xf32, #[[strided4D]]>)
17+
// TILE-23004: linalg.conv(%[[ARG0]], %[[ARG1]], %[[ARG2]])
18+
19+
// TILE-20000-LABEL: func @conv_padding(
20+
// TILE-20000-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<?x?x?x?xf32, #[[strided4D]]>
21+
// TILE-20000-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref<?x?x?x?xf32, #[[strided4D]]>
22+
// TILE-20000-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref<?x?x?x?xf32, #[[strided4D]]>)
23+
// TILE-20000-DAG: %[[C0:.*]] = constant 0 : index
24+
// TILE-20000-DAG: %[[C1:.*]] = constant 1 : index
25+
// TILE-20000-DAG: %[[C2:.*]] = constant 2 : index
26+
// TILE-20000: %[[B:.*]] = dim %[[ARG1]], 0
27+
// TILE-20000: loop.for %[[ivI:.*]] = %[[C0]] to %[[B]] step %[[C2]] {
28+
// TILE-20000: %[[DIM10:.*]] = dim %[[ARG1]], 0
29+
// TILE-20000: %[[EXTENT:.*]] = affine.min #[[minmap]](%[[C2]], %[[DIM10]], %[[ivI]])
30+
// TILE-20000: %[[DIM11:.*]] = dim %[[ARG1]], 1
31+
// TILE-20000: %[[DIM12:.*]] = dim %[[ARG1]], 2
32+
// TILE-20000: %[[DIM13:.*]] = dim %[[ARG1]], 3
33+
// TILE-20000: %[[SUBVIEW1:.*]] = subview %[[ARG1]][%[[ivI]], %[[C0]], %[[C0]], %[[C0]]] [%[[EXTENT]], %[[DIM11]], %[[DIM12]], %[[DIM13]]]
34+
// TILE-20000: %[[DIM20:.*]] = dim %[[ARG2]], 0
35+
// TILE-20000: %[[EXTENT:.*]] = affine.min #[[minmap]](%[[C2]], %[[DIM20]], %[[ivI]])
36+
// TILE-20000: %[[DIM21:.*]] = dim %[[ARG2]], 1
37+
// TILE-20000: %[[DIM22:.*]] = dim %[[ARG2]], 2
38+
// TILE-20000: %[[DIM23:.*]] = dim %[[ARG2]], 3
39+
// TILE-20000: %[[SUBVIEW2:.*]] = subview %[[ARG2]][%[[ivI]], %[[C0]], %[[C0]], %[[C0]]] [%[[EXTENT]], %[[DIM21]], %[[DIM22]], %[[DIM23]]]
40+
// TILE-20000: linalg.conv(%[[ARG0]], %[[SUBVIEW1]], %[[SUBVIEW2]])

0 commit comments

Comments
 (0)