@@ -2403,6 +2403,53 @@ func.func @dim_of_reshape_undominated(%arg0: tensor<*xf32>, %arg1: tensor<?xinde
2403
2403
2404
2404
// -----
2405
2405
2406
+ // CHECK-LABEL: @reshape_fold_2d
2407
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
2408
+ func.func @reshape_fold_2d (%arg0 : tensor <?x?xi32 >) -> tensor <?x?xi32 > {
2409
+ %c0 = arith.constant 0 : index
2410
+ %c1 = arith.constant 1 : index
2411
+ %d0 = tensor.dim %arg0 , %c0 : tensor <?x?xi32 >
2412
+ %d1 = tensor.dim %arg0 , %c1 : tensor <?x?xi32 >
2413
+ %ds = tensor.from_elements %d0 , %d1 : tensor <2 xindex >
2414
+ %reshape = tensor.reshape %arg0 (%ds ) : (tensor <?x?xi32 >, tensor <2 xindex >) -> tensor <?x?xi32 >
2415
+ // CHECK: return %[[ARG0]]
2416
+ return %reshape : tensor <?x?xi32 >
2417
+ }
2418
+
2419
+ // -----
2420
+
2421
+ // CHECK-LABEL: @reshape_nofold_2d
2422
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xi32>
2423
+ func.func @reshape_nofold_2d (%arg0 : tensor <?x?xi32 >) -> tensor <?x?xi32 > {
2424
+ %c0 = arith.constant 0 : index
2425
+ %c1 = arith.constant 1 : index
2426
+ %d0 = tensor.dim %arg0 , %c0 : tensor <?x?xi32 >
2427
+ %d1 = tensor.dim %arg0 , %c1 : tensor <?x?xi32 >
2428
+ %ds = tensor.from_elements %d1 , %d0 : tensor <2 xindex >
2429
+ // CHECK: tensor.reshape
2430
+ %reshape = tensor.reshape %arg0 (%ds ) : (tensor <?x?xi32 >, tensor <2 xindex >) -> tensor <?x?xi32 >
2431
+ return %reshape : tensor <?x?xi32 >
2432
+ }
2433
+
2434
+
2435
+ // -----
2436
+
2437
+ // CHECK-LABEL: @reshape_fold_3d_cst
2438
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<5x?x?xi32>
2439
+ func.func @reshape_fold_3d_cst (%arg0 : tensor <5 x?x?xi32 >) -> tensor <5 x?x?xi32 > {
2440
+ %c1 = arith.constant 1 : index
2441
+ %c2 = arith.constant 2 : index
2442
+ %d0 = arith.constant 5 : index
2443
+ %d1 = tensor.dim %arg0 , %c1 : tensor <5 x?x?xi32 >
2444
+ %d2 = tensor.dim %arg0 , %c2 : tensor <5 x?x?xi32 >
2445
+ %ds = tensor.from_elements %d0 , %d1 , %d2 : tensor <3 xindex >
2446
+ %reshape = tensor.reshape %arg0 (%ds ) : (tensor <5 x?x?xi32 >, tensor <3 xindex >) -> tensor <5 x?x?xi32 >
2447
+ // CHECK: return %[[ARG0]]
2448
+ return %reshape : tensor <5 x?x?xi32 >
2449
+ }
2450
+
2451
+ // -----
2452
+
2406
2453
// Test case: This test fails to fold because the index of tensor.dim is out_of_bounds
2407
2454
// CHECK-LABEL: func @dim_out_of_bounds(
2408
2455
// CHECK: %[[IDX:.*]] = index.constant 28
0 commit comments