Skip to content

[mlir][linalg] Implement TilingInterface for winograd operators #96179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

Hsiangkai
Copy link
Contributor

In order to support arbitrary size input data of conv2d, implement
TilingInterface for winograd operators. Before converting winograd
operators into nested loops with matrix multiply, tile the input of
conv2d into the supported size first.

Add a transform operator structured.decompose_winograd_op to decompose
winograd operators. Before applying the transform op, use tile_using_for
to tile the input data into supported size. The test case shows how to
tile and decompose winograd operators.

@llvmbot
Copy link
Member

llvmbot commented Jun 20, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Hsiangkai Wang (Hsiangkai)

Changes

In order to support arbitrary size input data of conv2d, implement
TilingInterface for winograd operators. Before converting winograd
operators into nested loops with matrix multiply, tile the input of
conv2d into the supported size first.

Add a transform operator structured.decompose_winograd_op to decompose
winograd operators. Before applying the transform op, use tile_using_for
to tile the input data into supported size. The test case shows how to
tile and decompose winograd operators.


Patch is 153.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/96179.diff

12 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td (+129)
  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+88)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+59)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+359)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+52)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp (+1118)
  • (added) mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir (+332)
  • (added) mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir (+88)
  • (added) mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir (+105)
  • (added) mlir/test/Dialect/Linalg/winograd-conv2d.mlir (+248)
  • (modified) mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp (+24)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 64c538367267d..45726d6ee2224 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -154,4 +154,133 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
   let hasVerifier = 1;
 }
 
+def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
+    [DeclareOpInterfaceMethods<TilingInterface,
+     ["getIterationDomain",
+      "getLoopIteratorTypes",
+      "getResultTilePosition",
+      "getTiledImplementation"]>]> {
+  let summary = "Winograd filter transform operator";
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    This operator is defined to represent the high level concept of filter
+    transformation (G x g x G^T) in the Winograd Conv2D algorithm.
+  }];
+
+  let arguments = (ins AnyRankedTensor:$filter,
+                       AnyRankedTensor:$output,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $filter `:` type($filter) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
+def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
+    [DeclareOpInterfaceMethods<TilingInterface,
+      ["getIterationDomain",
+       "getLoopIteratorTypes",
+       "getResultTilePosition",
+       "getTiledImplementation"]>]> {
+  let summary = "Winograd input transform operator";
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    This operator is defined to represent the high level concept of input
+    transformation (B^T x d x B) in the Winograd Conv2D algorithm.
+  }];
+
+  let arguments = (ins AnyRankedTensor:$input,
+                       AnyRankedTensor:$output,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $input `:` type($input) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
+def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
+    [DeclareOpInterfaceMethods<TilingInterface,
+      ["getIterationDomain",
+       "getLoopIteratorTypes",
+       "getResultTilePosition",
+       "getTiledImplementation"]>]> {
+  let summary = "Winograd output transform operator";
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    This operator is defined to represent the high level concept of output
+    transformation (A^T x y x A) in the Winograd Conv2D algorithm.
+  }];
+
+  let arguments = (ins AnyRankedTensor:$value,
+                       AnyRankedTensor:$output,
+                       I64Attr:$m,
+                       I64Attr:$r
+  );
+
+  let results = (outs AnyRankedTensor:$result);
+  let assemblyFormat = [{
+    attr-dict
+    `m` `(` $m `)`
+    `r` `(` $r `)`
+    `ins` `(` $value `:` type($value) `)`
+    `outs` `(` $output `:` type($output) `)`
+    `->` type($result)
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // LINALG_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 93e2c2db729da..71736eae38b4f 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -2587,4 +2587,92 @@ def MapCopyToThreadsOp :
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// Winograd Conv2D
+//===----------------------------------------------------------------------===//
+
+def WinogradConv2DOp : Op<Transform_Dialect,
+    "structured.winograd_conv2d",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Winograd Conv2D algorithm will convert linalg Conv2D operator into batched
+    matrix multiply. Before the matrix multiply, it will convert filter and
+    input into a format suitable for batched matrix multiply. After the matrix
+    multiply, it will convert output to the final result tensor.
+
+    The algorithm F(m x m, r x r) is
+
+    Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A
+
+    The size of output Y is m x m. The size of filter g is r x r. The size of
+    input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are
+    transformation matrices.
+
+    #### Return modes:
+
+    This operation fails if `target` is unsupported. Otherwise, the operation
+    succeeds and returns a handle of the sequence that replaces the original
+    convolution.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target,
+                       I64Attr:$m,
+                       I64Attr:$r);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat =
+    "$target attr-dict `:` functional-type($target, results)";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>
+  ];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::linalg::LinalgOp target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
+def DecomposeWinogradOp : Op<Transform_Dialect,
+    "structured.decompose_winograd_op",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     TransformOpInterface, TransformEachOpTrait,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Decompose winograd operators. It will convert filter, input and output
+    transform operators into a combination of scf, tensor, and linalg
+    equivalent operators. Before applying this transform operator, users
+    need to tile winograd transform operators into supported sizes.
+
+    #### Return modes:
+
+    This operation fails if `target` is unsupported. Otherwise, the operation
+    succeeds and returns a handle of the sequence that replaces the original
+    operator.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$transformed);
+
+  let assemblyFormat =
+    "$target attr-dict `:` functional-type($target, results)";
+
+  let builders = [
+    OpBuilder<(ins "Value":$target)>
+  ];
+
+  let extraClassDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure applyToOne(
+        ::mlir::transform::TransformRewriter &rewriter,
+        ::mlir::Operation *target,
+        ::mlir::transform::ApplyToEachResultList &results,
+        ::mlir::transform::TransformState &state);
+  }];
+}
+
 #endif // LINALG_TRANSFORM_OPS
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 05e97befdec1f..d0eec2be1f8fb 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1312,6 +1312,58 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
                                             linalg::BatchMatmulOp op,
                                             bool transposeLHS = true);
 
