Skip to content

Commit 1fa83da

Browse files
committed
[mlir][linalg] Implement Conv2D using Winograd Conv2D algorithm
Define high level winograd operators and convert conv_2d_nhwc_fhwc into winograd operators. According to Winograd Conv2D algorithm, we need three transform operators for input, filter, and output transformation. The formula of Winograd Conv2D algorithm is Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A filter transform: G x g x G^T input transform: B^T x d x B output transform: A^T x y x A The implementation is based on the paper, Fast Algorithm for Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
1 parent 975f0a9 commit 1fa83da

File tree

7 files changed

+769
-0
lines changed

7 files changed

+769
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,121 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
154154
let hasVerifier = 1;
155155
}
156156

157+
def Linalg_WinogradFilterTransformOp :
158+
Linalg_Op<"winograd_filter_transform", [AllElementTypesMatch<["filter", "output"]>]> {
159+
let summary = "Winograd filter transform operator";
160+
let description = [{
161+
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
162+
matrix multiply. Before the matrix multiply, it will convert filter and
163+
input into a format suitable for batched matrix multiply. After the matrix
164+
multiply, it will convert output to the final result tensor.
165+
166+
The algorithm F(m x m, r x r) is
167+
168+
Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
169+
170+
The size of output Y is m x m. The size of filter g is r x r. The size of
171+
input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
172+
transformation matrices.
173+
174+
This operator is defined to represent the high level concept of filter
175+
transformation (G x g x G^T) in the Winograd Conv2D algorithm.
176+
}];
177+
178+
let arguments = (ins TensorRankOf<[AnyType], [4]>:$filter,
179+
TensorRankOf<[AnyType], [4]>:$output,
180+
I64Attr:$m,
181+
I64Attr:$r
182+
);
183+
184+
let results = (outs TensorRankOf<[AnyType], [4]>:$result);
185+
let assemblyFormat = [{
186+
attr-dict
187+
`m` `(` $m `)`
188+
`r` `(` $r `)`
189+
`ins` `(` $filter `:` type($filter) `)`
190+
`outs` `(` $output `:` type($output) `)`
191+
`->` type($result)
192+
}];
193+
let hasVerifier = 1;
194+
}
195+
196+
def Linalg_WinogradInputTransformOp :
197+
Linalg_Op<"winograd_input_transform", [AllElementTypesMatch<["input", "output"]>]> {
198+
let summary = "Winograd input transform operator";
199+
let description = [{
200+
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
201+
matrix multiply. Before the matrix multiply, it will convert filter and
202+
input into a format suitable for batched matrix multiply. After the matrix
203+
multiply, it will convert output to the final result tensor.
204+
205+
The algorithm F(m x m, r x r) is
206+
207+
Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
208+
209+
The size of output Y is m x m. The size of filter g is r x r. The size of
210+
input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
211+
transformation matrices.
212+
213+
This operator is defined to represent the high level concept of input
214+
transformation (B^T x d x B) in the Winograd Conv2D algorithm.
215+
}];
216+
217+
let arguments = (ins TensorRankOf<[AnyType], [4]>:$input,
218+
TensorRankOf<[AnyType], [6]>:$output,
219+
I64Attr:$m,
220+
I64Attr:$r
221+
);
222+
223+
let results = (outs TensorRankOf<[AnyType], [6]>:$result);
224+
let assemblyFormat = [{
225+
attr-dict
226+
`m` `(` $m `)`
227+
`r` `(` $r `)`
228+
`ins` `(` $input `:` type($input) `)`
229+
`outs` `(` $output `:` type($output) `)`
230+
`->` type($result)
231+
}];
232+
let hasVerifier = 1;
233+
}
234+
235+
def Linalg_WinogradOutputTransformOp :
236+
Linalg_Op<"winograd_output_transform", [AllElementTypesMatch<["value", "output"]>]> {
237+
let summary = "Winograd output transform operator";
238+
let description = [{
239+
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
240+
matrix multiply. Before the matrix multiply, it will convert filter and
241+
input into a format suitable for batched matrix multiply. After the matrix
242+
multiply, it will convert output to the final result tensor.
243+
244+
The algorithm F(m x m, r x r) is
245+
246+
Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
247+
248+
The size of output Y is m x m. The size of filter g is r x r. The size of
249+
input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
250+
transformation matrices.
251+
252+
This operator is defined to represent the high level concept of output
253+
transformation (A^T x y x A) in the Winograd Conv2D algorithm.
254+
}];
255+
256+
let arguments = (ins TensorRankOf<[AnyType], [6]>:$value,
257+
TensorRankOf<[AnyType], [4]>:$output,
258+
I64Attr:$m,
259+
I64Attr:$r
260+
);
261+
262+
let results = (outs TensorRankOf<[AnyType], [4]>:$result);
263+
let assemblyFormat = [{
264+
attr-dict
265+
`m` `(` $m `)`
266+
`r` `(` $r `)`
267+
`ins` `(` $value `:` type($value) `)`
268+
`outs` `(` $output `:` type($output) `)`
269+
`->` type($result)
270+
}];
271+
let hasVerifier = 1;
272+
}
273+
157274
#endif // LINALG_OPS

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1692,6 +1692,10 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
16921692
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
16931693
const ControlBlockPackMatmulFn &controlFn);
16941694

1695+
/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
1696+
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
1697+
int64_t r);
1698+
16951699
} // namespace linalg
16961700
} // namespace mlir
16971701

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2734,6 +2734,113 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
27342734
return SmallVector<Value>{result};
27352735
}
27362736

2737+
//===----------------------------------------------------------------------===//
2738+
// WinogradFilterTransformOp
2739+
//===----------------------------------------------------------------------===//
2740+
2741+
LogicalResult WinogradFilterTransformOp::verify() {
2742+
auto filterType = cast<ShapedType>(getFilter().getType());
2743+
ArrayRef<int64_t> filterShape = filterType.getShape();
2744+
int64_t filterH = filterShape[1];
2745+
int64_t filterW = filterShape[2];
2746+
int64_t r = getR();
2747+
2748+
if (filterH != r && filterH != 1)
2749+
return failure();
2750+
if (filterW != r && filterW != 1)
2751+
return failure();
2752+
if (filterH == 1 && filterW == 1)
2753+
return failure();
2754+
2755+
return success();
2756+
}
2757+
2758+
//===----------------------------------------------------------------------===//
2759+
// WinogradInputTransformOp
2760+
//===----------------------------------------------------------------------===//
2761+
2762+
LogicalResult WinogradInputTransformOp::verify() {
2763+
auto inputType = cast<ShapedType>(getInput().getType());
2764+
ArrayRef<int64_t> inputShape = inputType.getShape();
2765+
int64_t inputH = inputShape[1];
2766+
int64_t inputW = inputShape[2];
2767+
auto outputType = cast<ShapedType>(getOutput().getType());
2768+
ArrayRef<int64_t> outputShape = outputType.getShape();
2769+
int64_t outputH = outputShape[0];
2770+
int64_t outputW = outputShape[1];
2771+
int64_t outputTileH = outputShape[2];
2772+
int64_t outputTileW = outputShape[3];
2773+
int m = getM();
2774+
int r = getR();
2775+
bool leftTransform = inputH != 1;
2776+
bool rightTransform = inputW != 1;
2777+
2778+
if (!leftTransform && !rightTransform)
2779+
return failure();
2780+
2781+
if (leftTransform) {
2782+
int64_t tileH = (inputH - (r - 1)) / m;
2783+
if (inputH != tileH * m + (r - 1))
2784+
return failure();
2785+
if (tileH != outputTileH)
2786+
return failure();
2787+
if (outputH != m + r - 1)
2788+
return failure();
2789+
}
2790+
2791+
if (rightTransform) {
2792+
int64_t tileW = (inputW - (r - 1)) / m;
2793+
if (inputW != tileW * m + (r - 1))
2794+
return failure();
2795+
if (tileW != outputTileW)
2796+
return failure();
2797+
if (outputW != m + r - 1)
2798+
return failure();
2799+
}
2800+
2801+
return success();
2802+
}
2803+
2804+
//===----------------------------------------------------------------------===//
2805+
// WinogradOutputTransformOp
2806+
//===----------------------------------------------------------------------===//
2807+
2808+
LogicalResult WinogradOutputTransformOp::verify() {
2809+
auto valueType = cast<ShapedType>(getValue().getType());
2810+
ArrayRef<int64_t> valueShape = valueType.getShape();
2811+
int64_t valueH = valueShape[0];
2812+
int64_t valueW = valueShape[1];
2813+
int64_t valueTileH = valueShape[2];
2814+
int64_t valueTileW = valueShape[3];
2815+
auto outputType = cast<ShapedType>(getOutput().getType());
2816+
ArrayRef<int64_t> outputShape = outputType.getShape();
2817+
int64_t outputH = outputShape[1];
2818+
int64_t outputW = outputShape[2];
2819+
int m = getM();
2820+
int r = getR();
2821+
bool leftTransform = valueH != 1;
2822+
bool rightTransform = valueW != 1;
2823+
2824+
if (!leftTransform && !rightTransform)
2825+
return failure();
2826+
2827+
if (leftTransform) {
2828+
if (valueH != m + r - 1)
2829+
return failure();
2830+
if (outputH != m * valueTileH)
2831+
return failure();
2832+
}
2833+
2834+
if (rightTransform) {
2835+
if (valueW != m + r - 1)
2836+
return failure();
2837+
if (outputW != m * valueTileW)
2838+
return failure();
2839+
}
2840+
2841+
return success();
2842+
}
2843+
27372844
//===----------------------------------------------------------------------===//
27382845
// LinalgDialect
27392846
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
3838
Transforms.cpp
3939
TransposeConv2D.cpp
4040
Vectorization.cpp
41+
WinogradConv2D.cpp
4142

4243
ADDITIONAL_HEADER_DIRS
4344
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg

0 commit comments

Comments
 (0)