Skip to content

Commit 0c54240

Browse files
committed
[mlir][linalg] Implement Conv2D using Winograd Conv2D algorithm
Define high level winograd operators and convert conv_2d_nhwc_fhwc into winograd operators. According to Winograd Conv2D algorithm, we need three transform operators for input, filter, and output transformation. The formula of Winograd Conv2D algorithm is Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A filter transform: G x g x G^T input transform: B^T x d x B output transform: A^T x y x A The implementation is based on the paper, Fast Algorithm for Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
1 parent 90779fd commit 0c54240

File tree

7 files changed

+779
-0
lines changed

7 files changed

+779
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,118 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
154154
let hasVerifier = 1;
155155
}
156156

157+
def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
158+
let summary = "Winograd filter transform operator";
159+
let description = [{
160+
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
161+
matrix multiply. Before the matrix multiply, it will convert filter and
162+
input into a format suitable for batched matrix multiply. After the matrix
163+
multiply, it will convert output to the final result tensor.
164+
165+
The algorithm F(m x m, r x r) is
166+
167+
Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
168+
169+
The size of output Y is m x m. The size of filter g is r x r. The size of
170+
input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
171+
transformation matrices.
172+
173+
This operator is defined to represent the high level concept of filter
174+
transformation (G x g x G^T) in the Winograd Conv2D algorithm.
175+
}];
176+
177+
let arguments = (ins AnyRankedTensor:$filter,
178+
AnyRankedTensor:$output,
179+
I64Attr:$m,
180+
I64Attr:$r
181+
);
182+
183+
let results = (outs AnyRankedTensor:$result);
184+
let assemblyFormat = [{
185+
attr-dict
186+
`m` `(` $m `)`
187+
`r` `(` $r `)`
188+
`ins` `(` $filter `:` type($filter) `)`
189+
`outs` `(` $output `:` type($output) `)`
190+
`->` type($result)
191+
}];
192+
let hasVerifier = 1;
193+
}
194+
195+
def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
196+
let summary = "Winograd input transform operator";
197+
let description = [{
198+
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
199+
matrix multiply. Before the matrix multiply, it will convert filter and
200+
input into a format suitable for batched matrix multiply. After the matrix
201+
multiply, it will convert output to the final result tensor.
202+
203+
The algorithm F(m x m, r x r) is
204+
205+
Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
206+
207+
The size of output Y is m x m. The size of filter g is r x r. The size of
208+
input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
209+
transformation matrices.
210+
211+
This operator is defined to represent the high level concept of input
212+
transformation (B^T x d x B) in the Winograd Conv2D algorithm.
213+
}];
214+
215+
let arguments = (ins AnyRankedTensor:$input,
216+
AnyRankedTensor:$output,
217+
I64Attr:$m,
218+
I64Attr:$r
219+
);
220+
221+
let results = (outs AnyRankedTensor:$result);
222+
let assemblyFormat = [{
223+
attr-dict
224+
`m` `(` $m `)`
225+
`r` `(` $r `)`
226+
`ins` `(` $input `:` type($input) `)`
227+
`outs` `(` $output `:` type($output) `)`
228+
`->` type($result)
229+
}];
230+
let hasVerifier = 1;
231+
}
232+
233+
def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> {
234+
let summary = "Winograd output transform operator";
235+
let description = [{
236+
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
237+
matrix multiply. Before the matrix multiply, it will convert filter and
238+
input into a format suitable for batched matrix multiply. After the matrix
239+
multiply, it will convert output to the final result tensor.
240+
241+
The algorithm F(m x m, r x r) is
242+
243+
Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
244+
245+
The size of output Y is m x m. The size of filter g is r x r. The size of
246+
input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
247+
transformation matrices.
248+
249+
This operator is defined to represent the high level concept of output
250+
transformation (A^T x y x A) in the Winograd Conv2D algorithm.
251+
}];
252+
253+
let arguments = (ins AnyRankedTensor:$value,
254+
AnyRankedTensor:$output,
255+
I64Attr:$m,
256+
I64Attr:$r
257+
);
258+
259+
let results = (outs AnyRankedTensor:$result);
260+
let assemblyFormat = [{
261+
attr-dict
262+
`m` `(` $m `)`
263+
`r` `(` $r `)`
264+
`ins` `(` $value `:` type($value) `)`
265+
`outs` `(` $output `:` type($output) `)`
266+
`->` type($result)
267+
}];
268+
let hasVerifier = 1;
269+
}
270+
157271
#endif // LINALG_OPS

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1692,6 +1692,10 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
16921692
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
16931693
const ControlBlockPackMatmulFn &controlFn);
16941694

