@@ -977,3 +977,58 @@ module {
977
977
// CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]]
978
978
// CHECK: linalg.yield %[[T3]] : f32
979
979
// CHECK: return %[[GENERIC]]
980
+
981
+ // -----
982
+
983
+ #map = affine_map <()[s0 , s1 ] -> (s0 * s1 )>
984
+ #map1 = affine_map <(d0 , d1 , d2 ) -> (d0 )>
985
+ #map2 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
986
+ // CHECK-DAG: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
987
+ // CHECK-DAG: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
988
+ // CHECK-LABEL: func.func @no_fuse_expand_collapsed_generic_input(
989
+ // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<?x?x?xf32>,
990
+ // CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<?x?xi64>,
991
+ // CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<?x?xi64>,
992
+ // CHECK-SAME: %[[VAL_3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<?x?xi64>)
993
+ func.func @no_fuse_expand_collapsed_generic_input (%arg0: tensor <?x?x?xf32 >, %arg1: tensor <?x?xi64 >, %arg2: tensor <?x?xi64 >, %arg3: tensor <?x?xi64 >) -> tensor <?x?x?x?xf32 > {
994
+ // CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %{{.+}} {{\[\[}}0, 1], [2], [3]] output_shape {{\[}}%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
995
+ // CHECK: %[[OUT:.*]] = tensor.empty(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}) : tensor<?x?x?x?xf32>
996
+ // CHECK: %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_3]] : tensor<?x?xi64>) outs(%[[OUT]] : tensor<?x?x?x?xf32>) {
997
+ // CHECK: ^bb0(%[[VAL_5:.*]]: i64, %[[VAL_6:.*]]: f32):
998
+ // CHECK: %[[OFFSETS:.*]] = arith.index_cast %[[VAL_5]] : i64 to index
999
+ // CHECK: %[[SIZES:.*]] = linalg.index 2 : index
1000
+ // CHECK: %[[STRIDES:.*]] = linalg.index 3 : index
1001
+ // CHECK: %[[EXTRACT:.*]] = tensor.extract %[[VAL_0]]{{\[}}%[[OFFSETS]], %[[SIZES]], %[[STRIDES]]] : tensor<?x?x?xf32>
1002
+ // CHECK: linalg.yield %[[EXTRACT]] : f32
1003
+ // CHECK: } -> tensor<?x?x?x?xf32>
1004
+ // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[VAL_4]] {{\[\[}}0, 1], [2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
1005
+ // CHECK: %[[SHAPE:.*]] = tensor.from_elements
1006
+ // CHECK: %[[RESULT:.*]] = tensor.reshape %[[COLLAPSED]](%[[SHAPE]]) : (tensor<?x?x?xf32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
1007
+ // CHECK: return %[[RESULT]] : tensor<?x?x?x?xf32>
1008
+ // CHECK: }
1009
+ %c1 = arith.constant 1 : index
1010
+ %c0 = arith.constant 0 : index
1011
+ %c2 = arith.constant 2 : index
1012
+ %dim = tensor.dim %arg1 , %c0 : tensor <?x?xi64 >
1013
+ %dim_0 = tensor.dim %arg1 , %c1 : tensor <?x?xi64 >
1014
+ %0 = arith.index_cast %dim : index to i64
1015
+ %1 = arith.index_cast %dim_0 : index to i64
1016
+ %collapsed = tensor.collapse_shape %arg3 [[0 , 1 ]] : tensor <?x?xi64 > into tensor <?xi64 >
1017
+ %dim_1 = tensor.dim %arg0 , %c1 : tensor <?x?x?xf32 >
1018
+ %dim_2 = tensor.dim %arg0 , %c2 : tensor <?x?x?xf32 >
1019
+ %2 = affine.apply #map ()[%dim , %dim_0 ]
1020
+ %3 = tensor.empty (%2 , %dim_1 , %dim_2 ) : tensor <?x?x?xf32 >
1021
+ %4 = linalg.generic {index ing_maps = [#map1 , #map2 ], iterator_types = [" parallel" , " parallel" , " parallel" ]} ins (%collapsed : tensor <?xi64 >) outs (%3 : tensor <?x?x?xf32 >) {
1022
+ ^bb0 (%in: i64 , %out: f32 ):
1023
+ %7 = arith.index_cast %in : i64 to index
1024
+ %8 = linalg.index 1 : index
1025
+ %9 = linalg.index 2 : index
1026
+ %extracted = tensor.extract %arg0 [%7 , %8 , %9 ] : tensor <?x?x?xf32 >
1027
+ linalg.yield %extracted : f32
1028
+ } -> tensor <?x?x?xf32 >
1029
+ %5 = arith.index_cast %dim_1 : index to i64
1030
+ %6 = arith.index_cast %dim_2 : index to i64
1031
+ %from_elements = tensor.from_elements %0 , %1 , %5 , %6 : tensor <4 xi64 >
1032
+ %reshape = tensor.reshape %4 (%from_elements ) : (tensor <?x?x?xf32 >, tensor <4 xi64 >) -> tensor <?x?x?x?xf32 >
1033
+ return %reshape : tensor <?x?x?x?xf32 >
1034
+ }
0 commit comments