Skip to content

[mlir][tensor] Fix bufferization interface for 'tensor.reshape' #128590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

christopherbate
Copy link
Contributor

Previously, the BufferizableOpInterface implementation for 'tensor.reshape'
listed the 'shape' operand as an alias for the result tensor, causing
unnecessary conflicts with ops that "write" to the shape operand.

Previously, the BufferizableOpInterface implementation for 'tensor.reshape'
listed the 'shape' operand as an alias for the result tensor, causing
unnecessary conflicts with ops that "write" to the shape operand.
@llvmbot
Copy link
Member

llvmbot commented Feb 24, 2025

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Christopher Bate (christopherbate)

Changes

Previously, the BufferizableOpInterface implementation for 'tensor.reshape'
listed the 'shape' operand as an alias for the result tensor, causing
unnecessary conflicts with ops that "write" to the shape operand.


Full diff: https://github.com/llvm/llvm-project/pull/128590.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+4)
  • (modified) mlir/test/Dialect/Tensor/one-shot-bufferize.mlir (+27)
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 81404fa664cd4..8b7aee67ea5c2 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -862,6 +862,10 @@ struct ReshapeOpInterface
 
   AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
                                       const AnalysisState &state) const {
+    // Only the 'source' operand aliases the result.
+    auto reshapeOp = cast<tensor::ReshapeOp>(op);
+    if (reshapeOp.getSourceMutable() != opOperand)
+      return {};
     return {{op->getOpResult(0), BufferRelation::Equivalent}};
   }
 
diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
index af4f84640890b..2983cd30258a5 100644
--- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
+++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir
@@ -398,6 +398,33 @@ func.func @tensor.reshape() -> tensor<2x2x5xf32> {
 
 // -----
 
+// CHECK-LABEL: func @tensor_reshape_aliasing
+//  CHECK-SAME:  (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index)
+func.func @tensor_reshape_aliasing(%arg0: index, %arg1: index) -> tensor<?x?xf32> {
+  %t1_static = arith.constant dense<0.> : tensor<10xf32>
+  // CHECK-DAG: %[[T1:.+]] = memref.cast
+  %t1 = tensor.cast %t1_static : tensor<10xf32> to tensor<?xf32>
+
+  // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+  %c0 = arith.constant 0 : index
+  // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+  %c1 = arith.constant 1 : index
+
+  // CHECK-DAG: %[[SHAPE:.+]] = memref.alloc() {{.*}} : memref<2xindex>
+  %shape = bufferization.alloc_tensor() : tensor<2xindex>
+  // CHECK: memref.store %[[ARG0]], %[[SHAPE]][%[[C0]]]
+  %shape.0 = tensor.insert %arg0 into %shape[%c0] : tensor<2xindex>
+  // CHECK: memref.store %[[ARG1]], %[[SHAPE]][%[[C1]]]
+  %shape.1 = tensor.insert %arg1 into %shape.0[%c1] : tensor<2xindex>
+
+  // CHECK: %[[RESHAPED:.+]] = memref.reshape %[[T1]](%[[SHAPE]])
+  %reshaped = tensor.reshape %t1(%shape.1) : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
+  // CHECK: return %[[RESHAPED]]
+  return %reshaped : tensor<?x?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @reshape_with_non_identity_layout(
 // CHECK-SAME:    %[[INPUT:[a-zA-Z0-9]*]]: memref<2x2xf32, strided<[?, ?], offset: ?>, 3>,
 // CHECK-SAME:    %[[LAYOUT:[a-zA-Z0-9]*]]: memref<2xi32, strided<[?], offset: ?>>,

@christopherbate christopherbate merged commit 3438dfc into llvm:main Mar 13, 2025
11 checks passed
frederik-h pushed a commit to frederik-h/llvm-project that referenced this pull request Mar 18, 2025
…#128590)

Previously, the BufferizableOpInterface implementation for
'tensor.reshape'
listed the 'shape' operand as an alias for the result tensor, causing
unnecessary conflicts with ops that "write" to the shape operand.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants