@@ -366,14 +366,10 @@ func.func @tensor.insert(%t1: tensor<5xf32>, %idx1: index, %f: f32) -> tensor<5x
366
366
// -----
367
367
368
368
// CHECK-LABEL: func @tensor.expand_shape(
369
- // CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32>
369
+ // CHECK-SAME: %[[t1:.*]]: tensor<?x10xf32>, %[[sz0:.*]]: index
370
370
func.func @tensor.expand_shape (%t1: tensor <?x10 xf32 >, %sz0: index ) -> tensor <2 x?x10 xf32 > {
371
371
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]]
372
- // CHECK: %[[C0:.*]] = arith.constant 0 : index
373
- // CHECK: %[[DIM:.*]] = memref.dim %[[m1]], %[[C0]] : memref<?x10xf32>
374
- // CHECK: %[[C2:.*]] = arith.constant 2 : index
375
- // CHECK: %[[VAL_1:.*]] = arith.divsi %[[DIM]], %[[C2]] : index
376
- // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1], [2]] output_shape [2, %[[VAL_1]], 10] : memref<?x10xf32> into memref<2x?x10xf32>
372
+ // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1], [2]] output_shape [2, %[[sz0]], 10] : memref<?x10xf32> into memref<2x?x10xf32>
377
373
%0 = tensor.expand_shape %t1 [[0 , 1 ], [2 ]] output_shape [2 , %sz0 , 10 ]
378
374
: tensor <?x10 xf32 > into tensor <2 x?x10 xf32 >
379
375
@@ -385,23 +381,20 @@ func.func @tensor.expand_shape(%t1: tensor<?x10xf32>, %sz0: index) -> tensor<2x?
385
381
// -----
386
382
387
383
// CHECK-LABEL: func @tensor.expand_shape_of_slice(
388
- // CHECK-SAME: %[[t1:.*]]: tensor<?x20xf32>
384
+ // CHECK-SAME: %[[t1:.*]]: tensor<?x20xf32>, %{{.*}}: index, %{{.*}}: index, %[[sz0:.*]]: index
389
385
func.func @tensor.expand_shape_of_slice (
390
386
%t1: tensor <?x20 xf32 >, %o1: index , %s1: index , %sz0: index ) -> tensor <?x7 x2 x5 xf32 > {
391
387
// CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] :
392
388
// CHECK: %[[subview:.*]] = memref.subview %[[m1]][%{{.*}}, 5] [%{{.*}}, 10] [1, 1] : memref<?x20xf32> to memref<?x10xf32, strided<[20, 1], offset: ?>>
393
389
%0 = tensor.extract_slice %t1 [%o1 , 5 ][%s1 , 10 ][1 , 1 ] :
394
390
tensor <?x20 xf32 > to tensor <?x10 xf32 >
395
- // CHECK: %[[C7:.*]] = arith.constant 7 : index
396
- // CHECK: %[[VAL_1:.*]] = arith.divsi %{{.*}}, %[[C7]] : index
397
- // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[VAL_1]], 7, 2, 5] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
391
+ // CHECK: %[[expanded:.*]] = memref.expand_shape %[[subview]] {{\[\[}}0, 1], [2, 3]] output_shape [%[[sz0]], 7, 2, 5] : memref<?x10xf32, strided<[20, 1], offset: ?>> into memref<?x7x2x5xf32, strided<[140, 20, 5, 1], offset: ?>>
398
392
%1 = tensor.expand_shape %0 [[0 , 1 ], [2 , 3 ]] output_shape [%sz0 , 7 , 2 , 5 ] :
399
393
tensor <?x10 xf32 > into tensor <?x7 x2 x5 xf32 >
400
394
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
401
395
// CHECK: return %[[r]]
402
396
return %1 : tensor <?x7 x2 x5 xf32 >
403
397
}
404
-
405
398
// -----
406
399
407
400
// CHECK-LABEL: func @tensor.expand_shape_of_scalar_slice(
@@ -417,7 +410,20 @@ func.func @tensor.expand_shape_of_scalar_slice(
417
410
// CHECK: return %[[r]]
418
411
return %1 : tensor <1 xf32 >
419
412
}
413
+ // -----
420
414
415
+ // CHECK-LABEL: func @tensor.expand_shape_multiple_dynamic_indices(
416
+ // CHECK-SAME: %[[t1:.*]]: tensor<?x256xf32>, %[[sz0:.*]]: index, %[[sz1:.*]]: index, %[[sz2:.*]]: index
417
+ func.func @tensor.expand_shape_multiple_dynamic_indices (%t1: tensor <?x256 xf32 >, %sz0: index , %sz1: index , %sz2: index ) -> tensor <?x?x?x256 xf32 > {
418
+ // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]]
419
+ // CHECK: %[[expanded:.*]] = memref.expand_shape %[[m1]] {{\[\[}}0, 1, 2], [3]] output_shape [%[[sz0]], %[[sz1]], %[[sz2]], 256] : memref<?x256xf32> into memref<?x?x?x256xf32>
420
+ %0 = tensor.expand_shape %t1 [[0 , 1 , 2 ], [3 ]] output_shape [%sz0 , %sz1 , %sz2 , 256 ]
421
+ : tensor <?x256 xf32 > into tensor <?x?x?x256 xf32 >
422
+
423
+ // CHECK: %[[r:.*]] = bufferization.to_tensor %[[expanded]]
424
+ // CHECK: return %[[r]]
425
+ return %0 : tensor <?x?x?x256 xf32 >
426
+ }
421
427
// -----
422
428
423
429
// CHECK-LABEL: func @tensor.collapse_shape(
@@ -646,3 +652,6 @@ func.func @parallel_insert_slice_copy_before_write(%in: tensor<4xf32>, %out: ten
646
652
// CHECK: }
647
653
return
648
654
}
655
+
656
+ // -----
657
+
0 commit comments