Skip to content

Commit 37cb3db

Browse files
HsiangkaiAnthony Tran
authored andcommitted
[mlir][linalg] Constrain the parameters m, r in Winograd ops (llvm#144657)
We only support fixed set of minimum filtering algorithm for Winograd Conv2D decomposition. Instead of letting users specify any integer, define a fixed set of enumeration values for the parameters of minimum filtering algorithm.
1 parent de52a16 commit 37cb3db

File tree

16 files changed

+253
-208
lines changed

16 files changed

+253
-208
lines changed

mlir/include/mlir/Dialect/Linalg/IR/Linalg.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,20 @@ OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val,
100100

101101
#include "mlir/Dialect/Linalg/IR/LinalgOpsEnums.h.inc"
102102

103+
namespace mlir {
104+
namespace linalg {
105+
106+
/// Converts the given `m` and `r` parameters to a WinogradConv2DFmr enumeration
107+
/// value.
108+
std::optional<WinogradConv2DFmr> getWinogradConv2DFmr(int64_t m, int64_t r);
109+
110+
/// Converts the given WinogradConv2DFmr enumeration value to a pair of
111+
/// m and r parameters.
112+
std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr);
113+
114+
} // namespace linalg
115+
} // namespace mlir
116+
103117
//===----------------------------------------------------------------------===//
104118
// Linalg Attributes
105119
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,4 +122,19 @@ def TypeFn : I32EnumAttr<"TypeFn", "", [
122122
let cppNamespace = "::mlir::linalg";
123123
}
124124

125+
/// We use F(m, r) to define the size of minimal filtering algorithms.
126+
/// m is the output dimension and r is the filter dimension. We can get
127+
/// the input dimension, alpha, from the formula, alpha = m + r - 1.
128+
///
129+
/// For example, when m = 2 and r = 3, we know its input size is 4.
130+
/// The Conv2D will operate on 4x4 input data with 3x3 filter and get
131+
/// 2x2 output result.
132+
def WinogradConv2DFmr : I32EnumAttr<"WinogradConv2DFmr", "", [
133+
I32EnumAttrCase<"F_2_3", 0>,
134+
I32EnumAttrCase<"F_4_3", 1>,
135+
I32EnumAttrCase<"F_2_5", 2>,
136+
]>{
137+
let cppNamespace = "mlir::linalg";
138+
}
139+
125140
#endif // LINALG_ENUMS

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -183,15 +183,13 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
183183

184184
let arguments = (ins TensorRankOf<[AnyType], [4]>:$filter,
185185
TensorRankOf<[AnyType], [4]>:$output,
186-
I64Attr:$m,
187-
I64Attr:$r
186+
WinogradConv2DFmr:$fmr
188187
);
189188

190189
let results = (outs TensorRankOf<[AnyType], [4]>:$result);
191190
let assemblyFormat = [{
192191
attr-dict
193-
`m` `(` $m `)`
194-
`r` `(` $r `)`
192+
`fmr` `(` $fmr `)`
195193
`ins` `(` $filter `:` type($filter) `)`
196194
`outs` `(` $output `:` type($output) `)`
197195
`->` type($result)
@@ -254,15 +252,13 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
254252

255253
let arguments = (ins TensorRankOf<[AnyType], [4]>:$input,
256254
TensorRankOf<[AnyType], [6]>:$output,
257-
I64Attr:$m,
258-
I64Attr:$r
255+
WinogradConv2DFmr:$fmr
259256
);
260257

261258
let results = (outs TensorRankOf<[AnyType], [6]>:$result);
262259
let assemblyFormat = [{
263260
attr-dict
264-
`m` `(` $m `)`
265-
`r` `(` $r `)`
261+
`fmr` `(` $fmr `)`
266262
`ins` `(` $input `:` type($input) `)`
267263
`outs` `(` $output `:` type($output) `)`
268264
`->` type($result)
@@ -343,15 +339,13 @@ def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
343339

344340
let arguments = (ins TensorRankOf<[AnyType], [6]>:$value,
345341
TensorRankOf<[AnyType], [4]>:$output,
346-
I64Attr:$m,
347-
I64Attr:$r
342+
WinogradConv2DFmr:$fmr
348343
);
349344

350345
let results = (outs TensorRankOf<[AnyType], [4]>:$result);
351346
let assemblyFormat = [{
352347
attr-dict
353-
`m` `(` $m `)`
354-
`r` `(` $r `)`
348+
`fmr` `(` $fmr `)`
355349
`ins` `(` $value `:` type($value) `)`
356350
`outs` `(` $output `:` type($output) `)`
357351
`->` type($result)

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef LINALG_TRANSFORM_OPS
1010
#define LINALG_TRANSFORM_OPS
1111

12+
include "mlir/Dialect/Linalg/IR/LinalgEnums.td"
1213
include "mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td"
1314
include "mlir/Dialect/Transform/IR/TransformAttrs.td"
1415
include "mlir/Dialect/Transform/IR/TransformDialect.td"
@@ -2902,8 +2903,7 @@ def WinogradConv2DOp : Op<Transform_Dialect,
29022903
}];
29032904

29042905
let arguments = (ins TransformHandleTypeInterface:$target,
2905-
I64Attr:$m,
2906-
I64Attr:$r);
2906+
WinogradConv2DFmr:$fmr);
29072907
let results = (outs TransformHandleTypeInterface:$transformed);
29082908

29092909
let assemblyFormat =

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class BufferizationState;
3737
namespace linalg {
3838

3939
class LinalgOp;
40+
enum class WinogradConv2DFmr : uint32_t;
4041

4142
//===----------------------------------------------------------------------===//
4243
// Utils.
@@ -1426,8 +1427,8 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
14261427
/// F(m x m, r x r). m is the dimension size of output and r is the dimension
14271428
/// size of filter.
14281429
FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
1429-
linalg::Conv2DNhwcFhwcOp op, int64_t m,
1430-
int64_t r);
1430+
linalg::Conv2DNhwcFhwcOp op,
1431+
WinogradConv2DFmr fmr);
14311432

14321433
/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
14331434
/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
@@ -1968,8 +1969,8 @@ void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
19681969
const ControlBlockPackMatmulFn &controlFn);
19691970

19701971
/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
1971-
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
1972-
int64_t r);
1972+
void populateWinogradConv2DPatterns(RewritePatternSet &patterns,
1973+
WinogradConv2DFmr fmr);
19731974

