Skip to content

Commit c045955

Browse files
authored
[mlir][tensor] Fold tensor.reshape for dynamic reshape (#88961)
If `tensor.reshape` occurs with `d0, d1, d2, ...` for the dimensions we know that the reshape is a no-op. Checking for this case lets us fold away the computation.
1 parent ab22504 commit c045955

File tree

2 files changed

+82
-0
lines changed

2 files changed

+82
-0
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,6 +1580,41 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
15801580
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
15811581
getResult().getType()))
15821582
return reshapedSource;
1583+
1584+
auto source = getSource();
1585+
auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
1586+
auto resultTy = dyn_cast<RankedTensorType>(getType());
1587+
1588+
if (!sourceTy || !resultTy || sourceTy != resultTy)
1589+
return {};
1590+
1591+
if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
1592+
auto elements = fromElements.getElements();
1593+
bool dynamicNoop =
1594+
sourceTy.getRank() == static_cast<int64_t>(elements.size());
1595+
for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
1596+
auto element = elements[id];
1597+
1598+
if (auto cst = getConstantIntValue(element)) {
1599+
dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
1600+
continue;
1601+
}
1602+
1603+
if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
1604+
dynamicNoop &= dimOp.getSource() == source;
1605+
1606+
APSInt dim;
1607+
auto cst = getConstantIntValue(dimOp.getIndex());
1608+
dynamicNoop &=
1609+
cst.has_value() && cst.value() == static_cast<int64_t>(id);
1610+
continue;
1611+
}
1612+
}
1613+
1614+
if (dynamicNoop)
1615+
return source;
1616+
}
1617+
15831618
return {};
15841619
}
15851620

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2403,6 +2403,53 @@ func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xinde
24032403

24042404
// -----
24052405

2406+
// CHECK-LABEL: @reshape_fold_2d
2407+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
2408+
func.func @reshape_fold_2d(%arg0 : tensor<?x?xi32>) -> tensor<?x?xi32> {
2409+
%c0 = arith.constant 0 : index
2410+
%c1 = arith.constant 1 : index
2411+
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
2412+
%d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
2413+
%ds = tensor.from_elements %d0, %d1 : tensor<2xindex>
2414+
%reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
2415+
// CHECK: return %[[ARG0]]
2416+
return %reshape : tensor<?x?xi32>
2417+
}
2418+
2419+
// -----
2420+
2421+
// CHECK-LABEL: @reshape_nofold_2d
2422+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
2423+
func.func @reshape_nofold_2d(%arg0 : tensor<?x?xi32>) -> tensor<?x?xi32> {
2424+
%c0 = arith.constant 0 : index
2425+
%c1 = arith.constant 1 : index
2426+
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xi32>
2427+
%d1 = tensor.dim %arg0, %c1 : tensor<?x?xi32>
2428+
%ds = tensor.from_elements %d1, %d0 : tensor<2xindex>
2429+
// CHECK: tensor.reshape
2430+
%reshape = tensor.reshape %arg0(%ds) : (tensor<?x?xi32>, tensor<2xindex>) -> tensor<?x?xi32>
2431+
return %reshape : tensor<?x?xi32>
2432+
}
2433+
2434+
2435+
// -----
2436+
2437+
// CHECK-LABEL: @reshape_fold_3d_cst
2438+
// CHECK-SAME: %[[ARG0:.+]]: tensor<5x?x?xi32>
2439+
func.func @reshape_fold_3d_cst(%arg0 : tensor<5x?x?xi32>) -> tensor<5x?x?xi32> {
2440+
%c1 = arith.constant 1 : index
2441+
%c2 = arith.constant 2 : index
2442+
%d0 = arith.constant 5 : index
2443+
%d1 = tensor.dim %arg0, %c1 : tensor<5x?x?xi32>
2444+
%d2 = tensor.dim %arg0, %c2 : tensor<5x?x?xi32>
2445+
%ds = tensor.from_elements %d0, %d1, %d2 : tensor<3xindex>
2446+
%reshape = tensor.reshape %arg0(%ds) : (tensor<5x?x?xi32>, tensor<3xindex>) -> tensor<5x?x?xi32>
2447+
// CHECK: return %[[ARG0]]
2448+
return %reshape : tensor<5x?x?xi32>
2449+
}
2450+
2451+
// -----
2452+
24062453
// Test case: This test fails to fold because the index of tensor.dim is out_of_bounds
24072454
// CHECK-LABEL: func @dim_out_of_bounds(
24082455
// CHECK: %[[IDX:.*]] = index.constant 28

0 commit comments

Comments
 (0)