Skip to content

[mlir]Fix compose subview #80551

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 7, 2024
Merged

Conversation

linuxlonelyeagle
Copy link
Member

I found a bug in test-compose-subview,You can see the example I gave.

#map = affine_map<() -> ()>
module {
  func.func private @fun(%arg0: memref<10x10xf32>, %arg1: memref<5x5xf32>) -> memref<5x5xf32> {
    %c0 = arith.constant 0 : index
    %c5 = arith.constant 5 : index
    %c1 = arith.constant 1 : index
    %subview = memref.subview %arg0[0, 0] [5, 5] [1, 1] : memref<10x10xf32> to memref<5x5xf32, strided<[10, 1]>>
    %alloc = memref.alloc() : memref<5x5xf32>
    scf.for %arg2 = %c0 to %c5 step %c1 {
      scf.for %arg3 = %c0 to %c5 step %c1 {
        %subview_0 = memref.subview %subview[%arg2, %arg3] [1, 1] [1, 1] : memref<5x5xf32, strided<[10, 1]>> to memref<f32, strided<[], offset: ?>>
        %subview_1 = memref.subview %arg1[%arg2, %arg3] [1, 1] [1, 1] : memref<5x5xf32> to memref<f32, strided<[], offset: ?>>
        %alloc_2 = memref.alloc() : memref<f32>
        linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%subview_0, %subview_1 : memref<f32, strided<[], offset: ?>>, memref<f32, strided<[], offset: ?>>) outs(%alloc_2 : memref<f32>) {
        ^bb0(%in: f32, %in_4: f32, %out: f32):
          %0 = arith.addf %in, %in_4 : f32
          linalg.yield %0 : f32
        }
        %subview_3 = memref.subview %alloc[%arg2, %arg3] [1, 1] [1, 1] : memref<5x5xf32> to memref<f32, strided<[], offset: ?>>
        memref.copy %alloc_2, %subview_3 : memref<f32> to memref<f32, strided<[], offset: ?>>
      }
    }
    return %alloc : memref<5x5xf32>
  }
  func.func @test(%arg0: memref<10x10xf32>, %arg1: memref<5x5xf32>) -> memref<5x5xf32> {
    %0 = call @fun(%arg0, %arg1) : (memref<10x10xf32>, memref<5x5xf32>) -> memref<5x5xf32>
    return %0 : memref<5x5xf32>
  }
}

When I run mlir-opt test.mlir ---test-compose-subview.

test.mlir:14:9: error: 'linalg.generic' op expected operand rank (2) to match the result rank of indexing_map #0 (0)
        linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%subview_0, %subview_1 : memref<f32, strided<[], offset: ?>>, memref<f32, strided<[], offset: ?>>) outs(%alloc_2 : memref<f32>) {
        ^
test1.mlir:14:9: note: see current operation: 
"linalg.generic"(%4, %5, %6) <{indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = [], operandSegmentSizes = array<i32: 2, 1>}> ({
^bb0(%arg4: f32, %arg5: f32, %arg6: f32):
  %8 = "arith.addf"(%arg4, %arg5) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
  "linalg.yield"(%8) : (f32) -> ()
}) : (memref<1x1xf32, strided<[10, 1], offset: ?>>, memref<f32, strided<[], offset: ?>>, memref<f32>) -> ()

This PR fixes that.In the meantime I've extended this PR to handle cases where stride is greater than 1.

func.func private @Unknown0(%arg0: memref<10x10xf32>, %arg1: memref<5x5xf32>) -> memref<5x5xf32> {
  %c0 = arith.constant 0 : index
  %c5 = arith.constant 5 : index
  %c1 = arith.constant 1 : index
  %subview = memref.subview %arg0[0, 0] [5, 5] [2, 2] : memref<10x10xf32> to memref<5x5xf32, strided<[20, 2]>>
  %alloc = memref.alloc() : memref<5x5xf32>
  scf.for %arg2 = %c0 to %c5 step %c1 {
    scf.for %arg3 = %c0 to %c5 step %c1 {
      %subview_0 = memref.subview %subview[%arg2, %arg3] [1, 1] [1, 1] : memref<5x5xf32, strided<[20, 2]>> to memref<f32, strided<[], offset: ?>>
      %subview_1 = memref.subview %arg1[%arg2, %arg3] [1, 1] [1, 1] : memref<5x5xf32> to memref<f32, strided<[], offset: ?>>
      %alloc_2 = memref.alloc() : memref<f32>
      linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%subview_0, %subview_1 : memref<f32, strided<[], offset: ?>>, memref<f32, strided<[], offset: ?>>) outs(%alloc_2 : memref<f32>) {
      ^bb0(%in: f32, %in_4: f32, %out: f32):
        %0 = arith.addf %in, %in_4 : f32
        linalg.yield %0 : f32
      }
      %subview_3 = memref.subview %alloc[%arg2, %arg3] [1, 1] [1, 1] : memref<5x5xf32> to memref<f32, strided<[], offset: ?>>
      memref.copy %alloc_2, %subview_3 : memref<f32> to memref<f32, strided<[], offset: ?>>
    }
  }
  return %alloc : memref<5x5xf32>
}
$ mlir-opt test.mlir -test-compose-subview
#map = affine_map<()[s0] -> (s0 * 2)>
#map1 = affine_map<() -> ()>
module {
  func.func private @Unknown0(%arg0: memref<10x10xf32>, %arg1: memref<5x5xf32>) -> memref<5x5xf32>  {
    %c0 = arith.constant 0 : index
    %c5 = arith.constant 5 : index
    %c1 = arith.constant 1 : index
    %alloc = memref.alloc() : memref<5x5xf32>
    scf.for %arg2 = %c0 to %c5 step %c1 {
      scf.for %arg3 = %c0 to %c5 step %c1 {
        %0 = affine.apply #map()[%arg2]
        %1 = affine.apply #map()[%arg3]
        %subview = memref.subview %arg0[%0, %1] [1, 1] [2, 2] : memref<10x10xf32> to memref<f32, strided<[], offset: ?>>
        %subview_0 = memref.subview %arg1[%arg2, %arg3] [1, 1] [1, 1] : memref<5x5xf32> to memref<f32, strided<[], offset: ?>>
        %alloc_1 = memref.alloc() : memref<f32>
        linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = []} ins(%subview, %subview_0 : memref<f32, strided<[], offset: ?>>, memref<f32, strided<[], offset: ?>>) outs(%alloc_1 : memref<f32>) {
        ^bb0(%in: f32, %in_3: f32, %out: f32):
          %2 = arith.addf %in, %in_3 : f32
          linalg.yield %2 : f32
        }
        %subview_2 = memref.subview %alloc[%arg2, %arg3] [1, 1] [1, 1] : memref<5x5xf32> to memref<f32, strided<[], offset: ?>>
        memref.copy %alloc_1, %subview_2 : memref<f32> to memref<f32, strided<[], offset: ?>>
      }
    }
    return %alloc : memref<5x5xf32>
  }
}

@llvmbot
Copy link
Member

llvmbot commented Feb 3, 2024

@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: lonely eagle (linuxlonelyeagle)

Changes

I found a bug in test-compose-subview,You can see the example I gave.

