Skip to content

Commit 8113ac4

Browse files
committed
[mlir][linalg] Add transform operator for Winograd Conv2D algorithm
Add a transform operator structured.winograd_conv2d to convert linalg.conv_2d_nhwc_fhwc to Linalg winograd operators.
1 parent 1fa83da commit 8113ac4

File tree

5 files changed

+173
-0
lines changed

5 files changed

+173
-0
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2587,4 +2587,55 @@ def MapCopyToThreadsOp :
25872587
}];
25882588
}
25892589

2590+
//===----------------------------------------------------------------------===//
2591+
// Winograd Conv2D
2592+
//===----------------------------------------------------------------------===//
2593+
2594+
def WinogradConv2DOp : Op<Transform_Dialect,
2595+
"structured.winograd_conv2d",
2596+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
2597+
TransformOpInterface, TransformEachOpTrait,
2598+
ReportTrackingListenerFailuresOpTrait]> {
2599+
let description = [{
2600+
Winograd Conv2D algorithm will convert linalg Conv2D operation into batched
2601+
matrix multiply. Before the matrix multiply, it will convert filter and
2602+
input into a format suitable for batched matrix multiply. After the matrix
2603+
multiply, it will convert output to the final result tensor.
2604+
2605+
The algorithm F(m x m, r x r) is
2606+
2607+
Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
2608+
2609+
The size of output Y is m x m. The size of filter g is r x r. The size of
2610+
input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
2611+
transformation matrices.
2612+
2613+
#### Return modes:
2614+
2615+
This operation produces a silenceable failure if `target` is unsupported.
2616+
Otherwise, the operation succeeds and returns a handle of the sequence that
2617+
replaces the original convolution.
2618+
}];
2619+
2620+
let arguments = (ins TransformHandleTypeInterface:$target,
2621+
I64Attr:$m,
2622+
I64Attr:$r);
2623+
let results = (outs TransformHandleTypeInterface:$transformed);
2624+
2625+
let assemblyFormat =
2626+
"$target attr-dict `:` functional-type($target, results)";
2627+
2628+
let builders = [
2629+
OpBuilder<(ins "Value":$target)>
2630+
];
2631+
2632+
let extraClassDeclaration = [{
2633+
::mlir::DiagnosedSilenceableFailure applyToOne(
2634+
::mlir::transform::TransformRewriter &rewriter,
2635+
::mlir::linalg::LinalgOp target,
2636+
::mlir::transform::ApplyToEachResultList &results,
2637+
::mlir::transform::TransformState &state);
2638+
}];
2639+
}
2640+
25902641
#endif // LINALG_TRANSFORM_OPS

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,13 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
13121312
linalg::BatchMatmulOp op,
13131313
bool transposeLHS = true);
13141314

1315+
/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm
1316+
/// F(m x m, r x r). m is the dimension size of output and r is the dimension
1317+
/// size of filter.
1318+
FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
1319+
linalg::Conv2DNhwcFhwcOp op, int64_t m,
1320+
int64_t r);
1321+
13151322
//===----------------------------------------------------------------------===//
13161323
// Rewrite patterns wrapping transformations.
13171324
// TODO: every single such pattern should be a close to noop wrapper around a

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3480,6 +3480,39 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
34803480
return DiagnosedSilenceableFailure::success();
34813481
}
34823482

3483+
//===----------------------------------------------------------------------===//
3484+
// WinogradConv2DOp
3485+
//===----------------------------------------------------------------------===//
3486+
3487+
DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
3488+
transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3489+
transform::ApplyToEachResultList &results,
3490+
transform::TransformState &state) {
3491+
rewriter.setInsertionPoint(target);
3492+
FailureOr<Operation *> maybeTransformed = failure();
3493+
bool supported = TypeSwitch<Operation *, bool>(target)
3494+
.Case([&](linalg::Conv2DNhwcFhwcOp op) {
3495+
maybeTransformed =
3496+
winogradConv2D(rewriter, op, getM(), getR());
3497+
return true;
3498+
})
3499+
.Default([&](Operation *op) {
3500+
return false;
3501+
});
3502+
3503+
if (!supported) {
3504+
return emitSilenceableError()
3505+
<< "this operation is not supported to convert to Winograd Conv2D";
3506+
}
3507+
3508+
if (supported && failed(maybeTransformed)) {
3509+
return emitSilenceableError() << "apply Winograd Conv2D failed";
3510+
}
3511+
3512+
results.push_back(*maybeTransformed);
3513+
return DiagnosedSilenceableFailure::success();
3514+
}
3515+
34833516
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
34843517

34853518
#define GET_OP_CLASSES

mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,12 @@ class WinogradConv2DNhwcFhwc final
324324
} // end anonymous namespace
325325

326326
//===----------------------------------------------------------------------===//
327+
FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
328+
linalg::Conv2DNhwcFhwcOp op, int64_t m,
329+
int64_t r) {
330+
return winogradConv2DHelper(rewriter, op, m, r);
331+
}
332+
327333
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
328334
int64_t r) {
329335
MLIRContext *context = patterns.getContext();
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file -verify-diagnostics| FileCheck %s
2+
3+
func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
4+
%0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
5+
return %0 : tensor<2x8x8x2xf32>
6+
}
7+
8+
module attributes {transform.with_named_sequence} {
9+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
10+
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
11+
%1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
12+
transform.yield
13+
}
14+
}
15+
16+
// CHECK-LABEL: func.func @conv2d
17+
// CHECK: linalg.winograd_filter_transform m(4) r(3)
18+
// CHECK: linalg.winograd_input_transform m(4) r(3)
19+
// CHECK: linalg.batch_matmul
20+
// CHECK: linalg.winograd_output_transform m(4) r(3)
21+
22+
// -----
23+
24+
func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> {
25+
%0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
26+
return %0 : tensor<2x9x9x2xf32>
27+
}
28+
29+
module attributes {transform.with_named_sequence} {
30+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
31+
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
32+
%1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
33+
transform.yield
34+
}
35+
}
36+
37+
// CHECK-LABEL: func.func @conv2d_unaligned
38+
// CHECK: linalg.winograd_filter_transform m(4) r(3)
39+
// CHECK: tensor.pad
40+
// CHECK-SAME: low[0, 0, 0, 0] high[0, 3, 3, 0]
41+
// CHECK: linalg.winograd_input_transform m(4) r(3)
42+
// CHECK: tensor.pad
43+
// CHECK-SAME: low[0, 0, 0, 0] high[0, 3, 3, 0]
44+
// CHECK: linalg.winograd_output_transform m(4) r(3)
45+
46+
// -----
47+
48+
func.func @conv2d_unsupported(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<3x3x5x2xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> {
49+
%0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<3x3x5x2xf32>) outs(%arg3 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
50+
return %0 : tensor<2x8x8x2xf32>
51+
}
52+
53+
module attributes {transform.with_named_sequence} {
54+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
55+
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op
56+
// expected-error @+1 {{this operation is not supported to convert to Winograd Conv2D}}
57+
%1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
58+
transform.yield
59+
}
60+
}
61+
62+
// -----
63+
64+
func.func @conv2d(%arg0: tensor<2x?x?x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32> {
65+
%0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x?x?x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32>
66+
return %0 : tensor<2x?x?x2xf32>
67+
}
68+
69+
module attributes {transform.with_named_sequence} {
70+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
71+
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
72+
// expected-error @+1 {{apply Winograd Conv2D failed}}
73+
%1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
74+
transform.yield
75+
}
76+
}

0 commit comments

Comments
 (0)