Skip to content

Commit 83c56aa

Browse files
committed
[mlir][linalg] Add depthwise_conv_2d_input_nhwc_filter_hwcf to Linalg TC ops.
Different from the definition in Tensorflow and TOSA, the output is [N,H,W,C,M]. This can make transforms easier in LinAlg because the indexing maps are plain. E.g., to determine if the fill op has dependency between the depthwise conv op, the current pipeline only recognizes the dep if they are all projected affine map. Reviewed By: asaadaldien Differential Revision: https://reviews.llvm.org/D97798
1 parent 8c3a70a commit 83c56aa

File tree

3 files changed

+99
-2
lines changed

3 files changed

+99
-2
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,45 @@ def conv_3d_ncdhw(I: f32(N, C, D, H, W), K: f32(F, C, KD, KH, KW)) -> (O: f32(N,
151151
std_mulf(I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw)));
152152
}
153153

154+
ods_def<DepthwiseConvInputNHWCFilterHWCFOp>:
155+
def depthwise_conv_2d_input_nhwc_filter_hwcf
156+
(I: f32(N, IH, IW, CI), K: f32(KH, KW, CI, CO))
157+
-> (O: f32(N, OH, OW, CI, CO))
158+
attr(strides: 2xi64)
159+
"""A general depth-wise 2-D convolution operation.
160+
161+
This operation performs depth-wise 2-D convolution over an input `I` and filter
162+
`F` and generates output `O` using the following computation:
163+
164+
```
165+
O(n, oh, ow, ci, co) = std_addf<kh, kw>(
166+
O(n, oh, ow, ci, co),
167+
std_mulf(I(n, oh * strides[0] + kh, ow * strides[1] + kw, ci),
168+
K(kh, kw, ci, co)));
169+
```
170+
171+
where
172+
173+
* `I` is a 4-D tensor with shape `(N, IH, IW, CI)`.
174+
* `F` is a 4-D tensor with shape `(KH, KW, CI, CO)`.
175+
* `O` is a 5-D tensor with shape `(N, OH, OW, CI, CO)`.
176+
* `strides` is a 2-element vector attribute for window strides along the
177+
height/width dimension.
178+
179+
The indexing maps for these three tensors contain 7 dimensions, following the
180+
order of (`N`, `OH`, `OW`, `CI`, `CO`, `KH`, `KW`).
181+
182+
Note: this op only supports any channel multiplier, which is `CO`. To map back
183+
to 4D result as DepthwiseConvInputNHWCFilterHWCOp, you will have to create a
184+
Linalg reshape op which collapses `CI` and `CO` into one dimension.
185+
"""
186+
{
187+
O(n, oh, ow, ci, co) = std_addf<kh, kw>(
188+
O(n, oh, ow, ci, co),
189+
std_mulf(I(n, oh * strides[0] + kh, ow * strides[1] + kw, ci),
190+
K(kh, kw, ci, co)));
191+
}
192+
154193
ods_def<DepthwiseConvInputNHWCFilterHWCOp>:
155194
def depthwise_conv_2d_input_nhwc_filter_hwc
156195
(I: f32(N, IH, IW, C), K: f32(KH, KW, C))
@@ -162,8 +201,10 @@ This operation performs depth-wise 2-D convolution over an input `I` and filter
162201
`F` and generates output `O` using the following computation:
163202

164203
```
165-
O(n, oh, ow, c) = std_addf<kh, kw>(std_mulf(
166-
I(n, oh * strides[0] + kh, ow * strides[1] + kw, c), K(kh, kw, c)))
204+
O(n, oh, ow, c) = std_addf<kh, kw>(
205+
O(n, oh, ow, c),
206+
std_mulf(I(n, oh * strides[0] + kh, ow * strides[1] + kw, c),
207+
K(kh, kw, c)));
167208
```
168209

169210
where

mlir/test/Dialect/Linalg/generalize-named-ops.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,33 @@ func @generalize_matmul_tensor(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C:
7676

7777
// -----
7878

79+
func @depthwise_conv_2d_input_nhwc_filter_hwcf(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) {
80+
linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
81+
{ strides = dense<1> : tensor<2xi64> }
82+
ins(%input, %filter : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
83+
outs(%output : memref<2x3x4x2x3xf32>)
84+
return
85+
}
86+
87+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
88+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
89+
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
90+
91+
// CHECK: func @depthwise_conv_2d_input_nhwc_filter_hwcf
92+
93+
// CHECK: linalg.generic
94+
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
95+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
96+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
97+
// CHECK-SAME: outs(%{{.+}} : memref<2x3x4x2x3xf32>)
98+
99+
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
100+
// CHECK-NEXT: %[[MUL:.+]] = mulf %[[BBARG0]], %[[BBARG1]] : f32
101+
// CHECK-NEXT: %[[ADD:.+]] = addf %[[BBARG2]], %[[MUL]] : f32
102+
// CHECK-NEXT: linalg.yield %[[ADD]] : f32
103+
104+
// -----
105+
79106
func @depthwise_conv_2d_input_nhwc_filter_hwc(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
80107
linalg.depthwise_conv_2d_input_nhwc_filter_hwc {strides = dense<2> : vector<2xi64>}
81108
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)

mlir/test/Dialect/Linalg/named-ops.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,34 @@
11
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
22

3+
// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwcf_tensor
4+
func @depthwise_conv_2d_input_nhwc_filter_hwcf_tensor(%input: tensor<2x4x5x2xf32>, %filter: tensor<2x2x2x3xf32>) -> tensor<2x3x4x2x3xf32> {
5+
%zero = constant 0.000000e+00 : f32
6+
%init = linalg.init_tensor [2, 3, 4, 2, 3] : tensor<2x3x4x2x3xf32>
7+
%fill = linalg.fill(%init, %zero) : tensor<2x3x4x2x3xf32>, f32 -> tensor<2x3x4x2x3xf32>
8+
// CHECK: %{{.+}} = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
9+
// CHECK-SAME: {strides = dense<1> : tensor<2xi64>}
10+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>)
11+
// CHECK-SAME: outs(%{{.+}} : tensor<2x3x4x2x3xf32>)
12+
%0 = linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
13+
{ strides = dense<1> : tensor<2xi64> }
14+
ins(%input, %filter : tensor<2x4x5x2xf32>, tensor<2x2x2x3xf32>)
15+
outs(%fill : tensor<2x3x4x2x3xf32>) -> tensor<2x3x4x2x3xf32>
16+
return %0 : tensor<2x3x4x2x3xf32>
17+
}
18+
19+
// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref
20+
func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) {
21+
// CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
22+
// CHECK-SAME: {strides = dense<1> : tensor<2xi64>}
23+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
24+
// CHECK-SAME: outs(%{{.+}} : memref<2x3x4x2x3xf32>)
25+
linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
26+
{ strides = dense<1> : tensor<2xi64> }
27+
ins(%input, %filter : memref<2x4x5x2xf32>, memref<2x2x2x3xf32>)
28+
outs(%output : memref<2x3x4x2x3xf32>)
29+
return
30+
}
31+
332
// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwc_tensor
433
func @depthwise_conv_2d_input_nhwc_filter_hwc_tensor(%input: tensor<1x113x113x96xf32>, %filter: tensor<3x3x96xf32>) -> tensor<1x56x56x96xf32> {
534
%init = linalg.init_tensor [1, 56, 56, 96] : tensor<1x56x56x96xf32>

0 commit comments

Comments
 (0)