Skip to content

[mlir][linalg] Add transform operator for Winograd Conv2D algorithm #96177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,118 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
let hasVerifier = 1;
}

def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform"> {
let summary = "Winograd filter transform operator";
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
matrix multiply. Before the matrix multiply, it will convert filter and
input into a format suitable for batched matrix multiply. After the matrix
multiply, it will convert output to the final result tensor.

The algorithm F(m x m, r x r) is

Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A

The size of output Y is m x m. The size of filter g is r x r. The size of
input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
transformation matrices.

This operator is defined to represent the high level concept of filter
transformation (G x g x G^T) in the Winograd Conv2D algorithm.
}];

let arguments = (ins AnyRankedTensor:$filter,
AnyRankedTensor:$output,
I64Attr:$m,
I64Attr:$r
);

let results = (outs AnyRankedTensor:$result);
let assemblyFormat = [{
attr-dict
`m` `(` $m `)`
`r` `(` $r `)`
`ins` `(` $filter `:` type($filter) `)`
`outs` `(` $output `:` type($output) `)`
`->` type($result)
}];
let hasVerifier = 1;
}

def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform"> {
let summary = "Winograd input transform operator";
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
matrix multiply. Before the matrix multiply, it will convert filter and
input into a format suitable for batched matrix multiply. After the matrix
multiply, it will convert output to the final result tensor.

The algorithm F(m x m, r x r) is

Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A

The size of output Y is m x m. The size of filter g is r x r. The size of
input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
transformation matrices.

This operator is defined to represent the high level concept of input
transformation (B^T x d x B) in the Winograd Conv2D algorithm.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
I64Attr:$m,
I64Attr:$r
);

let results = (outs AnyRankedTensor:$result);
let assemblyFormat = [{
attr-dict
`m` `(` $m `)`
`r` `(` $r `)`
`ins` `(` $input `:` type($input) `)`
`outs` `(` $output `:` type($output) `)`
`->` type($result)
}];
let hasVerifier = 1;
}

def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform"> {
let summary = "Winograd output transform operator";
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
matrix multiply. Before the matrix multiply, it will convert filter and
input into a format suitable for batched matrix multiply. After the matrix
multiply, it will convert output to the final result tensor.

The algorithm F(m x m, r x r) is

Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A

The size of output Y is m x m. The size of filter g is r x r. The size of
input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
transformation matrices.

This operator is defined to represent the high level concept of output
transformation (A^T x y x A) in the Winograd Conv2D algorithm.
}];

let arguments = (ins AnyRankedTensor:$value,
AnyRankedTensor:$output,
I64Attr:$m,
I64Attr:$r
);

let results = (outs AnyRankedTensor:$result);
let assemblyFormat = [{
attr-dict
`m` `(` $m `)`
`r` `(` $r `)`
`ins` `(` $value `:` type($value) `)`
`outs` `(` $output `:` type($output) `)`
`->` type($result)
}];
let hasVerifier = 1;
}

#endif // LINALG_OPS
Original file line number Diff line number Diff line change
Expand Up @@ -2587,4 +2587,55 @@ def MapCopyToThreadsOp :
}];
}

//===----------------------------------------------------------------------===//
// Winograd Conv2D
//===----------------------------------------------------------------------===//

def WinogradConv2DOp : Op<Transform_Dialect,
"structured.winograd_conv2d",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
matrix multiply. Before the matrix multiply, it will convert filter and
input into a format suitable for batched matrix multiply. After the matrix
multiply, it will convert output to the final result tensor.

The algorithm F(m x m, r x r) is

Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A

The size of output Y is m x m. The size of filter g is r x r. The size of
input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
transformation matrices.

#### Return modes:

This operation fails if `target` is unsupported. Otherwise, the operation
succeeds and returns a handle of the sequence that replaces the original
convolution.
}];

let arguments = (ins TransformHandleTypeInterface:$target,
I64Attr:$m,
I64Attr:$r);
let results = (outs TransformHandleTypeInterface:$transformed);

let assemblyFormat =
"$target attr-dict `:` functional-type($target, results)";

let builders = [
OpBuilder<(ins "Value":$target)>
];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::linalg::LinalgOp target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}

#endif // LINALG_TRANSFORM_OPS
11 changes: 11 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,13 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
linalg::BatchMatmulOp op,
bool transposeLHS = true);

/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm
/// F(m x m, r x r). m is the dimension size of output and r is the dimension
/// size of filter.
FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
linalg::Conv2DNhwcFhwcOp op, int64_t m,
int64_t r);

//===----------------------------------------------------------------------===//
// Rewrite patterns wrapping transformations.
// TODO: every single such pattern should be a close to noop wrapper around a
Expand Down Expand Up @@ -1692,6 +1699,10 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
const ControlBlockPackMatmulFn &controlFn);

/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
int64_t r);

} // namespace linalg
} // namespace mlir

Expand Down
78 changes: 78 additions & 0 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2734,6 +2734,84 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
return SmallVector<Value>{result};
}

//===----------------------------------------------------------------------===//
// WinogradFilterTransformOp
//===----------------------------------------------------------------------===//

LogicalResult WinogradFilterTransformOp::verify() {
auto filterType = cast<ShapedType>(getFilter().getType());
auto outputType = cast<ShapedType>(getOutput().getType());
auto filterElemType = filterType.getElementType();
auto outputElemType = outputType.getElementType();
if (filterElemType != outputElemType) {
return emitOpError() << "expected element type of input " << filterElemType
<< " to match element type of output "
<< outputElemType;
}

unsigned filterRank = filterType.getRank();
if (filterRank != 4)
return emitOpError() << "expected rank of input is 4";

unsigned outputRank = outputType.getRank();
if (outputRank != 6)
return emitOpError() << "expected rank of output is 6";

return success();
}

//===----------------------------------------------------------------------===//
// WinogradInputTransformOp
//===----------------------------------------------------------------------===//

LogicalResult WinogradInputTransformOp::verify() {
auto inputType = cast<ShapedType>(getInput().getType());
auto outputType = cast<ShapedType>(getOutput().getType());
auto inputElemType = inputType.getElementType();
auto outputElemType = outputType.getElementType();
if (inputElemType != outputElemType) {
return emitOpError() << "expected element type of input " << inputElemType
<< " to match element type of output "
<< outputElemType;
}

unsigned inputRank = inputType.getRank();
if (inputRank != 4)
return emitOpError() << "expected rank of input is 4";

unsigned outputRank = outputType.getRank();
if (outputRank != 6)
return emitOpError() << "expected rank of output is 6";

return success();
}

//===----------------------------------------------------------------------===//
// WinogradOutputTransformOp
//===----------------------------------------------------------------------===//

LogicalResult WinogradOutputTransformOp::verify() {
auto valueType = cast<ShapedType>(getValue().getType());
auto outputType = cast<ShapedType>(getOutput().getType());
auto valueElemType = valueType.getElementType();
auto outputElemType = outputType.getElementType();
if (valueElemType != outputElemType) {
return emitOpError() << "expected element type of value " << valueElemType
<< " to match element type of output "
<< outputElemType;
}

unsigned valueRank = valueType.getRank();
if (valueRank != 6)
return emitOpError() << "expected rank of input is 6";

unsigned outputRank = outputType.getRank();
if (outputRank != 4)
return emitOpError() << "expected rank of output is 4";

return success();
}

//===----------------------------------------------------------------------===//
// LinalgDialect
//===----------------------------------------------------------------------===//
Expand Down
25 changes: 25 additions & 0 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3480,6 +3480,31 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
return DiagnosedSilenceableFailure::success();
}

//===----------------------------------------------------------------------===//
// WinogradConv2DOp
//===----------------------------------------------------------------------===//

DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
transform::TransformRewriter &rewriter, linalg::LinalgOp target,
transform::ApplyToEachResultList &results,
transform::TransformState &state) {
rewriter.setInsertionPoint(target);
auto maybeTransformed =
TypeSwitch<Operation *, FailureOr<Operation *>>(target)
.Case([&](linalg::Conv2DNhwcFhwcOp op) {
return winogradConv2D(rewriter, op, getM(), getR());
})
.Default([&](Operation *op) {
return rewriter.notifyMatchFailure(op, "not supported");
});

if (failed(maybeTransformed))
return emitDefaultSilenceableFailure(target);

results.push_back(*maybeTransformed);
return DiagnosedSilenceableFailure::success();
}

#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"

#define GET_OP_CLASSES
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
Transforms.cpp
TransposeConv2D.cpp
Vectorization.cpp
WinogradConv2D.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg
Expand Down
Loading
Loading