1695+
/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
1696+
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
1697+
int64_t r);
1698+
16951699
} // namespace linalg
16961700
} // namespace mlir
16971701

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2734,6 +2734,84 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
27342734
return SmallVector<Value>{result};
27352735
}
27362736

2737+
//===----------------------------------------------------------------------===//
2738+
// WinogradFilterTransformOp
2739+
//===----------------------------------------------------------------------===//
2740+
2741+
LogicalResult WinogradFilterTransformOp::verify() {
2742+
auto filterType = cast<ShapedType>(getFilter().getType());
2743+
auto outputType = cast<ShapedType>(getOutput().getType());
2744+
auto filterElemType = filterType.getElementType();
2745+
auto outputElemType = outputType.getElementType();
2746+
if (filterElemType != outputElemType) {
2747+
return emitOpError() << "expected element type of input " << filterElemType
2748+
<< " to match element type of output "
2749+
<< outputElemType;
2750+
}
2751+
2752+
unsigned filterRank = filterType.getRank();
2753+
if (filterRank != 4)
2754+
return emitOpError() << "expected rank of input is 4";
2755+
2756+
unsigned outputRank = outputType.getRank();
2757+
if (outputRank != 6)
2758+
return emitOpError() << "expected rank of output is 6";
2759+
2760+
return success();
2761+
}
2762+
2763+
//===----------------------------------------------------------------------===//
2764+
// WinogradInputTransformOp
2765+
//===----------------------------------------------------------------------===//
2766+
2767+
LogicalResult WinogradInputTransformOp::verify() {
2768+
auto inputType = cast<ShapedType>(getInput().getType());
2769+
auto outputType = cast<ShapedType>(getOutput().getType());
2770+
auto inputElemType = inputType.getElementType();
2771+
auto outputElemType = outputType.getElementType();
2772+
if (inputElemType != outputElemType) {
2773+
return emitOpError() << "expected element type of input " << inputElemType
2774+
<< " to match element type of output "
2775+
<< outputElemType;
2776+
}
2777+
2778+
unsigned inputRank = inputType.getRank();
2779+
if (inputRank != 4)
2780+
return emitOpError() << "expected rank of input is 4";
2781+
2782+
unsigned outputRank = outputType.getRank();
2783+
if (outputRank != 6)
2784+
return emitOpError() << "expected rank of output is 6";
2785+
2786+
return success();
2787+
}
2788+
2789+
//===----------------------------------------------------------------------===//
2790+
// WinogradOutputTransformOp
2791+
//===----------------------------------------------------------------------===//
2792+
2793+
LogicalResult WinogradOutputTransformOp::verify() {
2794+
auto valueType = cast<ShapedType>(getValue().getType());
2795+
auto outputType = cast<ShapedType>(getOutput().getType());
2796+
auto valueElemType = valueType.getElementType();
2797+
auto outputElemType = outputType.getElementType();
2798+
if (valueElemType != outputElemType) {
2799+
return emitOpError() << "expected element type of value " << valueElemType
2800+
<< " to match element type of output "
2801+
<< outputElemType;
2802+
}
2803+
2804+
unsigned valueRank = valueType.getRank();
2805+
if (valueRank != 6)
2806+
return emitOpError() << "expected rank of input is 6";
2807+
2808+
unsigned outputRank = outputType.getRank();
2809+
if (outputRank != 4)
2810+
return emitOpError() << "expected rank of output is 4";
2811+
2812+
return success();
2813+
}
2814+
27372815
//===----------------------------------------------------------------------===//
27382816
// LinalgDialect
27392817
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
3838
Transforms.cpp
3939
TransposeConv2D.cpp
4040
Vectorization.cpp
41+
WinogradConv2D.cpp
4142

4243
ADDITIONAL_HEADER_DIRS
4344
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg

0 commit comments

Comments
 (0)