+/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm
+/// F(m x m, r x r). m is the dimension size of output and r is the dimension
+/// size of filter.
+FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
+                                      linalg::Conv2DNhwcFhwcOp op, int64_t m,
+                                      int64_t r);
+
+/// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
+/// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
+/// from FHWC first. We need to generate 2 levels of loops to iterate on F and
+/// C. After the rewriting, we get
+///
+/// scf.for %f = lo_f to hi_f step 1
+///   scf.for %c = lo_c to hi_c step 1
+///     %extracted = extract filter<h x w> from filter<f x h x w x c>
+///     %ret = linalg.matmul G, %extracted
+///     %ret = linalg.matmul %ret, GT
+///     %inserted = insert %ret into filter<h x w x c x f>
+FailureOr<Operation *>
+decomposeWinogradFilterTransformOp(RewriterBase &rewriter,
+                                   linalg::WinogradFilterTransformOp op);
+
+/// Rewrite linalg.winograd_input_transform. The data layout of the input is
+/// NHWC. The transformation matrix is 2-dimension. We need to extract H x W
+/// from NHWC first. We need to generate 2 levels of loops to iterate on N and
+/// C. After the rewriting, we get
+///
+/// scf.for %n = lo_n to hi_n step 1
+///   scf.for %c = lo_c to hi_c step 1
+///     %extracted = extract input<h x w> from input<n x h x w x c>
+///     %ret = linalg.matmul BT, %extracted
+///     %ret = linalg.matmul %ret, B
+///     %inserted = insert %ret into input<h x w x n x c>
+FailureOr<Operation *>
+decomposeWinogradInputTransformOp(RewriterBase &rewriter,
+                                  linalg::WinogradInputTransformOp op);
+
+/// Rewrite linalg.winograd_output_transform. The data layout of the output is
+/// HWNF. The transformation matrix is 2-dimension. We need to extract H x W
+/// from HWNF first. We need to generate 2 levels of loops to iterate on N and
+/// F. After the transformation, we get
+///
+/// scf.for %n = lo_n to hi_n step 1
+///   scf.for %f = lo_f to hi_f step 1
+///     %extracted = extract input<h x w> from result<h x w x n x f>
+///     %ret = linalg.matmul AT, %extracted
+///     %ret = linalg.matmul %ret, A
+///     %inserted = insert %ret into ret<n x h x w x f>
+FailureOr<Operation *>
+decomposeWinogradOutputTransformOp(RewriterBase &rewriter,
+                                   linalg::WinogradOutputTransformOp op);
+
 //===----------------------------------------------------------------------===//
 // Rewrite patterns wrapping transformations.
 // TODO: every single such pattern should be a close to noop wrapper around a
