Skip to content

[mlir][linalg] Constrain the parameters m, r in Winograd ops #144657

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 25, 2025

Conversation

Hsiangkai
Copy link
Contributor

We only support fixed set of minimum filtering algorithm for Winograd Conv2D decomposition. Instead of letting users specify any integer, define a fixed set of enumeration values for the parameters of minimum filtering algorithm.

We only support fixed set of minimum filtering algorithm for Winograd
Conv2D decomposition. Instead of letting users specify any integer,
define a fixed set of enumeration values for the parameters of minimum
filtering algorithm.
@llvmbot
Copy link
Member

llvmbot commented Jun 18, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Hsiangkai Wang (Hsiangkai)

Changes

We only support fixed set of minimum filtering algorithm for Winograd Conv2D decomposition. Instead of letting users specify any integer, define a fixed set of enumeration values for the parameters of minimum filtering algorithm.


Patch is 86.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144657.diff

16 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/Linalg.h (+7)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td (+18)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td (+6-12)
  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+2-2)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+5-4)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+44-13)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp (+64-79)
  • (modified) mlir/test/Dialect/Linalg/invalid.mlir (+15-15)
  • (modified) mlir/test/Dialect/Linalg/roundtrip.mlir (+12-12)
  • (modified) mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir (+12-12)
  • (modified) mlir/test/Dialect/Linalg/transform-tile-winograd.mlir (+18-18)
  • (modified) mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir (+6-6)
  • (modified) mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir (+3-3)
  • (modified) mlir/test/Dialect/Linalg/winograd-conv2d.mlir (+21-21)
  • (modified) mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp (+3-2)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
index 57bf6305a469d..69e09f6b32c2d 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
@@ -100,6 +100,13 @@ OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val,
 
 #include "mlir/Dialect/Linalg/IR/LinalgOpsEnums.h.inc"
 
+namespace mlir {
+namespace linalg {
+WinogradConv2DFmr getWinogradConv2DFmr(int64_t m, int64_t r);
+std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr);
+} // namespace linalg
+} // namespace mlir
+
 //===----------------------------------------------------------------------===//
 // Linalg Attributes
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index ce68afe471fe8..8c98c0b8b8683 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -122,4 +122,22 @@ def TypeFn : I32EnumAttr<"TypeFn", "", [
   let cppNamespace = "::mlir::linalg";
 }
 