#map = affine_map&lt;() -&gt; ()&gt;
module {
  func.func private @<!-- -->fun(%arg0: memref&lt;10x10xf32&gt;, %arg1: memref&lt;5x5xf32&gt;) -&gt; memref&lt;5x5xf32&gt; {
    %c0 = arith.constant 0 : index
    %c5 = arith.constant 5 : index
    %c1 = arith.constant 1 : index
    %subview = memref.subview %arg0[0, 0] [5, 5] [1, 1] : memref&lt;10x10xf32&gt; to memref&lt;5x5xf32, strided&lt;[10, 1]&gt;&gt;
    %alloc = memref.alloc() : memref&lt;5x5xf32&gt;
    scf.for %arg2 = %c0 to %c5 step %c1 {
      scf.for %arg3 = %c0 to %c5 step %c1 {
        %subview_0 = memref.subview %subview[%arg2, %arg3] [1, 1] [1, 1] : memref&lt;5x5xf32, strided&lt;[10, 1]&gt;&gt; to memref&lt;f32, strided&lt;[], offset: ?&gt;&gt;
        %subview_1 = memref.subview %arg1[%arg2, %arg3] [1, 1] [1, 1] : memref&lt;5x5xf32&gt; to memref&lt;f32, strided&lt;[], offset: ?&gt;&gt;
        %alloc_2 = memref.alloc() : memref&lt;f32&gt;
        linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%subview_0, %subview_1 : memref&lt;f32, strided&lt;[], offset: ?&gt;&gt;, memref&lt;f32, strided&lt;[], offset: ?&gt;&gt;) outs(%alloc_2 : memref&lt;f32&gt;) {
        ^bb0(%in: f32, %in_4: f32, %out: f32):
          %0 = arith.addf %in, %in_4 : f32
          linalg.yield %0 : f32
        }
        %subview_3 = memref.subview %alloc[%arg2, %arg3] [1, 1] [1, 1] : memref&lt;5x5xf32&gt; to memref&lt;f32, strided&lt;[], offset: ?&gt;&gt;
        memref.copy %alloc_2, %subview_3 : memref&lt;f32&gt; to memref&lt;f32, strided&lt;[], offset: ?&gt;&gt;
      }
    }
    return %alloc : memref&lt;5x5xf32&gt;
  }
  func.func @<!-- -->test(%arg0: memref&lt;10x10xf32&gt;, %arg1: memref&lt;5x5xf32&gt;) -&gt; memref&lt;5x5xf32&gt; {
    %0 = call @<!-- -->fun(%arg0, %arg1) : (memref&lt;10x10xf32&gt;, memref&lt;5x5xf32&gt;) -&gt; memref&lt;5x5xf32&gt;
    return %0 : memref&lt;5x5xf32&gt;
  }
}

When I run mlir-opt test.mlir ---test-compose-subview.

