Skip to content

Commit 7d246e8

Browse files
authored
[mlir][linalg] Implement Conv2D using Winograd Conv2D algorithm (#96181)
Define high level winograd operators and convert conv_2d_nhwc_fhwc into winograd operators. According to Winograd Conv2D algorithm, we need three transform operators for input, filter, and output transformation. The formula of Winograd Conv2D algorithm is Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A filter transform: G x g x G^T input transform: B^T x d x B output transform: A^T x y x A The implementation is based on the paper, Fast Algorithm for Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308) Reviewers: stellaraccident, ftynse, Max191, GeorgeARM, cxy-1993, nicolasvasilache, MaheshRavishankar, dcaballe, rengolin Reviewed By: ftynse, Max191, stellaraccident Pull Request: #96181
1 parent 015526b commit 7d246e8

File tree

9 files changed

+943
-0
lines changed

9 files changed

+943
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,121 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
154154
let hasVerifier = 1;
155155
}
156156

157+
def Linalg_WinogradFilterTransformOp :
158+
Linalg_Op<"winograd_filter_transform", [AllElementTypesMatch<["filter", "output"]>]> {
159+
let summary = "Winograd filter transform operator";
160+
let description = [{
161+
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
162+
matrix multiply. Before the matrix multiply, it will convert filter and
163+
input into a format suitable for batched matrix multiply. After the matrix
164+
multiply, it will convert output to the final result tensor.
165+
166+
The algorithm F(m x m, r x r) is
167+
168+
Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
169+
170+
The size of output Y is m x m. The size of filter g is r x r. The size of
171+
input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
172+
transformation matrices.
173+
174+
This operator is defined to represent the high level concept of filter
175+
transformation (G x g x G^T) in the Winograd Conv2D algorithm.
176+
}];
177+
178+
let arguments = (ins TensorRankOf<[AnyType], [4]>:$filter,
179+
TensorRankOf<[AnyType], [4]>:$output,
180+
I64Attr:$m,
181+
I64Attr:$r
182+
);
183+
184+
let results = (outs TensorRankOf<[AnyType], [4]>:$result);
185+
let assemblyFormat = [{
186+
attr-dict
187+
`m` `(` $m `)`
188+
`r` `(` $r `)`
189+
`ins` `(` $filter `:` type($filter) `)`
190+
`outs` `(` $output `:` type($output) `)`
191+
`->` type($result)
192+
}];
193+
let hasVerifier = 1;
194+
}
195+
196+
def Linalg_WinogradInputTransformOp :
197+
Linalg_Op<"winograd_input_transform", [AllElementTypesMatch<["input", "output"]>]> {
198+
let summary = "Winograd input transform operator";
199+
let description = [{
200+
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
201+
matrix multiply. Before the matrix multiply, it will convert filter and
202+
input into a format suitable for batched matrix multiply. After the matrix
203+
multiply, it will convert output to the final result tensor.
204+
205+
The algorithm F(m x m, r x r) is
206+
207+
Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
208+
209+
The size of output Y is m x m. The size of filter g is r x r. The size of
210+
input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
211+
transformation matrices.
212+
213+
This operator is defined to represent the high level concept of input
214+
transformation (B^T x d x B) in the Winograd Conv2D algorithm.
215+
}];
216+
217+
let arguments = (ins TensorRankOf<[AnyType], [4]>:$input,
218+
TensorRankOf<[AnyType], [6]>:$output,
219+
I64Attr:$m,
220+
I64Attr:$r
221+
);
222+
223+
let results = (outs TensorRankOf<[AnyType], [6]>:$result);
224+
let assemblyFormat = [{
225+
attr-dict
226+
`m` `(` $m `)`
227+
`r` `(` $r `)`
228+
`ins` `(` $input `:` type($input) `)`
229+
`outs` `(` $output `:` type($output) `)`
230+
`->` type($result)
231+
}];
232+
let hasVerifier = 1;
233+
}
234+
235+
def Linalg_WinogradOutputTransformOp :
236+
Linalg_Op<"winograd_output_transform", [AllElementTypesMatch<["value", "output"]>]> {
237+
let summary = "Winograd output transform operator";
238+
let description = [{
239+
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
240+
matrix multiply. Before the matrix multiply, it will convert filter and
241+
input into a format suitable for batched matrix multiply. After the matrix
242+
multiply, it will convert output to the final result tensor.
243+
244+
The algorithm F(m x m, r x r) is
245+
246+
Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
247+
248+
The size of output Y is m x m. The size of filter g is r x r. The size of
249+
input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
250+
transformation matrices.
251+
252+
This operator is defined to represent the high level concept of output
253+
transformation (A^T x y x A) in the Winograd Conv2D algorithm.
254+
}];
255+
256+
let arguments = (ins TensorRankOf<[AnyType], [6]>:$value,
257+
TensorRankOf<[AnyType], [4]>:$output,
258+
I64Attr:$m,
259+
I64Attr:$r
260+
);
261+
262+
let results = (outs TensorRankOf<[AnyType], [4]>:$result);
263+
let assemblyFormat = [{
264+
attr-dict
265+
`m` `(` $m `)`
266+
`r` `(` $r `)`
267+
`ins` `(` $value `:` type($value) `)`
268+
`outs` `(` $output `:` type($output) `)`
269+
`->` type($result)
270+
}];
271+
let hasVerifier = 1;
272+
}
273+
157274
#endif // LINALG_OPS

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1735,6 +1735,10 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
17351735
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
17361736
const ControlBlockPackMatmulFn &controlFn);
17371737

1738+
/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
1739+
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
1740+
int64_t r);
1741+
17381742
/// Adds patterns that reduce the rank of named contraction ops that have
17391743
/// unit dimensions in the operand(s) by converting to a sequence of `collapse_shape`,
17401744
/// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For example a

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2739,6 +2739,122 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
27392739
return SmallVector<Value>{result};
27402740
}
27412741

2742+
//===----------------------------------------------------------------------===//
2743+
// WinogradFilterTransformOp
2744+
//===----------------------------------------------------------------------===//
2745+
2746+
LogicalResult WinogradFilterTransformOp::verify() {
2747+
auto filterType = cast<ShapedType>(getFilter().getType());
2748+
ArrayRef<int64_t> filterShape = filterType.getShape();
2749+
int64_t filterH = filterShape[1];
2750+
int64_t filterW = filterShape[2];
2751+
int64_t r = getR();
2752+
int64_t m = getM();
2753+
2754+
if (filterH != r && filterH != 1)
2755+
return emitOpError("expect filter height either equals to r or 1");
2756+
if (filterW != r && filterW != 1)
2757+
return emitOpError("expect filter width either equals to r or 1");
2758+
if (filterH == 1 && filterW == 1)
2759+
return emitOpError("expect either filter height or width equals to r");
2760+
2761+
SmallVector<int64_t> expectedOutputShape;
2762+
expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
2763+
expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
2764+
expectedOutputShape.push_back(filterShape[3]);
2765+
expectedOutputShape.push_back(filterShape[0]);
2766+
2767+
auto outputType = cast<ShapedType>(getOutput().getType());
2768+
ArrayRef<int64_t> outputShape = outputType.getShape();
2769+
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
2770+
return emitOpError("the output shape is not expected");
2771+
}
2772+
return success();
2773+
}
2774+
2775+
//===----------------------------------------------------------------------===//
2776+
// WinogradInputTransformOp
2777+
//===----------------------------------------------------------------------===//
2778+
2779+
LogicalResult WinogradInputTransformOp::verify() {
2780+
auto inputType = cast<ShapedType>(getInput().getType());
2781+
ArrayRef<int64_t> inputShape = inputType.getShape();
2782+
int64_t inputH = inputShape[1];
2783+
int64_t inputW = inputShape[2];
2784+
int m = getM();
2785+
int r = getR();
2786+
int64_t tileSize = m + r - 1;
2787+
bool leftTransform = inputH != 1;
2788+
bool rightTransform = inputW != 1;
2789+
2790+
SmallVector<int64_t> expectedOutputShape(6, inputH);
2791+
if (ShapedType::isDynamic(inputH)) {
2792+
expectedOutputShape[0] = tileSize;
2793+
expectedOutputShape[2] = ShapedType::kDynamic;
2794+
} else {
2795+
expectedOutputShape[0] = leftTransform ? tileSize : 1;
2796+
expectedOutputShape[2] = leftTransform ? (inputH - (r - 1)) / m : 1;
2797+
}
2798+
if (ShapedType::isDynamic(inputW)) {
2799+
expectedOutputShape[1] = tileSize;
2800+
expectedOutputShape[3] = ShapedType::kDynamic;
2801+
} else {
2802+
expectedOutputShape[1] = rightTransform ? tileSize : 1;
2803+
expectedOutputShape[3] = rightTransform ? (inputW - (r - 1)) / m : 1;
2804+
}
2805+
expectedOutputShape[4] = inputShape[0];
2806+
expectedOutputShape[5] = inputShape[3];
2807+
2808+
auto outputType = cast<ShapedType>(getOutput().getType());
2809+
ArrayRef<int64_t> outputShape = outputType.getShape();
2810+
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
2811+
return emitOpError("the output shape is not expected");
2812+
}
2813+
return success();
2814+
}
2815+
2816+
//===----------------------------------------------------------------------===//
2817+
// WinogradOutputTransformOp
2818+
//===----------------------------------------------------------------------===//
2819+
2820+
LogicalResult WinogradOutputTransformOp::verify() {
2821+
auto valueType = cast<ShapedType>(getValue().getType());
2822+
ArrayRef<int64_t> valueShape = valueType.getShape();
2823+
int64_t valueH = valueShape[0];
2824+
int64_t valueW = valueShape[1];
2825+
int64_t valueTileH = valueShape[2];
2826+
int64_t valueTileW = valueShape[3];
2827+
int m = getM();
2828+
int r = getR();
2829+
bool leftTransform = valueH != 1;
2830+
bool rightTransform = valueW != 1;
2831+
2832+
SmallVector<int64_t> expectedOutputShape(4, valueH);
2833+
if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
2834+
expectedOutputShape[1] = ShapedType::kDynamic;
2835+
} else {
2836+
if (valueH != (leftTransform ? m + r - 1 : 1))
2837+
return emitOpError("expect input height equals to input tile size");
2838+
expectedOutputShape[1] = (leftTransform ? m : 1) * valueTileH;
2839+
}
2840+
if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
2841+
expectedOutputShape[2] = ShapedType::kDynamic;
2842+
} else {
2843+
if (valueW != (rightTransform ? m + r - 1 : 1))
2844+
return emitOpError("expect input width equals to input tile size");
2845+
expectedOutputShape[2] = (rightTransform ? m : 1) * valueTileW;
2846+
}
2847+
expectedOutputShape[0] = valueShape[4];
2848+
expectedOutputShape[3] = valueShape[5];
2849+
2850+
auto outputType = cast<ShapedType>(getOutput().getType());
2851+
ArrayRef<int64_t> outputShape = outputType.getShape();
2852+
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
2853+
return emitOpError("the output shape is not expected");
2854+
}
2855+
return success();
2856+
}
2857+
27422858
//===----------------------------------------------------------------------===//
27432859
// LinalgDialect
27442860
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
3838
Transforms.cpp
3939
TransposeConv2D.cpp
4040
Vectorization.cpp
41+
WinogradConv2D.cpp
4142

4243
ADDITIONAL_HEADER_DIRS
4344
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg

0 commit comments

Comments
 (0)