+/// We use F(m, r) to define the size of minimal filtering algorithms.
+/// m is the output dimension and r is the filter dimension. We can get
+/// the input dimension, alpha, from the formula, alpha = m + r - 1.
+///
+/// For example, when m = 2 and r = 3, we know its input size is 4.
+/// The Conv2D will operate on 4x4 input data with 3x3 filter and get
+/// 2x2 output result.
+def WinogradConv2DFmr : I32EnumAttr<"WinogradConv2DFmr",
+    "Winograd Conv2D F(m, r)",
+    [
+      I32EnumAttrCase<"F_2_3", 0>,
+      I32EnumAttrCase<"F_4_3", 1>,
+      I32EnumAttrCase<"F_2_5", 2>,
+      I32EnumAttrCase<"Unknown", -1>,
+    ]>{
+  let cppNamespace = "mlir::linalg";
+}
+
 #endif // LINALG_ENUMS
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index 1b48bf5fcb237..7ff44c2e1d2ed 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -183,15 +183,13 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
 
   let arguments = (ins TensorRankOf<[AnyType], [4]>:$filter,
                        TensorRankOf<[AnyType], [4]>:$output,
-                       I64Attr:$m,
-                       I64Attr:$r
+                       WinogradConv2DFmr:$fmr
   );
 
   let results = (outs TensorRankOf<[AnyType], [4]>:$result);
   let assemblyFormat = [{
     attr-dict
-    `m` `(` $m `)`
-    `r` `(` $r `)`
+    `fmr` `(` $fmr `)`
     `ins` `(` $filter `:` type($filter) `)`
     `outs` `(` $output `:` type($output) `)`
     `->` type($result)
@@ -254,15 +252,13 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
 
   let arguments = (ins TensorRankOf<[AnyType], [4]>:$input,
                        TensorRankOf<[AnyType], [6]>:$output,
-                       I64Attr:$m,
-                       I64Attr:$r
+                       WinogradConv2DFmr:$fmr
   );
 
   let results = (outs TensorRankOf<[AnyType], [6]>:$result);
   let assemblyFormat = [{
     attr-dict
-    `m` `(` $m `)`
-    `r` `(` $r `)`
+    `fmr` `(` $fmr `)`
     `ins` `(` $input `:` type($input) `)`
     `outs` `(` $output `:` type($output) `)`
     `->` type($result)
@@ -343,15 +339,13 @@ def Linalg_WinogradOutputTransformOp : Linalg_Op<"winograd_output_transform",
 
   let arguments = (ins TensorRankOf<[AnyType], [6]>:$value,
                        TensorRankOf<[AnyType], [4]>:$output,
-                       I64Attr:$m,
-                       I64Attr:$r
+                       WinogradConv2DFmr:$fmr
   );
 
   let results = (outs TensorRankOf<[AnyType], [4]>:$result);
   let assemblyFormat = [{
     attr-dict
-    `m` `(` $m `)`
-    `r` `(` $r `)`
+    `fmr` `(` $fmr `)`
     `ins` `(` $value `:` type($value) `)`
     `outs` `(` $output `:` type($output) `)`
     `->` type($result)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 15ea5e7bf7159..1b035ec2a457e 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -9,6 +9,7 @@
 #ifndef LINALG_TRANSFORM_OPS
 #define LINALG_TRANSFORM_OPS
 
+include "mlir/Dialect/Linalg/IR/LinalgEnums.td"
 include "mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td"
 include "mlir/Dialect/Transform/IR/TransformAttrs.td"
 include "mlir/Dialect/Transform/IR/TransformDialect.td"
@@ -2802,8 +2803,7 @@ def WinogradConv2DOp : Op<Transform_Dialect,
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
-                       I64Attr:$m,
-                       I64Attr:$r);
+                       WinogradConv2DFmr:$fmr);
   let results = (outs TransformHandleTypeInterface:$transformed);
 
   let assemblyFormat =
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 2eef0a06d0eb4..e5ee7724cd32d 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -36,6 +36,7 @@ class BufferizationState;
 namespace linalg {
 
 class LinalgOp;
+enum class WinogradConv2DFmr : uint32_t;
 
 //===----------------------------------------------------------------------===//
 // Utils.
@@ -1337,8 +1338,8 @@ FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter,
 /// F(m x m, r x r). m is the dimension size of output and r is the dimension
 /// size of filter.
 FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter,
-                                      linalg::Conv2DNhwcFhwcOp op, int64_t m,
-                                      int64_t r);
+                                      linalg::Conv2DNhwcFhwcOp op,
+                                      WinogradConv2DFmr fmr);
 
 /// Rewrite linalg.winograd_filter_transform. The data layout of the filter is
 /// FHWC. The transformation matrix is 2-dimension. We need to extract H x W
@@ -1879,8 +1880,8 @@ void populateBlockPackMatmulPatterns(RewritePatternSet &patterns,
                                      const ControlBlockPackMatmulFn &controlFn);
 
 /// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r).
-void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m,
-                                    int64_t r);
+void populateWinogradConv2DPatterns(RewritePatternSet &patterns,
+                                    WinogradConv2DFmr fmr);
 
 /// Patterns to decompose Winograd operators.
 void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 5dbb2403eddbd..ac592ad808311 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2989,8 +2989,9 @@ LogicalResult WinogradFilterTransformOp::verify() {
   ArrayRef<int64_t> filterShape = filterType.getShape();
   int64_t filterH = filterShape[getFilterHDim()];
   int64_t filterW = filterShape[getFilterWDim()];
-  int64_t r = getR();
-  int64_t m = getM();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
 
   if (filterH != r && filterH != 1)
     return emitOpError("expect filter height either equals to r or 1");
@@ -3046,8 +3047,9 @@ LogicalResult WinogradFilterTransformOp::getResultTilePosition(
   ArrayRef<int64_t> filterShape = filterType.getShape();
   int64_t filterH = filterShape[getFilterHDim()];
   int64_t filterW = filterShape[getFilterWDim()];
-  int64_t m = getM();
-  int64_t r = getR();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   int64_t alpha = m + r - 1;
   int64_t alphaH = filterH != 1 ? alpha : 1;
   int64_t alphaW = filterW != 1 ? alpha : 1;
@@ -3124,8 +3126,9 @@ LogicalResult WinogradInputTransformOp::verify() {
   ArrayRef<int64_t> inputShape = inputType.getShape();
   int64_t inputH = inputShape[getInputHDim()];
   int64_t inputW = inputShape[getInputWDim()];
-  int m = getM();
-  int r = getR();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   int64_t tileSize = m + r - 1;
 
   auto outputType = cast<ShapedType>(getOutput().getType());
@@ -3194,8 +3197,9 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
   int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
   int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
 
-  int64_t m = getM();
-  int64_t r = getR();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   int64_t alpha = m + r - 1;
   int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
   int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
@@ -3224,8 +3228,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
                                                  ArrayRef<OpFoldResult> offsets,
                                                  ArrayRef<OpFoldResult> sizes) {
   IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
-  int64_t m = getM();
-  int64_t r = getR();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
 
   ShapedType outputType = getOutputOperandType();
   ArrayRef<int64_t> outputShape = outputType.getShape();
@@ -3303,8 +3308,9 @@ LogicalResult WinogradOutputTransformOp::verify() {
   int64_t valueW = valueShape[getValueAlphaWDim()];
   int64_t valueTileH = valueShape[getValueTileHDim()];
   int64_t valueTileW = valueShape[getValueTileWDim()];
-  int m = getM();
-  int r = getR();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   bool leftTransform = valueH != 1;
   bool rightTransform = valueW != 1;
 
@@ -3365,7 +3371,9 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
     OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
     ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
     SmallVector<OpFoldResult> &resultSizes) {
-  int64_t m = getM();
+  WinogradConv2DFmr fmr = getFmr();
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
 
   Location loc = getLoc();
   MLIRContext *context = builder.getContext();
@@ -3623,6 +3631,29 @@ verifyExtendedBatchVariantMatmulSemantic(OpTy batchVariantMatmulOp,
 namespace mlir {
 namespace linalg {
 
+WinogradConv2DFmr getWinogradConv2DFmr(int64_t m, int64_t r) {
+  if (m == 2 && r == 3)
+    return WinogradConv2DFmr::F_2_3;
+  if (m == 4 && r == 3)
+    return WinogradConv2DFmr::F_4_3;
+  if (m == 2 && r == 5)
+    return WinogradConv2DFmr::F_2_5;
+  return WinogradConv2DFmr::Unknown;
+}
+
+std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
+  switch (fmr) {
+  case WinogradConv2DFmr::F_2_3:
+    return {2, 3};
+  case WinogradConv2DFmr::F_4_3:
+    return {4, 3};
+  case WinogradConv2DFmr::F_2_5:
+    return {2, 5};
+  default:
+    return {-1, -1};
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // MatMulOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index b2c28f5eed33c..528c561167824 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -4030,7 +4030,7 @@ DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne(
   bool supported = TypeSwitch<Operation *, bool>(target)
                        .Case([&](linalg::Conv2DNhwcFhwcOp op) {
                          maybeTransformed =
-                             winogradConv2D(rewriter, op, getM(), getR());
+                             winogradConv2D(rewriter, op, getFmr());
                          return true;
                        })
                        .Default([&](Operation *op) { return false; });
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index e4221d4748415..a4f835350fb52 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
@@ -176,19 +177,6 @@ constexpr float A_2x2_5x5[] = {
 };
 // clang-format on
 
-using TransformMapKeyTy = std::pair<int, int>;
-
-/// We use F(m, r) to define the size of minimal filtering algorithms.
-/// m is the output dimension and r is the filter dimension. We can get
-/// the input dimension, alpha, from the formula, alpha = m + r - 1.
-///
-/// For example, when m = 2 and r = 3, we know its input size is 4.
-/// The Conv2D will operate on 4x4 input data with 3x3 filter and get
-/// 2x2 output result.
-constexpr TransformMapKeyTy F_2_3{2, 3};
-constexpr TransformMapKeyTy F_4_3{4, 3};
-constexpr TransformMapKeyTy F_2_5{2, 5};
-
 /// Structure to keep information of constant transform matrices.
 struct TransformMatrix {
   TransformMatrix(const float *table, int64_t rows, int64_t cols,
@@ -344,22 +332,22 @@ Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source,
 ///     %ret = linalg.matmul %ret, GT
 ///     %inserted = insert %ret into filter<h x w x c x f>
 Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
-                      Value retValue, int64_t m, int64_t r,
+                      Value retValue, WinogradConv2DFmr fmr,
                       bool leftTransform = true, bool rightTransform = true) {
   // Map from (m, r) to G transform matrix.
-  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
       GMatrices = {
-          {F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
-          {F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
-          {F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
+          {WinogradConv2DFmr::F_2_3, TransformMatrix(G_2x2_3x3, 4, 3)},
+          {WinogradConv2DFmr::F_4_3, TransformMatrix(G_4x4_3x3, 6, 3)},
+          {WinogradConv2DFmr::F_2_5, TransformMatrix(G_2x2_5x5, 6, 5)},
       };
 
   // Map from (m, r) to GT transform matrix.
-  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
       GTMatrices = {
-          {F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
-          {F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
-          {F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
+          {WinogradConv2DFmr::F_2_3, TransformMatrix(GT_2x2_3x3, 3, 4)},
+          {WinogradConv2DFmr::F_4_3, TransformMatrix(GT_4x4_3x3, 3, 6)},
+          {WinogradConv2DFmr::F_2_5, TransformMatrix(GT_2x2_5x5, 5, 6)},
       };
 
   auto filterType = cast<ShapedType>(filter.getType());
@@ -370,6 +358,8 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
   int64_t filterW = filterShape[2];
   int64_t filterC = filterShape[3];
 
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   if (filterH != r && filterH != 1)
     return Value();
   if (filterW != r && filterW != 1)
@@ -387,14 +377,13 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
                             zeroIdx, filterH, filterW, /*loopNorFIdx=*/0,
                             /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
 
-    TransformMapKeyTy key = {m, r};
     int64_t retRows = 1;
     Value matmulRetValue = extractFilter;
     Value zero = builder.create<arith::ConstantOp>(
         loc, rewriter.getZeroAttr(elementType));
     if (leftTransform) {
       // Get constant transform matrix G.
-      auto it = GMatrices.find(key);
+      auto it = GMatrices.find(fmr);
       if (it == GMatrices.end())
         return {};
       const TransformMatrix &GMatrix = it->second;
@@ -416,7 +405,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
 
     if (rightTransform) {
       // Get constant transform matrix GT.
-      auto it = GTMatrices.find(key);
+      auto it = GTMatrices.find(fmr);
       if (it == GTMatrices.end())
         return {};
       const TransformMatrix &GTMatrix = it->second;
@@ -476,24 +465,26 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter,
 ///                            %output<alphaH x alphaW x tileH x tileW x N x C>
 ///                            at [0, 0, %h, %w, %n, %c]
 Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
-                     Value retValue, int64_t m, int64_t r,
+                     Value retValue, WinogradConv2DFmr fmr,
                      bool leftTransform = true, bool rightTransform = true) {
   // Map from (m, r) to BT transform matrix.
-  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
       BTMatrices = {
-          {F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
-          {F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
-          {F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
+          {WinogradConv2DFmr::F_2_3, TransformMatrix(BT_2x2_3x3, 4, 4)},
+          {WinogradConv2DFmr::F_4_3, TransformMatrix(BT_4x4_3x3, 6, 6)},
+          {WinogradConv2DFmr::F_2_5, TransformMatrix(BT_2x2_5x5, 6, 6)},
       };
 
   // Map from (m, r) to B transform matrix.
-  static const llvm::SmallDenseMap<TransformMapKeyTy, TransformMatrix>
+  static const llvm::SmallDenseMap<WinogradConv2DFmr, TransformMatrix>
       BMatrices = {
-          {F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
-          {F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
-          {F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
+          {WinogradConv2DFmr::F_2_3, TransformMatrix(B_2x2_3x3, 4, 4)},
+          {WinogradConv2DFmr::F_4_3, TransformMatrix(B_4x4_3x3, 6, 6)},
+          {WinogradConv2DFmr::F_2_5, TransformMatrix(B_2x2_5x5, 6, 6)},
       };
 
+  int64_t m, r;
+  std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
   auto inputType = cast<ShapedType>(input.getType());
   Type elementType = inputType.getElementType();
   auto inputShape = inputType.getShape(); // N, H, W, C
@@ -529,7 +520,6 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
                             widthOffset, alphaH, alphaW, /*loopNorFIdx=*/0,
                             /*loopCorFIdx=*/3, /*heightIdx=*/1, /*widthIdx=*/2);
 
-    TransformMapKeyTy key = {m, r};
     int64_t retRows = 1;
     int64_t retCols = 1;
     Value matmulRetValue = extractInput;
@@ -537,7 +527,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
         loc, rewriter.getZeroAttr(elementType));
     if (leftTransform) {
       // Get constant transform matrix BT.
-      auto it = BTMatrices.find(key);
+      auto it = BTMatrices.find(fmr);
       if (it == BTMatrices.end())
         return {};
       const TransformMatrix &BTMatrix = it->second;
@@ -560,7 +550,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
 
     if (rightTransform) {
       // Get constant transform matrix B.
-      auto it = BMatrices.find(key);
+      auto it = BMatrices.find(fmr);
       if (it == BMatrices.end())
         return {};
       const TransformMatrix &BMatrix = it->second;
@@ -696,24 +686,26 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc,
 ///                            output<N x H x W x F>
 ///                            at [%n, (%h x m), (%w x m), %f]
 Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
-                      Value output, int64_t m, int64_t r,
+                      Value output, WinogradConv2DFmr fmr,
                       bool l...
[truncated]

@GeorgeARM
Copy link
Contributor

GeorgeARM commented Jun 18, 2025

Is the implementation unable to functionally work with different m, r parameters?
If yes constraining is fine; but if it can, then constraint semantics should be orthogonal. Otherwise we restrict experimentation and we might end up with a long list of configurations. Users using Winograd I presume need to be aware of what they are doing. If we want to ease use we could have a utility function that returns some "safe", "good" parameters for a given kernel config.
If we don't have transform matrices we can just create a side-band API to inject/extract transform matrices and work with this? Just brain-storming really here.
Thoughts?

@MaheshRavishankar MaheshRavishankar requested a review from Max191 June 20, 2025 03:52
Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay for me in principle. Please address the comments.

Also please document steps that are needed to add support for other Winograd shapes.

I32EnumAttrCase<"F_2_3", 0>,
I32EnumAttrCase<"F_4_3", 1>,
I32EnumAttrCase<"F_2_5", 2>,
I32EnumAttrCase<"Unknown", -1>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather not have "unknown" as an enum entry. This turns a compile-time error into a runtime error (one can pass an instance of the enum into an op constructor and it is only caught at verifier/assertion).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 The Unknown enum is not a valid configuration, so it doesn't seem to be necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.


/// Converts the given `m` and `r` parameters to a WinogradConv2DFmr enumeration
/// value.
WinogradConv2DFmr getWinogradConv2DFmr(int64_t m, int64_t r);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can return std::optional<WinogradConv2DFmr> defaulting to std::nullopt.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

int64_t m = getM();
WinogradConv2DFmr fmr = getFmr();
int64_t m, r;
std::tie(m, r) = getFmrFromWinogradConv2DFmr(fmr);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, we could silently continue with m = r = -1 unless we remove the "unknown" enum case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed "unknown" enum case.

Copy link
Contributor

@Max191 Max191 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM other than the comments about the Unknown enum case.


/// Converts the given `m` and `r` parameters to a WinogradConv2DFmr enumeration
/// value.
WinogradConv2DFmr getWinogradConv2DFmr(int64_t m, int64_t r);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

I32EnumAttrCase<"F_2_3", 0>,
I32EnumAttrCase<"F_4_3", 1>,
I32EnumAttrCase<"F_2_5", 2>,
I32EnumAttrCase<"Unknown", -1>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 The Unknown enum is not a valid configuration, so it doesn't seem to be necessary.

@Hsiangkai Hsiangkai merged commit d16f42d into llvm:main Jun 25, 2025
7 checks passed
itf added a commit to itf/llvm-project that referenced this pull request Jun 25, 2025
PR llvm#144657 added include "mlir/Dialect/Linalg/IR/LinalgEnums.td" to mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
itf added a commit that referenced this pull request Jun 25, 2025
#144657 added #include
"mlir/Dialect/Linalg/IR/LinalgEnums.td" to
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td, so
we update the deps
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Jun 25, 2025
llvm/llvm-project#144657 added #include
"mlir/Dialect/Linalg/IR/LinalgEnums.td" to
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td, so
we update the deps
itf added a commit to itf/llvm-project that referenced this pull request Jun 25, 2025
llvm#144657 added #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" to mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp,  though it is not in use
itf added a commit that referenced this pull request Jun 25, 2025
#144657 added #include
"mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" to
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp, though it is not
in use
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Jun 25, 2025
…657 (#145749)

llvm/llvm-project#144657 added #include
"mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" to
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp, though it is not
in use
anthonyhatran pushed a commit to anthonyhatran/llvm-project that referenced this pull request Jun 26, 2025
…4657)

We only support fixed set of minimum filtering algorithm for Winograd
Conv2D decomposition. Instead of letting users specify any integer,
define a fixed set of enumeration values for the parameters of minimum
filtering algorithm.
anthonyhatran pushed a commit to anthonyhatran/llvm-project that referenced this pull request Jun 26, 2025
llvm#144657 added #include
"mlir/Dialect/Linalg/IR/LinalgEnums.td" to
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td, so
we update the deps
anthonyhatran pushed a commit to anthonyhatran/llvm-project that referenced this pull request Jun 26, 2025
…vm#145749)

llvm#144657 added #include
"mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" to
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp, though it is not
in use
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants