Skip to content

Commit 0b6f8ae

Browse files
committed
[mlir][linalg] Implement Conv2D using Winograd Conv2D algorithm
Define high level winograd operators and convert conv_2d_nhwc_fhwc into winograd operators. According to Winograd Conv2D algorithm, we need three transform operators for input, filter, and output transformation. The formula of Winograd Conv2D algorithm is Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A filter transform: G x g x G^T input transform: B^T x d x B output transform: A^T x y x A The implementation is based on the paper, Fast Algorithm for Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
1 parent 7439072 commit 0b6f8ae

File tree

6 files changed

+535
-0
lines changed

6 files changed

+535
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,123 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
154154
let hasVerifier = 1;
155155
}
156156

157+
def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
158+
let summary = "Winograd filter transform operator";
159+
let description = [{
160+
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
161+
matrix multiply. Before the matrix multiply, it will convert filter and
162+
input into a format suitable for batched matrix multiply. After the matrix
163+
multiply, it will convert output to the final result tensor.
164+
165+
The algorithm F(m x m, r x r) is
166+
167+
Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
168+
169+
The size of output Y is m x m. The size of filter g is r x r. The size of
170+
input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
171+
transformation matrices.
172+
173+
This operator is defined to represent the high level concept of filter
174+
transformation (G x g x G^T) in the Winograd Conv2D algorithm.
175+
}];
176+
177+
let arguments = (ins AnyRankedTensor:$filter,
178+
AnyRankedTensor:$output,
179+
I64Attr:$output_height,
180+
I64Attr:$output_width,
181+
I64Attr:$m,
182+
I64Attr:$r
183+
);
184+
185+
let results = (outs AnyRankedTensor:$result);
186+
let assemblyFormat = [{
187+
attr-dict
188+
`output_height` `(` $output_height `)`
189+
`output_width` `(` $output_width `)`
190+
`m` `(` $m `)`
191+
`r` `(` $r `)`
192+
`ins` `(` $filter `:` type($filter) `)`
193+
`outs` `(` $output `:` type($output) `)`
194+
`->` type($result)
195+
}];
196+
}
197+
198+
def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
199+
let summary = "Winograd input transform operator";
200+
let description = [{
201+
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
202+
matrix multiply. Before the matrix multiply, it will convert filter and
203+
input into a format suitable for batched matrix multiply. After the matrix
204+
multiply, it will convert output to the final result tensor.
205+
206+
The algorithm F(m x m, r x r) is
207+
208+
Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
209+
210+
The size of output Y is m x m. The size of filter g is r x r. The size of
211+
input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
212+
transformation matrices.
213+
214+
This operator is defined to represent the high level concept of input
215+
transformation (B^T x d x B) in the Winograd Conv2D algorithm.
216+
}];
217+
218+
let arguments = (ins AnyRankedTensor:$input,
219+
AnyRankedTensor:$output,
220+
I64Attr:$output_height,
221+
I64Attr:$output_width,
222+
I64Attr:$m,
223+
I64Attr:$r
224+
);
225+
226+
let results = (outs AnyRankedTensor:$result);
227+
let assemblyFormat = [{
228+
attr-dict
229+
`output_height` `(` $output_height `)`
230+
`output_width` `(` $output_width `)`
231+
`m` `(` $m `)`
232+
`r` `(` $r `)`
233+
`ins` `(` $input `:` type($input) `)`
234+
`outs` `(` $output `:` type($output) `)`
235+
`->` type($result)
236+
}];
237+
}
238+
239+
def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> {
240+
let summary = "Winograd output transform operator";
241+
let description = [{
242+
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
243+
matrix multiply. Before the matrix multiply, it will convert filter and
244+
input into a format suitable for batched matrix multiply. After the matrix
245+
multiply, it will convert output to the final result tensor.
246+
247+
The algorithm F(m x m, r x r) is
248+
249+
Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
250+
251+
The size of output Y is m x m. The size of filter g is r x r. The size of
252+
input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
253+
transformation matrices.
254+
255+
This operator is defined to represent the high level concept of output
256+
transformation (A^T x y x A) in the Winograd Conv2D algorithm.
257+
}];
258+
259+
let arguments = (ins AnyRankedTensor:$value,
260+
AnyRankedTensor:$output,
261+
I64Attr:$m,
262+
I64Attr:$r
263+
);
264+
265+
let results = (outs AnyRankedTensor:$result);
266+
let assemblyFormat = [{
267+
attr-dict
268+
`m` `(` $m `)`
269+
`r` `(` $r `)`
270+
`ins` `(` $value `:` type($value) `)`
271+
`outs` `(` $output `:` type($output) `)`
272+
`->` type($result)
273+
}];
274+
}
275+
157276
#endif // LINALG_OPS

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1692,6 +1692,10 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
16921692
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
16931693
const ControlBlockPackMatmulFn &controlFn);
16941694

1695+
/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
1696+
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
1697+
int64_t r);
1698+
16951699
} // namespace linalg
16961700
} // namespace mlir
16971701

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
3838
Transforms.cpp
3939
TransposeConv2D.cpp
4040
Vectorization.cpp
41+
WinogradConv2D.cpp
4142

4243
ADDITIONAL_HEADER_DIRS
4344
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
//===- WinogradConv2D.cpp - Winograd Conv2D implementation ----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Implement Winograd Conv2D algorithm. The implementation is based on the
10+
// paper: Fast Algorithms for Convolutional Neural Networks
11+
// (https://arxiv.org/abs/1509.09308)
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
16+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
17+
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
18+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19+
20+
namespace mlir {
21+
namespace linalg {
22+
23+
namespace {
24+
25+
using TransformMapKeyTy = std::pair<int, int>;
26+
27+
// We use F(m, r) to define the size of minimal filtering algorithms.
28+
// m is the output dimension and r is the filter dimension. We can get
29+
// the input dimension, alpha, from the formula, alpha = m + r - 1.
30+
//
31+
// For example, when m = 2 and r = 3, we know its input size is 4.
32+
// The Conv2D will operate on 4x4 input data with 3x3 filter and get
33+
// 2x2 output result.
34+
constexpr TransformMapKeyTy F_2_3{2, 3};
35+
constexpr TransformMapKeyTy F_4_3{4, 3};
36+
constexpr TransformMapKeyTy F_2_5{2, 5};
37+
38+
Value collapse2DData(RewriterBase &rewriter, Location loc, Value data) {
39+
auto type = cast<ShapedType>(data.getType());
40+
auto elementType = type.getElementType();
41+
auto shape = type.getShape();
42+
auto collapseType = RankedTensorType::get(
43+
{shape[0] * shape[1], shape[2], shape[3]}, elementType);
44+
SmallVector<ReassociationIndices> reassociation = {{0, 1}, {2}, {3}};
45+
return rewriter.create<tensor::CollapseShapeOp>(loc, collapseType, data,
46+
reassociation);
47+
}
48+
49+
// This function generates linalg.batch_matmul to multiply input with filter.
50+
// linalg.batch_matmul only supports 3-dimension data sets. We can treat H x W
51+
// data as the 1-dimension data array. That is to convert [H, W, N, C] to
52+
// [H x W, N, C]. In this way, we can convert 4-dimension input data to
53+
// 3-dimension representation that is suitable for linalg.batch_matmul.
54+
//
55+
// Batched matmul will do the matrix multiply with the reduction on channel.
56+
//
57+
// We get
58+
//
59+
// %collapsed_input = tensor.collapse_shape %input
60+
// %collapsed_filter = tensor.collapse_shape %filter
61+
// %ret = linalg.batch_matmul %collapsed_input, %collapsed_filter
62+
// %expanded_ret = tensor.expand_shape %ret
63+
//
64+
// After this function, we get return value with data layout (H, W, N, F)
65+
//
66+
Value matrixMultiply(RewriterBase &rewriter, Location loc,
67+
Value transformedFilter, Value transformedInput) {
68+
auto collapseFilter = collapse2DData(rewriter, loc, transformedFilter);
69+
auto collapseInput = collapse2DData(rewriter, loc, transformedInput);
70+
71+
// Batched matrix multiply
72+
auto filterType = cast<ShapedType>(transformedFilter.getType());
73+
auto filterShape = filterType.getShape();
74+
auto inputType = cast<ShapedType>(transformedInput.getType());
75+
auto inputElemType = inputType.getElementType();
76+
auto inputShape = inputType.getShape();
77+
78+
auto matmulType = RankedTensorType::get(
79+
{inputShape[0] * inputShape[1], inputShape[2], filterShape[3]},
80+
inputElemType);
81+
Value init = rewriter.create<tensor::EmptyOp>(loc, matmulType.getShape(),
82+
inputElemType);
83+
84+
auto matmulOp = rewriter.create<linalg::BatchMatmulOp>(
85+
loc, matmulType, ValueRange({collapseInput, collapseFilter}),
86+
ValueRange{init});
87+
88+
// Expand matmul result
89+
SmallVector<ReassociationIndices> reassociation = {{0, 1}, {2}, {3}};
90+
auto expandType = RankedTensorType::get(
91+
{inputShape[0], inputShape[1], inputShape[2], filterShape[3]},
92+
inputElemType);
93+
auto expandOutput = rewriter.create<tensor::ExpandShapeOp>(
94+
loc, expandType, matmulOp.getResult(0), reassociation);
95+
return expandOutput;
96+
}
97+
98+
FailureOr<Operation *> winogradConv2DHelper(RewriterBase &rewriter,
99+
linalg::Conv2DNhwcFhwcOp convOp,
100+
int64_t m, int64_t r) {
101+
Value input = convOp.getInputs()[0];
102+
Value filter = convOp.getInputs()[1];
103+
Value output = convOp.getOutputs()[0];
104+
105+
auto outputType = cast<ShapedType>(output.getType());
106+
int64_t outputH = outputType.getShape()[1];
107+
int64_t outputW = outputType.getShape()[2];
108+
auto filterType = cast<ShapedType>(filter.getType());
109+
auto filterShape = filterType.getShape(); // F, H, W, C
110+
int64_t filterF = filterShape[0];
111+
int64_t filterH = filterShape[1];
112+
int64_t filterW = filterShape[2];
113+
int64_t filterC = filterShape[3];
114+
auto inputType = cast<ShapedType>(input.getType());
115+
auto inputShape = inputType.getShape(); // N, H, W, C
116+
int64_t inputN = inputShape[0];
117+
int64_t inputC = inputShape[3];
118+
119+
// Only support F(m x m, r x r), F(m x 1, r x 1) or F(1 x m, 1 x r)
120+
if ((outputH != outputW) && (outputH != 1 && outputW != 1))
121+
return failure();
122+
if ((filterH != filterW) && (filterH != 1 && filterW != 1))
123+
return failure();
124+
125+
if ((outputH == 1 && filterH != 1) || (outputH != 1 && filterH == 1))
126+
return failure();
127+
if ((outputW == 1 && filterW != 1) || (outputW != 1 && filterW == 1))
128+
return failure();
129+
130+
// Map from (m, r) to G transform matrix.
131+
static const llvm::SmallVector<TransformMapKeyTy, 3> validConfigs = {
132+
F_2_3, F_4_3, F_2_5};
133+
134+
TransformMapKeyTy key = {m, r};
135+
auto it = std::find(validConfigs.begin(), validConfigs.end(), key);
136+
// If we cannot find the constant transformation matrix, it means we do
137+
// not support this configuration yet.
138+
if (it == validConfigs.end())
139+
return failure();
140+
141+
// All the criterias are satisfied. We can do Winograd Conv2D.
142+
Location loc = convOp.getLoc();
143+
144+
// For F(m x 1, r x 1), we only need to do left side transform.
145+
bool leftTransform = outputH != 1;
146+
// For F(1 x m, 1 x r), we only need to do right side transform.
147+
bool rightTransform = outputW != 1;
148+
149+
// Create operator for filter transform
150+
Type elementType = filterType.getElementType();
151+
int64_t alphaH = leftTransform ? m + r - 1 : 1;
152+
int64_t alphaW = rightTransform ? m + r - 1 : 1;
153+
int64_t retHeight = leftTransform ? (outputH / m) * alphaH : 1;
154+
int64_t retWidth = rightTransform ? (outputW / m) * alphaW : 1;
155+
auto retType = RankedTensorType::get({retHeight, retWidth, filterC, filterF},
156+
elementType);
157+
Value retValue =
158+
rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType);
159+
auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>(
160+
loc, retType, filter, retValue, outputH, outputW, m, r);
161+
162+
// Create operator for input transform
163+
retType =
164+
RankedTensorType::get({retHeight, retWidth, inputN, inputC}, elementType);
165+
retValue =
166+
rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), elementType);
167+
auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>(
168+
loc, retType, input, retValue, outputH, outputW, m, r);
169+
170+
Value matmulRet =
171+
matrixMultiply(rewriter, loc, transformedFilter, transformedInput);
172+
173+
// create operator for output transform
174+
auto transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>(
175+
loc, outputType, matmulRet, output, m, r);
176+
177+
rewriter.replaceOp(convOp, transformedOutput);
178+
179+
return transformedOutput.getOperation();
180+
}
181+
182+
class WinogradConv2DNhwcFhwc final
183+
: public OpRewritePattern<linalg::Conv2DNhwcFhwcOp> {
184+
public:
185+
using OpRewritePattern::OpRewritePattern;
186+
WinogradConv2DNhwcFhwc(mlir::MLIRContext *context, int64_t m, int64_t r)
187+
: OpRewritePattern(context), m(m), r(r) {}
188+
189+
LogicalResult matchAndRewrite(linalg::Conv2DNhwcFhwcOp convOp,
190+
PatternRewriter &rewriter) const override {
191+
Value filter = convOp.getInputs()[1];
192+
auto filterType = cast<ShapedType>(filter.getType());
193+
auto filterShape = filterType.getShape(); // F, H, W, C
194+
int64_t filterH = filterShape[1];
195+
int64_t filterW = filterShape[2];
196+
Value output = convOp.getOutputs()[0];
197+
auto outputType = cast<ShapedType>(output.getType());
198+
auto outputShape = outputType.getShape(); // F, H, W, C
199+
int64_t outputH = outputShape[1];
200+
int64_t outputW = outputShape[2];
201+
202+
if (filterH != r && filterH != 1 && filterW != r && filterW != 1)
203+
return failure();
204+
205+
if (outputH < m && outputH != 1 && outputW < m && outputW != 1)
206+
return failure();
207+
208+
if (failed(winogradConv2DHelper(rewriter, convOp, m, r)))
209+
return failure();
210+
211+
return success();
212+
}
213+
214+
private:
215+
int64_t m;
216+
int64_t r;
217+
};
218+
} // end anonymous namespace
219+
220+
//===----------------------------------------------------------------------===//
221+
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
222+
int64_t r) {
223+
MLIRContext *context = patterns.getContext();
224+
patterns.insert<WinogradConv2DNhwcFhwc>(context, m, r);
225+
}
226+
227+
} // end namespace linalg
228+
} // end namespace mlir

0 commit comments

Comments
 (0)