@@ -1692,6 +1744,13 @@ void populateTransposeMatmulPatterns(RewritePatternSet &patterns,
 void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
                                      const ControlBlockPackMatmulFn &controlFn);
 
+/// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
+                                    int64_t r);
+
+/// Patterns to decompose Winograd operators.
+void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 57d126603ebd7..a416e1f6e257f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2734,6 +2734,365 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
   return SmallVector<Value>{result};
 }
 
+//===----------------------------------------------------------------------===//
+// WinogradFilterTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradFilterTransformOp::verify() {
+  auto filterType = cast<ShapedType>(getFilter().getType());
+  auto outputType = cast<ShapedType>(getOutput().getType());
+  auto filterElemType = filterType.getElementType();
+  auto outputElemType = outputType.getElementType();
+  if (filterElemType != outputElemType) {
+    return emitOpError() << "expected element type of input " << filterElemType
+                         << " to match element type of output "
+                         << outputElemType;
+  }
+
+  unsigned filterRank = filterType.getRank();
+  if (filterRank != 4)
+    return emitOpError() << "expected rank of input is 4";
+
+  unsigned outputRank = outputType.getRank();
+  if (outputRank != 6)
+    return emitOpError() << "expected rank of output is 6";
+
+  return success();
+}
+
+SmallVector<Range>
+WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
+  Location loc = getLoc();
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+  Value output = getOutput();
+  SmallVector<Range> loopBounds(6);
+  for (unsigned dim = 0; dim < 6; ++dim) {
+    loopBounds[dim].offset = zero;
+    loopBounds[dim].size = getDimValue(builder, loc, output, dim);
+    loopBounds[dim].stride = one;
+  }
+  return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradFilterTransformOp::getLoopIteratorTypes() {
+  SmallVector<utils::IteratorType> iteratorTypes(6,
+                                                 utils::IteratorType::parallel);
+  return iteratorTypes;
+}
+
+Value getValueFromOpFoldResult(OpFoldResult opFoldResult, OpBuilder &builder,
+                               Location loc) {
+  if (auto val = opFoldResult.dyn_cast<Value>()) {
+    return val;
+  } else if (auto attr = opFoldResult.dyn_cast<Attribute>()) {
+    auto intAttr = cast<IntegerAttr>(attr);
+    return builder.create<arith::ConstantOp>(loc, intAttr);
+  }
+  // This should never happen if OpFoldResult is correctly formed.
+  return nullptr;
+}
+
+LogicalResult WinogradFilterTransformOp::getResultTilePosition(
+    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  auto oneAttr = builder.getI64IntegerAttr(1);
+
+  resultOffsets.push_back(offsets[0]);
+  resultOffsets.push_back(offsets[1]);
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(zeroAttr);
+  resultSizes.push_back(oneAttr);
+  resultSizes.push_back(oneAttr);
+  resultSizes.push_back(sizes[2]);
+  resultSizes.push_back(sizes[3]);
+  resultSizes.push_back(sizes[4]);
+  resultSizes.push_back(sizes[5]);
+
+  return success();
+}
+
+FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
+    OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes) {
+  auto oneAttr = builder.getI64IntegerAttr(1);
+
+  Location loc = getLoc();
+  SmallVector<OpFoldResult> strides(6, oneAttr);
+  SmallVector<Value> tiledOperands;
+  tiledOperands.emplace_back(getFilter());
+
+  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
+  if (failed(getResultTilePosition(builder, 1, offsets, sizes, sliceOffsets,
+                                   sliceSizes)))
+    return failure();
+
+  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
+      loc, getOutput(), sliceOffsets, sliceSizes, strides));
+
+  SmallVector<Type, 4> resultTypes;
+  resultTypes.push_back(tiledOperands[1].getType());
+  Operation *tiledOp =
+      mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
+
+  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+}
+
+//===----------------------------------------------------------------------===//
+// WinogradInputTransformOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WinogradInputTransformOp::verify() {
+  auto inputType = cast<ShapedType>(getInput().getType());
+  auto outputType = cast<ShapedType>(getOutput().getType());
+  auto inputElemType = inputType.getElementType();
+  auto outputElemType = outputType.getElementType();
+  if (inputElemType != outputElemType) {
+    return emitOpError() << "expected element type of input " << inputElemType
+                         << " to match element type of output "
+                         << outputElemType;
+  }
+
+  unsigned inputRank = inputType.getRank();
+  if (inputRank != 4)
+    return emitOpError() << "expected rank of input is 4";
+
+  unsigned outputRank = outputType.getRank();
+  if (outputRank != 6)
+    return emitOpError() << "expected rank of output is 6";
+
+  return success();
+}
+
+SmallVector<Range>
+WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
+  Location loc = getLoc();
+  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
+  Value output = getOutput();
+  SmallVector<Range> loopBounds(6);
+  for (unsigned dim = 0; dim < 6; ++dim) {
+    loopBounds[dim].offset = zero;
+    loopBounds[dim].size = getDimValue(builder, loc, output, dim);
+    loopBounds[dim].stride = one;
+  }
+  return loopBounds;
+}
+
+SmallVector<utils::IteratorType>
+WinogradInputTransformOp::getLoopIteratorTypes() {
+  SmallVector<utils::IteratorType> iteratorTypes(6,
+                                                 utils::IteratorType::parallel);
+  return iteratorTypes;
+}
+
+LogicalResult WinogradInputTransformOp::getResultTilePosition(
+    OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
+    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
+    SmallVector<OpFoldResult> &resultSizes) {
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  auto oneAttr = builder.getI64IntegerAttr(1);
+
+  resultOffsets.push_back(offsets[0]);
+  resultOffsets.push_back(offsets[1]);
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(zeroAttr);
+  resultOffsets.push_back(zeroAttr);
+  resultSizes.push_back(oneAttr);
+  resultSizes.push_back(oneAttr);
+  resultSizes.push_back(sizes[2]);
+  resultSizes.push_back(sizes[3]);
+  resultSizes.push_back(sizes[4]);
+  resultSizes.push_back(sizes[5]);
+
+  return success();
+}
+
+FailureOr<TilingResult>
+WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
+                                                 ArrayRef<OpFoldResult> offsets,
+                                                 ArrayRef<OpFoldResult> sizes) {
+  auto oneAttr = builder.getI64IntegerAttr(1);
+  auto zeroAttr = builder.getI64IntegerAttr(0);
+  Value input = getInput();
+  auto inputType = cast<ShapedType>(input.getType());
+  auto inputShape = inputType.getShape();
+  int64_t inputH = inputShape[1];
+  int64_t inputW = inputShape[2];
+  int64_t m = getM();
+  int64_t r = getR();
+  int64_t alpha = m + r - 1;
+  int64_t alphaH = inputH != 1 ? alpha : 1;
+  int64_t alphaW = inputW != 1 ? alpha : 1;
+  auto alphaHAttr = builder.getI64IntegerAttr(alphaH);
+  auto alphaWAttr = bui...
[truncated]

@Hsiangkai Hsiangkai requested a review from Max191 June 20, 2024 11:59
@Hsiangkai Hsiangkai force-pushed the users/hsiangkai/winograd-tiling branch from 5fd211c to a5c85c7 Compare June 20, 2024 12:13
@Hsiangkai
Copy link
Contributor Author

Sorry, I am still figuring out how to create stack PRs.

@Hsiangkai Hsiangkai closed this Jun 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants