Skip to content

Commit 56d6983

Browse files
committed
[mlir][linalg] Add transform operator for Winograd Conv2D algorithm
Add a transform operator structured.winograd_conv2d to convert linalg.conv_2d_nhwc_fhwc to Linalg winograd operators.
1 parent 0b6f8ae commit 56d6983

File tree

5 files changed

+130
-0
lines changed

5 files changed

+130
-0
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2587,4 +2587,55 @@ def MapCopyToThreadsOp :
25872587
}];
25882588
}
25892589

2590+
//===----------------------------------------------------------------------===//
2591+
// Winograd Conv2D
2592+
//===----------------------------------------------------------------------===//
2593+
2594+
def WinogradConv2DOp : Op<Transform_Dialect,
2595+
"structured.winograd_conv2d",
2596+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
2597+
TransformOpInterface, TransformEachOpTrait,
2598+
ReportTrackingListenerFailuresOpTrait]> {
2599+
let description = [{
2600+
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
2601+
matrix multiply. Before the matrix multiply, it will convert filter and
2602+
input into a format suitable for batched matrix multiply. After the matrix
2603+
multiply, it will convert output to the final result tensor.
2604+
2605+
The algorithm F(m x m, r x r) is
2606+
2607+
Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
2608+
2609+
The size of output Y is m x m. The size of filter g is r x r. The size of
2610+
input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
2611+
transformation matrices.
2612+
2613+
#### Return modes:
2614+
2615+
This operation fails if `target` is unsupported. Otherwise, the operation
2616+
succeeds and returns a handle of the sequence that replaces the original
2617+
convolution.
2618+
}];
2619+
2620+
let arguments = (ins TransformHandleTypeInterface:$target,
2621+
I64Attr:$m,
2622+
I64Attr:$r);
2623+
let results = (outs TransformHandleTypeInterface:$transformed);
2624+
2625+
let assemblyFormat =
2626+
"$target attr-dict `:` functional-type($target, results)";
2627+
2628+
let builders = [
2629+
OpBuilder<(ins "Value":$target)>
2630+
];
2631+
2632+
let extraClassDeclaration = [{
2633+
::mlir::DiagnosedSilenceableFailure applyToOne(
2634+
::mlir::transform::TransformRewriter &rewriter,
2635+
::mlir::linalg::LinalgOp target,
2636+
::mlir::transform::ApplyToEachResultList &results,
2637+
::mlir::transform::TransformState &state);
2638+
}];
2639+
}
2640+
25902641
#endif // LINALG_TRANSFORM_OPS

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,6 +1312,13 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
13121312
linalg::BatchMatmulOp op,
13131313
bool transposeLHS = true);
13141314

1315+
/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm
1316+
/// F(m x m, r x r). m is the dimension size of output and r is the dimension
1317+
/// size of filter.
1318+
FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
1319+
linalg::Conv2DNhwcFhwcOp op, int64_t m,
1320+
int64_t r);
1321+
13151322
//===----------------------------------------------------------------------===//
13161323
// Rewrite patterns wrapping transformations.
13171324
// TODO: every single such pattern should be a close to noop wrapper around a

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3480,6 +3480,31 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
34803480
return DiagnosedSilenceableFailure::success();
34813481
}
34823482

3483+
//===----------------------------------------------------------------------===//
3484+
// WinogradConv2DOp
3485+
//===----------------------------------------------------------------------===//
3486+
3487+
DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
3488+
transform::TransformRewriter &rewriter, linalg::LinalgOp target,
3489+
transform::ApplyToEachResultList &results,
3490+
transform::TransformState &state) {
3491+
rewriter.setInsertionPoint(target);
3492+
auto maybeTransformed =
3493+
TypeSwitch<Operation *, FailureOr<Operation *>>(target)
3494+
.Case([&](linalg::Conv2DNhwcFhwcOp op) {
3495+
return winogradConv2D(rewriter, op, getM(), getR());
3496+
})
3497+
.Default([&](Operation *op) {
3498+
return rewriter.notifyMatchFailure(op, "not supported");
3499+
});
3500+
3501+
if (failed(maybeTransformed))
3502+
return emitDefaultSilenceableFailure(target);
3503+
3504+
results.push_back(*maybeTransformed);
3505+
return DiagnosedSilenceableFailure::success();
3506+
}
3507+
34833508
#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
34843509

34853510
#define GET_OP_CLASSES

mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,12 @@ class WinogradConv2DNhwcFhwc final
218218
} // end anonymous namespace
219219

220220
//===----------------------------------------------------------------------===//
221+
FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
222+
linalg::Conv2DNhwcFhwcOp op, int64_t m,
223+
int64_t r) {
224+
return winogradConv2DHelper(rewriter, op, m, r);
225+
}
226+
221227
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
222228
int64_t r) {
223229
MLIRContext *context = patterns.getContext();
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file | FileCheck %s
2+
3+
func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
4+
%0 = tensor.empty() : tensor<2x8x8x2xf32>
5+
%1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (0)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%0 : tensor<2x8x8x2xf32>) {
6+
^bb0(%in: f32, %out: f32):
7+
linalg.yield %in : f32
8+
} -> tensor<2x8x8x2xf32>
9+
%2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%1 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
10+
return %2 : tensor<2x8x8x2xf32>
11+
}
12+
13+
module attributes {transform.with_named_sequence} {
14+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
15+
%0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op
16+
%1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op)
17+
transform.yield
18+
}
19+
}
20+
21+
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (0)>
22+
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
23+
// CHECK-LABEL: func.func @conv2d
24+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x10x10x5xf32>, %[[ARG1:.*]]: tensor<2x3x3x5xf32>, %[[ARG2:.*]]: tensor<1xf32>) -> tensor<2x8x8x2xf32> {
25+
// CHECK: %[[S0:.*]] = tensor.empty() : tensor<2x8x8x2xf32>
26+
// CHECK-NEXT: %[[S1:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG2]] : tensor<1xf32>) outs(%[[S0]] : tensor<2x8x8x2xf32>) {
27+
// CHECK-NEXT: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: f32):
28+
// CHECK-NEXT: linalg.yield %[[IN]] : f32
29+
// CHECK-NEXT: } -> tensor<2x8x8x2xf32>
30+
// CHECK-NEXT: %[[S2:.*]] = tensor.empty() : tensor<12x12x5x2xf32>
31+
// CHECK-NEXT: %[[S3:.*]] = linalg.winograd_filter_transform output_height(8) output_width(8) m(4) r(3) ins(%[[ARG1]] : tensor<2x3x3x5xf32>) outs(%[[S2]] : tensor<12x12x5x2xf32>) -> tensor<12x12x5x2xf32>
32+
// CHECK-NEXT: %[[S4:.*]] = tensor.empty() : tensor<12x12x2x5xf32>
33+
// CHECK-NEXT: %[[S5:.*]] = linalg.winograd_input_transform output_height(8) output_width(8) m(4) r(3) ins(%[[ARG0]] : tensor<2x10x10x5xf32>) outs(%[[S4]] : tensor<12x12x2x5xf32>) -> tensor<12x12x2x5xf32>
34+
// CHECK-NEXT: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2], [3]] : tensor<12x12x5x2xf32> into tensor<144x5x2xf32>
35+
// CHECK-NEXT: %[[COLLAPSED_0:.*]] = tensor.collapse_shape %[[S5]] {{\[}}[0, 1], [2], [3]] : tensor<12x12x2x5xf32> into tensor<144x2x5xf32>
36+
// CHECK-NEXT: %[[S6:.*]] = tensor.empty() : tensor<144x2x2xf32>
37+
// CHECK-NEXT: %[[S7:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_0]], %[[COLLAPSED]] : tensor<144x2x5xf32>, tensor<144x5x2xf32>) outs(%[[S6]] : tensor<144x2x2xf32>) -> tensor<144x2x2xf32>
38+
// CHECK-NEXT: %[[EXPANDED:.*]] = tensor.expand_shape %[[S7]] {{\[}}[0, 1], [2], [3]] output_shape [12, 12, 2, 2] : tensor<144x2x2xf32> into tensor<12x12x2x2xf32>
39+
// CHECK-NEXT: %[[S8:.*]] = linalg.winograd_output_transform m(4) r(3) ins(%[[EXPANDED]] : tensor<12x12x2x2xf32>) outs(%[[S1]] : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32>
40+
// CHECK-NEXT: return %[[S8]] : tensor<2x8x8x2xf32>
41+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)