19741975
/// Patterns to decompose Winograd operators.
19751976
void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2989,8 +2989,9 @@ LogicalResult WinogradFilterTransformOp::verify() {
29892989
ArrayRef<int64_t> filterShape = filterType.getShape();
29902990
int64_t filterH = filterShape[getFilterHDim()];
29912991
int64_t filterW = filterShape[getFilterWDim()];
2992-
int64_t r = getR();
2993-
int64_t m = getM();
2992+
WinogradConv2DFmr fmr = getFmr();
2993+
int64_t m, r;
2994+
std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
29942995

29952996
if (filterH != r && filterH != 1)
29962997
return emitOpError("expect filter height either equals to r or 1");
@@ -3046,8 +3047,9 @@ LogicalResult WinogradFilterTransformOp::getResultTilePosition(
30463047
ArrayRef<int64_t> filterShape = filterType.getShape();
30473048
int64_t filterH = filterShape[getFilterHDim()];
30483049
int64_t filterW = filterShape[getFilterWDim()];
3049-
int64_t m = getM();
3050-
int64_t r = getR();
3050+
WinogradConv2DFmr fmr = getFmr();
3051+
int64_t m, r;
3052+
std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
30513053
int64_t alpha = m + r - 1;
30523054
int64_t alphaH = filterH != 1 ? alpha : 1;
30533055
int64_t alphaW = filterW != 1 ? alpha : 1;
@@ -3124,8 +3126,9 @@ LogicalResult WinogradInputTransformOp::verify() {
31243126
ArrayRef<int64_t> inputShape = inputType.getShape();
31253127
int64_t inputH = inputShape[getInputHDim()];
31263128
int64_t inputW = inputShape[getInputWDim()];
3127-
int m = getM();
3128-
int r = getR();
3129+
WinogradConv2DFmr fmr = getFmr();
3130+
int64_t m, r;
3131+
std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
31293132
int64_t tileSize = m + r - 1;
31303133

31313134
auto outputType = cast<ShapedType>(getOutput().getType());
@@ -3194,8 +3197,9 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
31943197
int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
31953198
int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
31963199

3197-
int64_t m = getM();
3198-
int64_t r = getR();
3200+
WinogradConv2DFmr fmr = getFmr();
3201+
int64_t m, r;
3202+
std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
31993203
int64_t alpha = m + r - 1;
32003204
int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
32013205
int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
@@ -3224,8 +3228,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
32243228
ArrayRef<OpFoldResult> offsets,
32253229
ArrayRef<OpFoldResult> sizes) {
32263230
IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3227-
int64_t m = getM();
3228-
int64_t r = getR();
3231+
WinogradConv2DFmr fmr = getFmr();
3232+
int64_t m, r;
3233+
std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
32293234

32303235
ShapedType outputType = getOutputOperandType();
32313236
ArrayRef<int64_t> outputShape = outputType.getShape();
@@ -3303,8 +3308,9 @@ LogicalResult WinogradOutputTransformOp::verify() {
33033308
int64_t valueW = valueShape[getValueAlphaWDim()];
33043309
int64_t valueTileH = valueShape[getValueTileHDim()];
33053310
int64_t valueTileW = valueShape[getValueTileWDim()];
3306-
int m = getM();
3307-
int r = getR();
3311+
WinogradConv2DFmr fmr = getFmr();
3312+
int64_t m, r;
3313+
std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
33083314
bool leftTransform = valueH != 1;
33093315
bool rightTransform = valueW != 1;
33103316

@@ -3365,7 +3371,9 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
33653371
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
33663372
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
33673373
SmallVector<OpFoldResult> &resultSizes) {
3368-
int64_t m = getM();
3374+
WinogradConv2DFmr fmr = getFmr();
3375+
int64_t m, r;
3376+
std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
33693377

33703378
Location loc = getLoc();
33713379
MLIRContext *context = builder.getContext();
@@ -3623,6 +3631,27 @@ verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp,
36233631
namespace mlir {
36243632
namespace linalg {
36253633

3634+
std::optional<WinogradConv2DFmr> getWinogradConv2DFmr(int64_t m, int64_t r) {
3635+
if (m == 2 && r == 3)
3636+
return WinogradConv2DFmr::F_2_3;
3637+
if (m == 4 && r == 3)
3638+
return WinogradConv2DFmr::F_4_3;
3639+
if (m == 2 && r == 5)
3640+
return WinogradConv2DFmr::F_2_5;
3641+
return std::nullopt;
3642+
}
3643+
3644+
std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
3645+
switch (fmr) {
3646+
case WinogradConv2DFmr::F_2_3:
3647+
return {2, 3};
3648+
case WinogradConv2DFmr::F_4_3:
3649+
return {4, 3};
3650+
case WinogradConv2DFmr::F_2_5:
3651+
return {2, 5};
3652+
}
3653+
}
3654+
36263655
//===----------------------------------------------------------------------===//
36273656
// MatMulOp
36283657
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4250,7 +4250,7 @@ DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
42504250
bool supported = TypeSwitch<Operation *, bool>(target)
42514251
.Case([&](linalg::Conv2DNhwcFhwcOp op) {
42524252
maybeTransformed =
4253-
winogradConv2D(rewriter, op, getM(), getR());
4253+
winogradConv2D(rewriter, op, getFmr());
42544254
return true;
42554255
})
42564256
.Default([&](Operation *op) { return false; });

0 commit comments

Comments
 (0)