test.mlir:14:9: error: 'linalg.generic' op expected operand rank (2) to match the result rank of indexing_map #<!-- -->0 (0)
        linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = []} ins(%subview_0, %subview_1 : memref&lt;f32, strided&lt;[], offset: ?&gt;&gt;, memref&lt;f32, strided&lt;[], offset: ?&gt;&gt;) outs(%alloc_2 : memref&lt;f32&gt;) {
        ^
test1.mlir:14:9: note: see current operation: 
"linalg.generic"(%4, %5, %6) &lt;{indexing_maps = [affine_map&lt;() -&gt; ()&gt;, affine_map&lt;() -&gt; ()&gt;, affine_map&lt;() -&gt; ()&gt;], iterator_types = [], operandSegmentSizes = array&lt;i32: 2, 1&gt;}&gt; ({
^bb0(%arg4: f32, %arg5: f32, %arg6: f32):
  %8 = "arith.addf"(%arg4, %arg5) &lt;{fastmath = #arith.fastmath&lt;none&gt;}&gt; : (f32, f32) -&gt; f32
  "linalg.yield"(%8) : (f32) -&gt; ()
}) : (memref&lt;1x1xf32, strided&lt;[10, 1], offset: ?&gt;&gt;, memref&lt;f32, strided&lt;[], offset: ?&gt;&gt;, memref&lt;f32&gt;) -&gt; ()

This PR fixes that.In the meantime I've extended this PR to handle cases where stride is greater than 1.

func.func private @<!-- -->Unknown0(%arg0: memref&lt;10x10xf32&gt;, %arg1: memref&lt;5x5xf32&gt;) -&gt; memref&lt;5x5xf32&gt; {
  %c0 = arith.constant 0 : index
  %c5 = arith.constant 5 : index
  %c1 = arith.constant 1 : index
  %subview = memref.subview %arg0[0, 0] [5, 5] [2, 2] : memref&lt;10x10xf32&gt; to memref&lt;5x5xf32, strided&lt;[20, 2]&gt;&gt;
  %alloc = memref.alloc() : memref&lt;5x5xf32&gt;
  scf.for %arg2 = %c0 to %c5 step %c1 {
    scf.for %arg3 = %c0 to %c5 step %c1 {
      %subview_0 = memref.subview %subview[%arg2, %arg3] [1, 1] [1, 1] : memref&lt;5x5xf32, strided&lt;[20, 2]&gt;&gt; to memref&lt;f32, strided&lt;[], offset: ?&gt;&gt;
      %subview_1 = memref.subview %arg1[%arg2, %arg3] [1, 1] [1, 1] : memref&lt;5x5xf32&gt; to memref&lt;f32, strided&lt;[], offset: ?&gt;&gt;
      %alloc_2 = memref.alloc() : memref&lt;f32&gt;
      linalg.generic {indexing_maps = [affine_map&lt;() -&gt; ()&gt;, affine_map&lt;() -&gt; ()&gt;, affine_map&lt;() -&gt; ()&gt;], iterator_types = []} ins(%subview_0, %subview_1 : memref&lt;f32, strided&lt;[], offset: ?&gt;&gt;, memref&lt;f32, strided&lt;[], offset: ?&gt;&gt;) outs(%alloc_2 : memref&lt;f32&gt;) {
      ^bb0(%in: f32, %in_4: f32, %out: f32):
        %0 = arith.addf %in, %in_4 : f32
        linalg.yield %0 : f32
      }
      %subview_3 = memref.subview %alloc[%arg2, %arg3] [1, 1] [1, 1] : memref&lt;5x5xf32&gt; to memref&lt;f32, strided&lt;[], offset: ?&gt;&gt;
      memref.copy %alloc_2, %subview_3 : memref&lt;f32&gt; to memref&lt;f32, strided&lt;[], offset: ?&gt;&gt;
    }
  }
  return %alloc : memref&lt;5x5xf32&gt;
}
$ mlir-opt test.mlir -test-compose-subview
#map = affine_map&lt;()[s0] -&gt; (s0 * 2)&gt;
#map1 = affine_map&lt;() -&gt; ()&gt;
module {
  func.func private @<!-- -->Unknown0(%arg0: memref&lt;10x10xf32&gt;, %arg1: memref&lt;5x5xf32&gt;) -&gt; memref&lt;5x5xf32&gt;  {
    %c0 = arith.constant 0 : index
    %c5 = arith.constant 5 : index
    %c1 = arith.constant 1 : index
    %alloc = memref.alloc() : memref&lt;5x5xf32&gt;
    scf.for %arg2 = %c0 to %c5 step %c1 {
      scf.for %arg3 = %c0 to %c5 step %c1 {
        %0 = affine.apply #map()[%arg2]
        %1 = affine.apply #map()[%arg3]
        %subview = memref.subview %arg0[%0, %1] [1, 1] [2, 2] : memref&lt;10x10xf32&gt; to memref&lt;f32, strided&lt;[], offset: ?&gt;&gt;
        %subview_0 = memref.subview %arg1[%arg2, %arg3] [1, 1] [1, 1] : memref&lt;5x5xf32&gt; to memref&lt;f32, strided&lt;[], offset: ?&gt;&gt;
        %alloc_1 = memref.alloc() : memref&lt;f32&gt;
        linalg.generic {indexing_maps = [#map1, #map1, #map1], iterator_types = []} ins(%subview, %subview_0 : memref&lt;f32, strided&lt;[], offset: ?&gt;&gt;, memref&lt;f32, strided&lt;[], offset: ?&gt;&gt;) outs(%alloc_1 : memref&lt;f32&gt;) {
        ^bb0(%in: f32, %in_3: f32, %out: f32):
          %2 = arith.addf %in, %in_3 : f32
          linalg.yield %2 : f32
        }
        %subview_2 = memref.subview %alloc[%arg2, %arg3] [1, 1] [1, 1] : memref&lt;5x5xf32&gt; to memref&lt;f32, strided&lt;[], offset: ?&gt;&gt;
        memref.copy %alloc_1, %subview_2 : memref&lt;f32&gt; to memref&lt;f32, strided&lt;[], offset: ?&gt;&gt;
      }
    }
    return %alloc : memref&lt;5x5xf32&gt;
  }
}

Full diff: https://github.com/llvm/llvm-project/pull/80551.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp (+52-37)
  • (modified) mlir/test/Transforms/compose-subview.mlir (+48)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
index 431d270b0a2cb..dd10865544a50 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
@@ -24,14 +24,14 @@ using namespace mlir;
 
 namespace {
 
-// Replaces a subview of a subview with a single subview. Only supports subview
-// ops with static sizes and static strides of 1 (both static and dynamic
+// Replaces a subview of a subview with a single subview(both static and dynamic
 // offsets are supported).
 struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(memref::SubViewOp op,
                                 PatternRewriter &rewriter) const override {
+
     // 'op' is the 'SubViewOp' we're rewriting. 'sourceOp' is the op that
     // produces the input of the op we're rewriting (for 'SubViewOp' the input
     // is called the "source" value). We can only combine them if both 'op' and
@@ -52,66 +52,81 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
 
     // Offsets, sizes and strides OpFoldResult for the combined 'SubViewOp'.
     SmallVector<OpFoldResult> offsets, sizes, strides;
-
-    // Because we only support input strides of 1, the output stride is also
-    // always 1.
-    if (llvm::all_of(strides, [](OpFoldResult &valueOrAttr) {
-          Attribute attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr);
-          return attr && cast<IntegerAttr>(attr).getInt() == 1;
-        })) {
-      strides = SmallVector<OpFoldResult>(sourceOp.getMixedStrides().size(),
-                                          rewriter.getI64IntegerAttr(1));
-    } else {
-      return failure();
+    auto opStrides = op.getMixedStrides();
+    auto sourceStrides = sourceOp.getMixedStrides();
+
+    // The output stride in each dimension is equal to the product of the
+    // dimensions corresponding to source and op.
+    for (auto [opStride, sourceStride] : llvm::zip(opStrides, sourceStrides)) {
+      Attribute opStrideAttr = dyn_cast_if_present<Attribute>(opStride);
+      Attribute sourceStrideAttr = dyn_cast_if_present<Attribute>(sourceStride);
+      if (!opStrideAttr || !sourceStrideAttr)
+        return failure();
+      strides.push_back(rewriter.getI64IntegerAttr(
+          cast<IntegerAttr>(opStrideAttr).getInt() *
+          cast<IntegerAttr>(sourceStrideAttr).getInt()));
     }
 
     // The rules for calculating the new offsets and sizes are:
     // * Multiple subview offsets for a given dimension compose additively.
-    //   ("Offset by m" followed by "Offset by n" == "Offset by m + n")
+    //   ("Offset by m and Stride by k" followed by "Offset by n" == "Offset by
+    //   m + n * k")
     // * Multiple sizes for a given dimension compose by taking the size of the
     //   final subview and ignoring the rest. ("Take m values" followed by "Take
     //   n values" == "Take n values") This size must also be the smallest one
     //   by definition (a subview needs to be the same size as or smaller than
     //   its source along each dimension; presumably subviews that are larger
     //   than their sources are disallowed by validation).
-    for (auto it : llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
-                             op.getMixedSizes())) {
-      auto opOffset = std::get<0>(it);
-      auto sourceOffset = std::get<1>(it);
-      auto opSize = std::get<2>(it);
-
+    for (auto [opOffset, sourceOffset, sourceStride, opSize] :
+         llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
+                   sourceOp.getMixedStrides(), op.getMixedSizes())) {
       // We only support static sizes.
       if (opSize.is<Value>()) {
         return failure();
       }
-
       sizes.push_back(opSize);
       Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset),
                 sourceOffsetAttr =
-                    llvm::dyn_cast_if_present<Attribute>(sourceOffset);
-
+                    llvm::dyn_cast_if_present<Attribute>(sourceOffset),
+                sourceStrideAttr =
+                    llvm::dyn_cast_if_present<Attribute>(sourceStride);
       if (opOffsetAttr && sourceOffsetAttr) {
+
         // If both offsets are static we can simply calculate the combined
         // offset statically.
         offsets.push_back(rewriter.getI64IntegerAttr(
-            cast<IntegerAttr>(opOffsetAttr).getInt() +
+            cast<IntegerAttr>(opOffsetAttr).getInt() *
+                cast<IntegerAttr>(sourceStrideAttr).getInt() +
             cast<IntegerAttr>(sourceOffsetAttr).getInt()));
       } else {
-        // When either offset is dynamic, we must emit an additional affine
-        // transformation to add the two offsets together dynamically.
-        AffineExpr expr = rewriter.getAffineConstantExpr(0);
+        AffineExpr expr0 = rewriter.getAffineConstantExpr(0);
+        AffineExpr expr1 = rewriter.getAffineConstantExpr(0);
         SmallVector<Value> affineApplyOperands;
-        for (auto valueOrAttr : {opOffset, sourceOffset}) {
-          if (auto attr = llvm::dyn_cast_if_present<Attribute>(valueOrAttr)) {
-            expr = expr + cast<IntegerAttr>(attr).getInt();
+        SmallVector<OpFoldResult> opOffsets{sourceOffset, opOffset};
+        for (auto [idx, offset] : llvm::enumerate(opOffsets)) {
+          if (auto attr = llvm::dyn_cast_if_present<Attribute>(offset)) {
+            if (idx == 0) {
+              expr0 = expr0 + cast<IntegerAttr>(attr).getInt();
+            } else if (idx == 1) {
+              expr1 = expr1 + cast<IntegerAttr>(attr).getInt();
+              expr1 = expr1 * cast<IntegerAttr>(sourceStrideAttr).getInt();
+              expr0 = expr0 + expr1;
+            }
           } else {
-            expr =
-                expr + rewriter.getAffineSymbolExpr(affineApplyOperands.size());
-            affineApplyOperands.push_back(valueOrAttr.get<Value>());
+            if (idx == 0) {
+              expr0 = expr0 +
+                      rewriter.getAffineSymbolExpr(affineApplyOperands.size());
+              affineApplyOperands.push_back(offset.get<Value>());
+            } else if (idx == 1) {
+              expr1 = expr1 +
+                      rewriter.getAffineSymbolExpr(affineApplyOperands.size());
+              affineApplyOperands.push_back(offset.get<Value>());
+              expr1 = expr1 * cast<IntegerAttr>(sourceStrideAttr).getInt();
+              expr0 = expr0 + expr1;
+            }
           }
         }
-
-        AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr);
+        AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr0);
         Value result = rewriter.create<affine::AffineApplyOp>(
             op.getLoc(), map, affineApplyOperands);
         offsets.push_back(result);
@@ -120,8 +135,8 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
 
     // This replaces 'op' but leaves 'sourceOp' alone; if it no longer has any
     // uses it can be removed by a (separate) dead code elimination pass.
-    rewriter.replaceOpWithNewOp<memref::SubViewOp>(op, sourceOp.getSource(),
-                                                   offsets, sizes, strides);
+    rewriter.replaceOpWithNewOp<memref::SubViewOp>(
+        op, op.getType(), sourceOp.getSource(), offsets, sizes, strides);
     return success();
   }
 };
diff --git a/mlir/test/Transforms/compose-subview.mlir b/mlir/test/Transforms/compose-subview.mlir
index cb8ebcb2bf9e9..2fa551c218d42 100644
--- a/mlir/test/Transforms/compose-subview.mlir
+++ b/mlir/test/Transforms/compose-subview.mlir
@@ -45,3 +45,51 @@ func.func @main(%input: memref<4x1024xf32>) -> memref<1x128xf32, strided<[1024,
   %1 = memref.subview %0[1, %cst_128] [1, 128] [1, 1] : memref<2x256xf32, strided<[1024, 1], offset: ?>> to memref<1x128xf32, strided<[1024, 1], offset: ?>>
   return %1 : memref<1x128xf32, strided<[1024, 1], offset: ?>>
 }
+
+// -----
+
+func.func @main(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: 4480>> {
+  //      CHECK: subview %arg0[4, 384] [1, 64] [4, 4] 
+  // CHECK-SAME: memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: 4480>>
+  %0 = memref.subview %input[2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: 2304>>
+  %1 = memref.subview %0[1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: 2304>> to memref<1x64xf32, strided<[4096, 4], offset: 4480>>
+  return %1 : memref<1x64xf32, strided<[4096, 4], offset: 4480>>
+}
+
+// -----
+
+func.func @main(%input: memref<30x30xf32>) -> memref<2x2xf32, strided<[240, 8], offset: 217>> {
+  //      CHECK: subview %arg0[7, 7] [2, 2] [8, 8]
+  // CHECK-SAME: memref<30x30xf32> to memref<2x2xf32, strided<[240, 8], offset: 217>>
+  %0 = memref.subview %input[1, 1] [12, 12] [2, 2] : memref<30x30xf32> to memref<12x12xf32, strided<[60, 2], offset: 31>>
+  %1 = memref.subview %0[1, 1] [5, 5] [2, 2] : memref<12x12xf32, strided<[60, 2], offset: 31>> to memref<5x5xf32, strided<[120, 4], offset: 93>>
+  %2 = memref.subview %1[1, 1] [2, 2] [2, 2] : memref<5x5xf32, strided<[120, 4], offset: 93>> to memref<2x2xf32, strided<[240, 8], offset: 217>>
+  return %2 : memref<2x2xf32, strided<[240, 8], offset: 217>> 
+}
+
+// -----
+
+func.func @main(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
+  //      CHECK:%[[VAL_1:.*]] = arith.constant 4 : index
+  %cst_2 = arith.constant 2 : index
+  //      CHECK:%[[VAL_2:.*]] = arith.constant 384 : index
+  %cst_64 = arith.constant 64 : index
+  //      CHECK: subview %arg0{{\[}}%[[VAL_1]], %[[VAL_2]]] [1, 64] [4, 4]
+  // CHECK-SAME: memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
+  %0 = memref.subview %input[%cst_2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: ?>>
+  %1 = memref.subview %0[1, %cst_64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: ?>> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
+  return %1 : memref<1x64xf32, strided<[4096, 4], offset: ?>>
+}
+
+// -----
+
+func.func @main(%input: memref<4x1024xf32>) -> memref<1x64xf32, strided<[4096, 4], offset: ?>> {
+  //      CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
+  %cst_1 = arith.constant 1 : index
+  %cst_2 = arith.constant 2 : index
+  //      CHECK: subview %arg0{{\[}}%[[VAL_1]], 384] [1, 64] [4, 4]
+  // CHECK-SAME: memref<4x1024xf32> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
+  %0 = memref.subview %input[%cst_2, 256] [2, 256] [2, 2] : memref<4x1024xf32> to memref<2x256xf32, strided<[2048, 2], offset: ?>>
+  %1 = memref.subview %0[%cst_1, 64] [1, 64] [2, 2] : memref<2x256xf32, strided<[2048, 2], offset: ?>> to memref<1x64xf32, strided<[4096, 4], offset: ?>>
+  return %1 : memref<1x64xf32, strided<[4096, 4], offset: ?>>
+}

Copy link
Contributor

@bondhugula bondhugula left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments.

@ftynse ftynse merged commit 2ecf608 into llvm:main Feb 7, 2024
@linuxlonelyeagle
Copy link
Member Author

@ftynse @bondhugula Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants