Skip to content

[MLIR][Linalg] Fixes for Winograd decomposition and for tiling #123675

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 29, 2025

Conversation

d-smirnov
Copy link
Contributor

The PR addresses issues with the filters of 1 x r and of r x 1 and with the tiling.

The PR addresses issues with filers 1 x r and r x 1 and with tiling

Signed-off-by: Dmitriy Smirnov <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Jan 20, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Dmitriy Smirnov (d-smirnov)

Changes

The PR addresses issues with the filters of 1 x r and of r x 1 and with the tiling.


Patch is 35.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123675.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td (+4-2)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+22-14)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp (+19-15)
  • (modified) mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir (+107-6)
  • (modified) mlir/test/Dialect/Linalg/transform-tile-winograd.mlir (+34-28)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index e42fd5d2ce13c1..f8df828f74851b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -155,7 +155,7 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
 }
 
 def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
-    [AllElementTypesMatch<["filter", "output"]>,
+    [AllElementTypesMatch<["filter", "output"]>, DestinationStyleOpInterface,
      DeclareOpInterfaceMethods<TilingInterface,
       ["getIterationDomain",
        "getLoopIteratorTypes",
@@ -220,12 +220,13 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
     int64_t getFilterCDim() {
       return 3;
     }
+    MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
   }];
   let hasVerifier = 1;
 }
 
 def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
