Skip to content

Commit dc00c79

Browse files
committed
[mlir][linalg] Add transform operator for Winograd Conv2D algorithm
Add a transform operator structured.winograd_conv2d to convert linalg.conv_2d_nhwc_fhwc to Linalg winograd operators.
1 parent 276ed89 commit dc00c79

File tree

5 files changed

+177
-0
lines changed

5 files changed

+177
-0
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2587,4 +2587,55 @@ def MapCopyToThreadsOp :
25872587
}];
25882588
}
25892589

2590+
//===----------------------------------------------------------------------===//
2591+
// Winograd Conv2D
2592+
//===----------------------------------------------------------------------===//
2593+
2594+
def WinogradConv2DOp : Op<Transform_Dialect,
2595+
"structured.winograd_conv2d",
2596+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
2597+
TransformOpInterface, TransformEachOpTrait,
2598+
ReportTrackingListenerFailuresOpTrait]> {
2599+
let description = [{
2600+
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
2601+
matrix multiply. Before the matrix multiply, it will convert filter and
2602+
input into a format suitable for batched matrix multiply. After the matrix
2603+
multiply, it will convert output to the final result tensor.
2604+
2605+
The algorithm F(m x m, r x r) is
2606+
2607+
Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
2608+
2609+
The size of output Y is m x m. The size of filter g is r x r. The size of
2610+
input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
2611+
transformation matrices.
2612+
2613+
#### Return modes:
2614+
2615+
This operation fails if `target` is unsupported. Otherwise, the operation
2616+
succeeds and returns a handle of the sequence that replaces the original
2617+
convolution.
2618+
}];
2619+
2620+
let arguments = (ins TransformHandleTypeInterface:$target,
2621+
I64Attr:$m,
2622+
I64Attr:$r);
2623+
let results = (outs TransformHandleTypeInterface:$transformed);
2624+
2625+
let assemblyFormat =
2626+
"$target attr-dict `:` functional-type($target, results)";
2627+
2628+
let builders = [
2629+
OpBuilder<(ins "Value":$target)>
2630+
];
2631+
2632+
let extraClassDeclaration = [{
2633+
::mlir::DiagnosedSilenceableFailure applyToOne(
2634+
::mlir::transform::TransformRewriter &rewriter,
2635+
::mlir::linalg::LinalgOp target,
2636+
::mlir::transform::ApplyToEachResultList &results,
2637+
::mlir::transform::TransformState &state);
2638+
}];
2639+
}
2640+
25902641
#endif // LINALG_TRANSFORM_OPS

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,13 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
13121312
linalg::BatchMatmulOp op,
13131313
bool transposeLHS = true);
13141314

1315+
/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm
1316+
/// F(m x m, r x r). m is the dimension size of output and r is the dimension
1317+
/// size of filter.
1318+
FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
1319+
linalg::Conv2DNhwcFhwcOp op, int64_t m,
1320+
int64_t r);
1321+
13151322
//===----------------------------------------------------------------------===//
13161323
// Rewrite patterns wrapping transformations.
13171324
// TODO: every single such pattern should be a close to noop wrapper around a

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3480,6 +3480,31 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
34803480
return DiagnosedSilenceableFailure::success();
34813481
}
34823482

3483+
//===----------------------------------------------------------------------===//
3484+
// WinogradConv2DOp
3485+
//===----------------------------------------------------------------------===//
3486+
3487+
DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
3488+
transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3489+
transform::ApplyToEachResultList &results,
3490+
transform::TransformState &state) {
3491+
rewriter.setInsertionPoint(target);
3492+
auto maybeTransformed =
3493+
TypeSwitch<Operation *, FailureOr<Operation *>>(target)
3494+
.Case([&](linalg::Conv2DNhwcFhwcOp op) {
3495+
return winogradConv2D(rewriter, op, getM(), getR());
3496+
})
3497+
.Default([&](Operation *op) {
3498+
return rewriter.notifyMatchFailure(op, "not supported");
3499+
});
3500+
3501+
if (failed(maybeTransformed))
3502+
return emitDefaultSilenceableFailure(target);
3503+
3504+
results.push_back(*maybeTransformed);
3505+
return DiagnosedSilenceableFailure::success();
3506+
}
3507+
34833508
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
34843509

34853510
#define GET_OP_CLASSES

mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,12 @@ class WinogradConv2DNhwcFhwc final
311311
} // end anonymous namespace
312312

313313
//===----------------------------------------------------------------------===//
314+
FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
315+
linalg::Conv2DNhwcFhwcOp op, int64_t m,
316+
int64_t r) {
317+
return winogradConv2DHelper(rewriter, op, m, r);
318+
}
319+
314320
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
315321
int64_t r) {
316322
MLIRContext *context = patterns.getContext();
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s
2+
3+
func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
4+
%0 = tensor.empty() : tensor<2x8x8x2xf32>
5+
%1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) {
6+
^bb0(%in: f32, %out: f32):
7+
linalg.yield %in : f32
8+
} -> tensor<2x8x8x2xf32>
9+
%2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
10+
return %2 : tensor<2x8x8x2xf32>
11+
}
12+
13+
module attributes {transform.with_named_sequence} {
14+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
15+
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
16+
%1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
17+
transform.yield
18+
}
19+
}
20+
21+
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
22+
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
23+
// CHECK-LABEL: func.func @conv2d
24+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
25+
// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
26+
// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) {
27+
// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
28+
// CHECK-NEXT: linalg.yield %[[IN]] : f32
29+
// CHECK-NEXT: } -> tensor<2x8x8x2xf32>
30+
// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<2x2x6x6x5x2xf32>
31+
// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<2x2x6x6x5x2xf32>) -> tensor<2x2x6x6x5x2xf32>
32+
// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<2x2x6x6x2x5xf32>
33+
// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<2x2x6x6x2x5xf32>) -> tensor<2x2x6x6x2x5xf32>
34+
// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x5x2xf32> into tensor<144x5x2xf32>
35+
// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<2x2x6x6x2x5xf32> into tensor<144x2x5xf32>
36+
// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32>
37+
// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
38+
// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [2, 2, 6, 6, 2, 2] : tensor<144x2x2xf32> into tensor<2x2x6x6x2x2xf32>
39+
// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<2x2x6x6x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
40+
// CHECK-NEXT: return %[[S8]] : tensor<2x8x8x2xf32>
41+
// CHECK-NEXT: }
42+
43+
// -----
44+
45+
func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
46+
%0 = tensor.empty() : tensor<2x9x9x2xf32>
47+
%1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x9x9x2xf32>) {
48+
^bb0(%in: f32, %out: f32):
49+
linalg.yield %in : f32
50+
} -> tensor<2x9x9x2xf32>
51+
%2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32>
52+
return %2 : tensor<2x9x9x2xf32>
53+
}
54+
55+
module attributes {transform.with_named_sequence} {
56+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
57+
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
58+
%1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
59+
transform.yield
60+
}
61+
}
62+
63+
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
64+
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
65+
// CHECK-LABEL: func.func @conv2d_unaligned
66+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x11x11x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x9x9x2xf32> {
67+
// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x9x9x2xf32>
68+
// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x9x9x2xf32>) {
69+
// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
70+
// CHECK-NEXT: linalg.yield %[[IN]] : f32
71+
// CHECK-NEXT: } -> tensor<2x9x9x2xf32>
72+
// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<3x3x6x6x5x2xf32>
73+
// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<3x3x6x6x5x2xf32>) -> tensor<3x3x6x6x5x2xf32>
74+
// CHECK-NEXT: %[[INPUT_BUF:.*]] = tensor.empty() : tensor<2x14x14x5xf32>
75+
// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[ARG0]] into %[[INPUT_BUF]][0, 0, 0, 0] [2, 11, 11, 5] [1, 1, 1, 1] : tensor<2x11x11x5xf32> into tensor<2x14x14x5xf32>
76+
// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<3x3x6x6x2x5xf32>
77+
// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform m(4) r(3) ins(%[[INSERTED_SLICE]] : tensor<2x14x14x5xf32>) outs(%[[S4]] : tensor<3x3x6x6x2x5xf32>) -> tensor<3x3x6x6x2x5xf32>
78+
// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x5x2xf32> into tensor<324x5x2xf32>
79+
// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1, 2, 3], [4], [5]] : tensor<3x3x6x6x2x5xf32> into tensor<324x2x5xf32>
80+
// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<324x2x2xf32>
81+
// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<324x2x5xf32>, tensor<324x5x2xf32>) outs(%[[S6]] : tensor<324x2x2xf32>) -> tensor<324x2x2xf32>
82+
// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1, 2, 3], [4], [5]] output_shape [3, 3, 6, 6, 2, 2] : tensor<324x2x2xf32> into tensor<3x3x6x6x2x2xf32>
83+
// CHECK-NEXT: %[[OUTPUT_BUF:.*]] = tensor.empty() : tensor<2x12x12x2xf32>
84+
// CHECK-NEXT: %[[INSERTED_SLICE_2:.*]] = tensor.insert_slice %[[S1]] into %[[OUTPUT_BUF]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x9x9x2xf32> into tensor<2x12x12x2xf32>
85+
// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<3x3x6x6x2x2xf32>) outs(%[[INSERTED_SLICE_2]] : tensor<2x12x12x2xf32>) -> tensor<2x12x12x2xf32>
86+
// CHECK-NEXT: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[S8]][0, 0, 0, 0] [2, 9, 9, 2] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<2x9x9x2xf32>
87+
// CHECK-NEXT: return %[[EXTRACTED_SLICE]] : tensor<2x9x9x2xf32>
88+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)