Skip to content

Commit e11fa7a

Browse files
committed
[mlir][linalg] Implement Winograd Conv2D.
This patch implements the Winograd Conv2D algorithm. It supports several configurations of Winograd Conv2D, including F(2, 3), F(4, 3) and F(2, 5). These configurations show that the implementation can support different kernel size (3 and 5) and different output size (2 and 4). Besides symetric kernel size 3x3 and 5x5, this patch also supports 1x3, 3x1, 1x5, and 5x1 kernels. The implementation is based on the paper, Fast Algorithm for Convolutional Neural Networks. (https://arxiv.org/abs/1509.09308)
1 parent 419e7b8 commit e11fa7a

File tree

8 files changed

+1623
-0
lines changed

8 files changed

+1623
-0
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2587,4 +2587,42 @@ def MapCopyToThreadsOp :
25872587
}];
25882588
}
25892589

2590+
//===----------------------------------------------------------------------===//
2591+
// Winograd Conv2D
2592+
//===----------------------------------------------------------------------===//
2593+
2594+
def WinogradConv2DOp : Op<Transform_Dialect,
2595+
"structured.winograd_conv2d",
2596+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
2597+
TransformOpInterface, TransformEachOpTrait,
2598+
ReportTrackingListenerFailuresOpTrait]> {
2599+
let description = [{
2600+
Use Winograd Conv2D algorithm to compute Conv2D.
2601+
2602+
#### Return modes:
2603+
2604+
This operation fails if `target` is unsupported. Otherwise, the operation
2605+
succeeds and returns a handle of the sequence that replaces the original
2606+
convolution.
2607+
}];
2608+
2609+
let arguments = (ins TransformHandleTypeInterface:$target);
2610+
let results = (outs TransformHandleTypeInterface:$transformed);
2611+
2612+
let assemblyFormat =
2613+
"$target attr-dict `:` functional-type($target, results)";
2614+
2615+
let builders = [
2616+
OpBuilder<(ins "Value":$target)>
2617+
];
2618+
2619+
let extraClassDeclaration = [{
2620+
::mlir::DiagnosedSilenceableFailure applyToOne(
2621+
::mlir::transform::TransformRewriter &rewriter,
2622+
::mlir::linalg::LinalgOp target,
2623+
::mlir::transform::ApplyToEachResultList &results,
2624+
::mlir::transform::TransformState &state);
2625+
}];
2626+
}
2627+
25902628
#endif // LINALG_TRANSFORM_OPS

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,11 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
13121312
linalg::BatchMatmulOp op,
13131313
bool transposeLHS = true);
13141314

1315+
/// Convert linalg.conv_2d_nhwc_fhwc to a sequence of operations as Winograd
1316+
/// Conv2D algorithm.
1317+
FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
1318+
linalg::Conv2DNhwcFhwcOp op);
1319+
13151320
//===----------------------------------------------------------------------===//
13161321
// Rewrite patterns wrapping transformations.
13171322
// TODO: every single such pattern should be a close to noop wrapper around a
@@ -1692,6 +1697,9 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
16921697
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
16931698
const ControlBlockPackMatmulFn &controlFn);
16941699

1700+
/// Patterns to apply Winograd Conv2D algorithm.
1701+
void populateWinogradConv2DPatterns(RewritePatternSet &patterns);
1702+
16951703
} // namespace linalg
16961704
} // namespace mlir
16971705

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3480,6 +3480,31 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
34803480
return DiagnosedSilenceableFailure::success();
34813481
}
34823482

3483+
//===----------------------------------------------------------------------===//
3484+
// WinogradConv2DOp
3485+
//===----------------------------------------------------------------------===//
3486+
3487+
DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
3488+
transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3489+
transform::ApplyToEachResultList &results,
3490+
transform::TransformState &state) {
3491+
rewriter.setInsertionPoint(target);
3492+
auto maybeTransformed =
3493+
TypeSwitch<Operation *, FailureOr<Operation *>>(target)
3494+
.Case([&](linalg::Conv2DNhwcFhwcOp op) {
3495+
return winogradConv2D(rewriter, op);
3496+
})
3497+
.Default([&](Operation *op) {
3498+
return rewriter.notifyMatchFailure(op, "not supported");
3499+
});
3500+
3501+
if (failed(maybeTransformed))
3502+
return emitDefaultSilenceableFailure(target);
3503+
3504+
results.push_back(*maybeTransformed);
3505+
return DiagnosedSilenceableFailure::success();
3506+
}
3507+
34833508
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
34843509

34853510
#define GET_OP_CLASSES

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
3838
Transforms.cpp
3939
TransposeConv2D.cpp
4040
Vectorization.cpp
41+
WinogradConv2D.cpp
4142

4243
ADDITIONAL_HEADER_DIRS
4344
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg

0 commit comments

Comments
 (0)