-    [AllElementTypesMatch<["input", "output"]>,
+    [AllElementTypesMatch<["input", "output"]>, DestinationStyleOpInterface,
      DeclareOpInterfaceMethods<TilingInterface,
       ["getIterationDomain",
        "getLoopIteratorTypes",
@@ -308,6 +309,7 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
     int64_t getOutputCDim() {
       return 5;
     }
+    MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
   }];
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c13b663dbf05b1..0649fbc8e9549b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3063,8 +3063,11 @@ LogicalResult WinogradInputTransformOp::verify() {
   int m = getM();
   int r = getR();
   int64_t tileSize = m + r - 1;
-  bool leftTransform = inputH != 1;
-  bool rightTransform = inputW != 1;
+
+  auto outputType = cast<ShapedType>(getOutput().getType());
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+  bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
+  bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
 
   SmallVector<int64_t> expectedOutputShape(6, inputH);
   if (ShapedType::isDynamic(inputH)) {
@@ -3073,7 +3076,7 @@ LogicalResult WinogradInputTransformOp::verify() {
   } else {
     expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
     expectedOutputShape[getOutputTileHDim()] =
-        leftTransform ? (inputH - (r - 1)) / m : 1;
+        leftTransform ? (inputH - (r - 1)) / m : inputH;
   }
   if (ShapedType::isDynamic(inputW)) {
     expectedOutputShape[getOutputAlphaWDim()] = tileSize;
@@ -3081,13 +3084,11 @@ LogicalResult WinogradInputTransformOp::verify() {
   } else {
     expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
     expectedOutputShape[getOutputTileWDim()] =
-        rightTransform ? (inputW - (r - 1)) / m : 1;
+        rightTransform ? (inputW - (r - 1)) / m : inputW;
   }
   expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
   expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
 
-  auto outputType = cast<ShapedType>(getOutput().getType());
-  ArrayRef<int64_t> outputShape = outputType.getShape();
   if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
     return emitOpError("the output shape is not expected");
   }
@@ -3124,15 +3125,17 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
     ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
     SmallVector<OpFoldResult> &resultSizes) {
   IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
-  ShapedType inputType = getInputOperandType();
-  ArrayRef<int64_t> inputShape = inputType.getShape();
-  int64_t inputH = inputShape[getInputHDim()];
-  int64_t inputW = inputShape[getInputWDim()];
+  ShapedType outputType = getOutputOperandType();
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+  int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
+  int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
+
   int64_t m = getM();
   int64_t r = getR();
   int64_t alpha = m + r - 1;
-  int64_t alphaH = inputH != 1 ? alpha : 1;
-  int64_t alphaW = inputW != 1 ? alpha : 1;
+  int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
+  int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
+
   IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
   IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
 
@@ -3165,6 +3168,11 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
   int64_t m = getM();
   int64_t r = getR();
 
+  ShapedType outputType = getOutputOperandType();
+  ArrayRef<int64_t> outputShape = outputType.getShape();
+  int64_t alphaH = outputShape[getOutputAlphaHDim()];
+  int64_t alphaW = outputShape[getOutputAlphaWDim()];
+
   Location loc = getLoc();
   MLIRContext *context = builder.getContext();
   auto offsetAffineMap =
@@ -3190,9 +3198,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
   sliceOffsets.append(
       {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
   OpFoldResult sizeH =
-      inputH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
+      alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
   OpFoldResult sizeW =
-      inputW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
+      alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
   sliceSizes.append(
       {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
   int64_t inputRank = getInputOperandRank();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 79f77822116fd7..f1059ddf5da2cf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -514,12 +514,14 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
     Value CIter = ivs[3];
 
     auto context = builder.getContext();
+
+    auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
     auto affineMap =
         AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
-    Value heightOffset =
-        builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
-    Value widthOffset =
-        builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
+    Value heightOffset = builder.create<affine::AffineApplyOp>(
+        loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
+    Value widthOffset = builder.create<affine::AffineApplyOp>(
+        loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
 
     // Extract (H, W) from (N, H, W, C).
     auto extractInput =
@@ -753,12 +755,13 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
     Value zero = builder.create<arith::ConstantOp>(
         loc, rewriter.getZeroAttr(elementType));
 
+    auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
     auto affineMap =
         AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
-    Value heightOffset =
-        builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
-    Value widthOffset =
-        builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
+    Value heightOffset = builder.create<affine::AffineApplyOp>(
+        loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
+    Value widthOffset = builder.create<affine::AffineApplyOp>(
+        loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
 
     Value outInitVal =
         extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset,
@@ -1075,16 +1078,17 @@ FailureOr<Operation *>
 decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
                                       linalg::WinogradInputTransformOp op) {
   Location loc = op.getLoc();
-  Value input = op.getInput();
-  auto inputType = cast<ShapedType>(input.getType());
-  auto inputShape = inputType.getShape();
-  int64_t inputH = inputShape[1];
-  int64_t inputW = inputShape[2];
+  Value output = op.getOutput();
+  auto outputType = cast<ShapedType>(output.getType());
+  auto outputShape = outputType.getShape();
+
+  int64_t outputH = outputShape[0];
+  int64_t outputW = outputShape[1];
 
   // For F(m x 1, r x 1), we only need to do left side transform.
-  bool leftTransform = inputH != 1;
+  bool leftTransform = outputH != 1;
   // For F(1 x m, 1 x r), we only need to do right side transform.
-  bool rightTransform = inputW != 1;
+  bool rightTransform = outputW != 1;
   Value transformedInput =
       inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
                      op.getR(), leftTransform, rightTransform);
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index 776dc5b748c846..a9af874ea40933 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -61,13 +61,12 @@ module attributes {transform.with_named_sequence} {
 // CHECK:      scf.yield %[[INSERTED_SLICE]]
 // CHECK:    scf.yield %[[S9]]
 // CHECK:  %[[S2:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK:  %[[S3:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK:  %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S3]])
+// CHECK:  %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]])
 // CHECK:    %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
 // CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
 // CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
 // CHECK:      %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
-// CHECK:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:      %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
 // CHECK:      %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
 // CHECK:        %[[S13:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
 // CHECK:          %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
@@ -195,13 +194,12 @@ module attributes {transform.with_named_sequence} {
 // CHECK:    scf.yield %[[S9]] : tensor<6x6x5x2xf32>
 // CHECK:  %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0]
 // CHECK:  %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK:  %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK:  %[[S4:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]])
+// CHECK:  %[[S4:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S2]])
 // CHECK:    %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
 // CHECK:      %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
 // CHECK:      %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
 // CHECK:      %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
-// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:      %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
 // CHECK:      %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
 // CHECK:        %[[S13:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
 // CHECK:          %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 6, 6, 1] [1, 1, 1, 1]
@@ -346,3 +344,106 @@ module attributes {transform.with_named_sequence} {
 // CHECK:       scf.yield %[[INSERTED_SLICE]]
 // CHECK:     scf.yield %[[S7]]
 // CHECK:   return %[[S6]]
+
+// -----
+
+func.func @conv2d_mx1_rx1_2(%arg0: tensor<2x6x2x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<6x1x5x2xf32>
+  %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+  %2 = tensor.empty() : tensor<6x1x1x2x2x5xf32>
+  %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x2x5xf32>) outs(%2 : tensor<6x1x1x2x2x5xf32>) -> tensor<6x1x1x2x2x5xf32>
+  %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
+  %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x2x2x5xf32> into tensor<6x4x5xf32>
+  %4 = tensor.empty() : tensor<6x4x2xf32>
+  %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
+  %6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x4x5xf32>, tensor<6x5x2xf32>) outs(%5 : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
+  %expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 2, 2, 2] : tensor<6x4x2xf32> into tensor<6x1x1x2x2x2xf32>
+  %7 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x2x2x2xf32>) outs(%arg2 : tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32>
+  return %7 : tensor<2x4x2x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+    %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
+    %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
+    %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
+    %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
+    %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
+    transform.yield
+  }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func.func @conv2d_mx1_rx1
+// CHECK-SAME:  (%[[ARG0:.*]]: tensor<2x6x2x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32> {
+// CHECK:   %[[CST:.*]] = arith.constant 3.200000e+01 : f32
+// CHECK:  %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<4x6xf32>
+// CHECK:  %[[CST_1:.*]] = arith.constant dense<{{.*}}> : tensor<6x6xf32>
+// CHECK:  %[[CST_2:.*]] = arith.constant dense<{{.*}}> : tensor<6x3xf32>
+// CHECK:   %[[C1:.*]] = arith.constant 1 : index
+// CHECK:   %[[C5:.*]] = arith.constant 5 : index
+// CHECK:   %[[C2:.*]] = arith.constant 2 : index
+// CHECK:   %[[C0:.*]] = arith.constant 0 : index
+// CHECK:   %[[CST_3:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:   %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK:   %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
+// CHECK:     %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK:       %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 1, 1] [1, 1, 1, 1]
+// CHECK:       %[[S8:.*]] = tensor.empty() : tensor<6x1xf32>
+// CHECK:       %[[S9:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S8]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK:       %[[S10:.*]] = linalg.matmul ins(%[[CST_2]], %[[EXTRACTED_SLICE]] : tensor<6x3xf32>, tensor<3x1xf32>) outs(%[[S9]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK:       %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S10]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 1, 1, 1] [1, 1, 1, 1]
+// CHECK:       scf.yield %[[INSERTED_SLICE]]
+// CHECK:     scf.yield %[[S7]]
+// CHECK:   %[[S2:.*]] = tensor.empty() : tensor<6x1x1x2x2x5xf32>
+// CHECK:   %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]])
+// CHECK:     %[[S8:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK:     %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %8, 0] [2, 6, 1, 5] [1, 1, 1, 1]
+// CHECK:     %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG4]][0, 0, 0, %[[ARG3]], 0, 0] [6, 1, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:     %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[EXTRACTED_SLICE_5]])
+// CHECK:     %[[S10:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]])
+// CHECK:       %[[EXTRACTED_SLICE_6:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG5]], 0, 0, %[[ARG7]]] [1, 6, 1, 1] [1, 1, 1, 1]
+// CHECK:       %[[S11:.*]] = tensor.empty() : tensor<6x1xf32>
+// CHECK:       %[[S12:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S11]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK:       %[[S13:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_6]] : tensor<6x6xf32>, tensor<6x1xf32>) outs(%[[S12]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK:       %[[INSERTED_SLICE_7:.*]] = tensor.insert_slice %[[S13]] into %[[ARG8]][0, 0, 0, 0, %[[ARG5]], %[[ARG7]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK:       scf.yield %[[INSERTED_SLICE_7]]
+// CHECK:     scf.yield %[[S10]]
+// CHECK:    %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG4]][0, 0, 0, %[[ARG3]], 0, 0] [6, 1, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK:    scf.yield %[[INSERTED_SLICE]]
+// CHECK:   %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
+// CHECK:   %[[COLLAPSED_4:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK:   %[[S4:.*]] = tensor.empty() : tensor<6x4x2xf32>
+// CHECK:   %[[S5:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S4]] : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
+// CHECK:   %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_4]], %[[COLLAPSED]] : tensor<6x4x5xf32>, tensor<6x5x2xf32>) outs(%[[S5]] : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
+// CHECK:   %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 2, 2, 2]
+// CHECK:   %[[S7:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
+// CHECK:     %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, %[[ARG3]], 0, 0] [6, 1, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
+// CHECK:     %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG4]][0, 0, 0, 0] [2, 4, 1, 2] [1, 1, 1, 1]
+// CHECK:     %[[S8:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[EXTRACTED_SLICE_5]])
+// CHECK:       %[[S9:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]])
+// CHECK:       %[[EXTRACTED_SLICE_6:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0...
[truncated]

@d-smirnov
Copy link
Contributor Author

Copy link
Contributor

@Hsiangkai Hsiangkai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@GeorgeARM GeorgeARM merged commit f20b8e3 into llvm:main Jan 29, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants