@@ -61,11 +61,37 @@ module attributes {transform.with_named_sequence} {
61
61
62
62
// -----
63
63
64
- // CHECK-LABEL: func.func @pack_all_dyn(
65
- func.func @pack_all_dyn (%arg0: tensor <?x?xf32 >, %arg1: tensor <?x?x?x?xf32 >) -> tensor <?x?x?x?xf32 > {
66
- %pack = tensor.pack %arg0 inner_dims_pos = [0 , 1 ] inner_tiles = [16 , 2 ] into %arg1
67
- : tensor <?x?xf32 > -> tensor <?x?x?x?xf32 >
68
-
64
+ // CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> (s1 * s0 - 64)>
65
+ // CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s1 * s0 - 128)>
66
+ // CHECK: func.func @pack_dyn_tiles(
67
+ // CHECK-SAME: %[[ARG0:.*]]: [[TENSOR_TY_0:tensor<64x128xf32>]]
68
+ // CHECK-SAME: %[[ARG1:.*]]: tensor<?x?x?x?xf32>,
69
+ // CHECK-SAME: %[[TILE0:.*]]: index,
70
+ // CHECK-SAME: %[[TILE1:.*]]: index
71
+ func.func @pack_dyn_tiles (%arg0: tensor <64 x128 xf32 >, %arg1: tensor <?x?x?x?xf32 >, %tile_0: index , %tile_1: index ) -> tensor <?x?x?x?xf32 > {
72
+ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
73
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
74
+ // CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG1]], %[[C0]]
75
+ // CHECK-DAG: %[[PAD0:.*]] = affine.apply #[[MAP0]]()[%[[TILE0]], %[[DIM0]]]
76
+ // CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]]
77
+ // CHECK-DAG: %[[PAD1:.*]] = affine.apply #[[MAP1]]()[%[[TILE1]], %[[DIM1]]]
78
+ // CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
79
+ // CHECK: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0] high[%[[PAD0]], %[[PAD1]]]
80
+ // CHECK-NEXT: ^bb0
81
+ // CHECK-NEXT: tensor.yield %[[CST]] : f32
82
+ // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
83
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
84
+ // CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]]
85
+ // CHECK-DAG: %[[DIM3:.*]] = tensor.dim %[[ARG1]], %[[C3]]
86
+ // CHECK-NEXT: %[[INIT_SHAPE:.*]] = tensor.empty() : tensor<4xindex>
87
+ // CHECK-NEXT: %[[SHAPE0:.*]] = tensor.insert %[[DIM0]] into %[[INIT_SHAPE]][%[[C0]]]
88
+ // CHECK-NEXT: %[[SHAPE1:.*]] = tensor.insert %[[DIM2]] into %[[SHAPE0]][%[[C1]]]
89
+ // CHECK-NEXT: %[[SHAPE2:.*]] = tensor.insert %[[DIM1]] into %[[SHAPE1]][%[[C2]]]
90
+ // CHECK-NEXT: %[[SHAPE3:.*]] = tensor.insert %[[DIM3]] into %[[SHAPE2]][%[[C3]]]
91
+ // CHECK-NEXT: %[[EXPANDED:.*]] = tensor.reshape %[[PADDED]](%[[SHAPE3]])
92
+ // CHECK-NEXT: %[[TRANSPOSED:.*]] = linalg.transpose ins(%[[EXPANDED]] : {{.*}}) outs(%[[ARG1]] {{.*}}) permutation = [0, 2, 1, 3]
93
+ %pack = tensor.pack %arg0 inner_dims_pos = [0 , 1 ] inner_tiles = [%tile_0 , %tile_1 ] into %arg1
94
+ : tensor <64 x128 xf32 > -> tensor <?x?x?x?xf32 >
69
95
return %pack : tensor <?x?x?x?xf32 >
70
96
}
71
97
0 commit comments