Skip to content

Commit 593f6fd

Browse files
authored
[mlir][tensor] Fix tensor.reshape canonicalization (llvm#90141)
Canonicalization defaulted to replacement when the input dims were from unknown source. This is obviously incorrect. Tweaked and included test to prevent future issue.
1 parent c2170a3 commit 593f6fd

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1609,6 +1609,9 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
16091609
cst.has_value() && cst.value() == static_cast<int64_t>(id);
16101610
continue;
16111611
}
1612+
1613+
dynamicNoop = false;
1614+
break;
16121615
}
16131616

16141617
if (dynamicNoop)

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2431,6 +2431,15 @@ func.func @reshape_nofold_2d(%arg0 : tensor<?x?xi32>) -> tensor<?x?xi32> {
24312431
return %reshape : tensor<?x?xi32>
24322432
}
24332433

2434+
// -----
2435+
2436+
// CHECK-LABEL: @reshape_nofold_2d_ins
2437+
func.func @reshape_nofold_2d_ins(%arg0 : tensor<?x?xi32>, %arg1: index, %arg2: index) -> tensor<?x?xi32> {
2438+
%ds = tensor.from_elements %arg1, %arg2 : tensor<2xindex>
2439+
// CHECK: tensor.reshape
2440+
%reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
2441+
return %reshape : tensor<?x?xi32>
2442+
}
24342443

24352444
// -----
24362445

0 commit comments

Comments
 (0)