Skip to content

[mlir][linalg] Implement TilingInterface for winograd operators #96184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 135 additions & 6 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,13 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
let hasVerifier = 1;
}

def Linalg_WinogradFilterTransformOp :
Linalg_Op<"winograd_filter_transform", [AllElementTypesMatch<["filter", "output"]>]> {
def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
[AllElementTypesMatch<["filter", "output"]>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation"]>]> {
let summary = "Winograd filter transform operator";
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
Expand Down Expand Up @@ -190,11 +195,42 @@ def Linalg_WinogradFilterTransformOp :
`outs` `(` $output `:` type($output) `)`
`->` type($result)
}];
let extraClassDeclaration = [{
ShapedType getFilterOperandType() {
return cast<ShapedType>(getFilter().getType());
}
ShapedType getOutputOperandType() {
return cast<ShapedType>(getOutput().getType());
}
int64_t getFilterOperandRank() {
return getFilterOperandType().getRank();
}
int64_t getOutputOperandRank() {
return getOutputOperandType().getRank();
}
int64_t getFilterFDim() {
return 0;
}
int64_t getFilterHDim() {
return 1;
}
int64_t getFilterWDim() {
return 2;
}
int64_t getFilterCDim() {
return 3;
}
}];
let hasVerifier = 1;
}

def Linalg_WinogradInputTransformOp :
Linalg_Op<"winograd_input_transform", [AllElementTypesMatch<["input", "output"]>]> {
def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
[AllElementTypesMatch<["input", "output"]>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation"]>]> {
let summary = "Winograd input transform operator";
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
Expand Down Expand Up @@ -229,11 +265,60 @@ def Linalg_WinogradInputTransformOp :
`outs` `(` $output `:` type($output) `)`
`->` type($result)
}];
let extraClassDeclaration = [{
ShapedType getInputOperandType() {
return cast<ShapedType>(getInput().getType());
}
ShapedType getOutputOperandType() {
return cast<ShapedType>(getOutput().getType());
}
int64_t getInputOperandRank() {
return getInputOperandType().getRank();
}
int64_t getOutputOperandRank() {
return getOutputOperandType().getRank();
}
int64_t getInputNDim() {
return 0;
}
int64_t getInputHDim() {
return 1;
}
int64_t getInputWDim() {
return 2;
}
int64_t getInputCDim() {
return 3;
}
int64_t getOutputAlphaHDim() {
return 0;
}
int64_t getOutputAlphaWDim() {
return 1;
}
int64_t getOutputTileHDim() {
return 2;
}
int64_t getOutputTileWDim() {
return 3;
}
int64_t getOutputNDim() {
return 4;
}
int64_t getOutputCDim() {
return 5;
}
}];
let hasVerifier = 1;
}

def Linalg_WinogradOutputTransformOp :
Linalg_Op<"winograd_output_transform", [AllElementTypesMatch<["value", "output"]>]> {
def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
[AllElementTypesMatch<["value", "output"]>,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
"getResultTilePosition",
"getTiledImplementation"]>]> {
let summary = "Winograd output transform operator";
let description = [{
Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
Expand Down Expand Up @@ -268,6 +353,50 @@ def Linalg_WinogradOutputTransformOp :
`outs` `(` $output `:` type($output) `)`
`->` type($result)
}];
let extraClassDeclaration = [{
ShapedType getValueOperandType() {
return cast<ShapedType>(getValue().getType());
}
ShapedType getOutputOperandType() {
return cast<ShapedType>(getOutput().getType());
}
int64_t getValueOperandRank() {
return getValueOperandType().getRank();
}
int64_t getOutputOperandRank() {
return getOutputOperandType().getRank();
}
int64_t getValueAlphaHDim() {
return 0;
}
int64_t getValueAlphaWDim() {
return 1;
}
int64_t getValueTileHDim() {
return 2;
}
int64_t getValueTileWDim() {
return 3;
}
int64_t getValueNDim() {
return 4;
}
int64_t getValueFDim() {
return 5;
}
int64_t getOutputNDim() {
return 0;
}
int64_t getOutputHDim() {
return 1;
}
int64_t getOutputWDim() {
return 2;
}
int64_t getOutputFDim() {
return 3;
}
}];
let hasVerifier = 1;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2697,4 +2697,41 @@ def WinogradConv2DOp : Op<Transform_Dialect,
}];
}

def DecomposeWinogradOp : Op<Transform_Dialect,
"structured.decompose_winograd_op",
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
TransformOpInterface, TransformEachOpTrait,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Decompose winograd operations. It will convert filter, input and output
transform operations into a combination of scf, tensor, and linalg
equivalent operations. Before applying this transform operations, users
need to tile winograd transform operations into supported sizes.

#### Return modes:

This operation fails if `target` is unsupported. Otherwise, the operation
succeeds and returns a handle of the sequence that replaces the original
operations.
}];

let arguments = (ins TransformHandleTypeInterface:$target);
let results = (outs TransformHandleTypeInterface:$transformed);

let assemblyFormat =
"$target attr-dict `:` functional-type($target, results)";

let builders = [
OpBuilder<(ins "Value":$target)>
];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::transform::TransformRewriter &rewriter,
::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}

#endif // LINALG_TRANSFORM_OPS
57 changes: 57 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,63 @@ FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
linalg::Conv2DNhwcFhwcOp op, int64_t m,
int64_t r);

/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
/// from FHWC first. We generate 2 levels of loops to iterate on F and C. After
/// the rewriting, we get
///
/// scf.for %f = lo_f to hi_f step 1
/// scf.for %c = lo_c to hi_c step 1
/// %extracted = extract filter<h x w> from filter<f x h x w x c>
/// %ret = linalg.matmul G, %extracted
/// %ret = linalg.matmul %ret, GT
/// %inserted = insert %ret into filter<h x w x c x f>
FailureOr<Operation *>
decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
linalg::WinogradFilterTransformOp op);

/// Rewrite linalg.winograd_input_transform. The data layout of the input is
/// NHWC. The transformation matrix is 2-dimension. We need to extract H x W
/// from NHWC first. We generate 4 levels of loops to iterate on N, C, tileH,
/// and tileW. After the rewriting, we get
///
/// scf.for %h = 0 to tileH step 1
/// scf.for %w = 0 to tileW step 1
/// scf.for %n = 0 to N step 1
/// scf.for %c = 0 to C step 1
/// %extracted = extract %extracted<alphaH x alphaW> from
/// %input<N x H x W x C>
/// at [%n, (%h x m), (%w x m), %c]
/// %ret = linalg.matmul BT, %extracted
/// %ret = linalg.matmul %ret, B
/// %inserted = insert %ret<alphaH x alphaW> into
/// %output<alphaH x alphaW x tileH x tileW x N x C>
/// at [0, 0, %h, %w, %n, %c]
FailureOr<Operation *>
decomposeWinogradInputTransformOp(RewriterBase &rewriter,
linalg::WinogradInputTransformOp op);

/// Rewrite linalg.winograd_output_transform. The data layout of the output is
/// HWNF. The transformation matrix is 2-dimension. We need to extract H x W
/// from HWNF first. We generate 4 levels of loops to iterate on N, F, tileH,
/// and tileW. After the transformation, we get
///
/// scf.for %h = 0 to tileH step 1
/// scf.for %w = 0 to tileW step 1
/// scf.for %n = 0 to N step 1
/// scf.for %f = 0 to F step 1
/// %extracted = extract %extracted<alphaH x alphaW> from
/// %input<alphaH x alphaW x tileH x tileW x N x F>
/// at [0, 0, %h, %w, %n, %f]
/// %ret = linalg.matmul AT, %extracted
/// %ret = linalg.matmul %ret, A
/// %inserted = insert %ret<alphaH x alphaW> into
/// output<N x H x W x F>
/// at [%n, (%h x m), (%w x m), %f]
FailureOr<Operation *>
decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
linalg::WinogradOutputTransformOp op);

//===----------------------------------------------------------------------===//
// Rewrite patterns wrapping transformations.
// TODO: every single such pattern should be a close to noop wrapper around a
Expand Down
Loading
Loading