@@ -977,3 +977,93 @@ module {
977
977
// CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]]
978
978
// CHECK: linalg.yield %[[T3]] : f32
979
979
// CHECK: return %[[GENERIC]]
980
+
981
+ // -----
982
+
983
+ #map = affine_map <()[s0 , s1 ] -> (s0 * s1 )>
984
+ #map1 = affine_map <(d0 , d1 , d2 ) -> (d0 )>
985
+ #map2 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
986
+ module {
987
+ func.func @no_fusio (%arg0: tensor <?x?x?xf32 >, %arg1: tensor <?x?xi64 >, %arg2: tensor <?x?xi64 >, %arg3: tensor <?x?xi64 >) -> tensor <?x?x?x?xf32 > {
988
+ %c1 = arith.constant 1 : index
989
+ %c0 = arith.constant 0 : index
990
+ %c2 = arith.constant 2 : index
991
+ %dim = tensor.dim %arg1 , %c0 : tensor <?x?xi64 >
992
+ %dim_0 = tensor.dim %arg1 , %c1 : tensor <?x?xi64 >
993
+ %0 = arith.index_cast %dim : index to i64
994
+ %1 = arith.index_cast %dim_0 : index to i64
995
+ %collapsed = tensor.collapse_shape %arg3 [[0 , 1 ]] : tensor <?x?xi64 > into tensor <?xi64 >
996
+ %dim_1 = tensor.dim %arg0 , %c1 : tensor <?x?x?xf32 >
997
+ %dim_2 = tensor.dim %arg0 , %c2 : tensor <?x?x?xf32 >
998
+ %2 = affine.apply #map ()[%dim , %dim_0 ]
999
+ %3 = tensor.empty (%2 , %dim_1 , %dim_2 ) : tensor <?x?x?xf32 >
1000
+ %4 = linalg.generic {index ing_maps = [#map1 , #map2 ], iterator_types = [" parallel" , " parallel" , " parallel" ]} ins (%collapsed : tensor <?xi64 >) outs (%3 : tensor <?x?x?xf32 >) {
1001
+ ^bb0 (%in: i64 , %out: f32 ):
1002
+ %7 = arith.index_cast %in : i64 to index
1003
+ %8 = linalg.index 1 : index
1004
+ %9 = linalg.index 2 : index
1005
+ %extracted = tensor.extract %arg0 [%7 , %8 , %9 ] : tensor <?x?x?xf32 >
1006
+ linalg.yield %extracted : f32
1007
+ } -> tensor <?x?x?xf32 >
1008
+ %5 = arith.index_cast %dim_1 : index to i64
1009
+ %6 = arith.index_cast %dim_2 : index to i64
1010
+ %from_elements = tensor.from_elements %0 , %1 , %5 , %6 : tensor <4 xi64 >
1011
+ %reshape = tensor.reshape %4 (%from_elements ) : (tensor <?x?x?xf32 >, tensor <4 xi64 >) -> tensor <?x?x?x?xf32 >
1012
+ return %reshape : tensor <?x?x?x?xf32 >
1013
+ }
1014
+ }
1015
+
1016
+ // -----
1017
+
1018
+ #map = affine_map <()[s0 , s1 ] -> (s0 * s1 )>
1019
+ #map1 = affine_map <(d0 , d1 , d2 ) -> (d0 )>
1020
+ #map2 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
1021
+ // CHECK-DAG: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
1022
+ // CHECK-DAG: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
1023
+ // CHECK-LABEL: func.func @no_fuse_expand_collapsed_generic_input(
1024
+ // CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<?x?x?xf32>,
1025
+ // CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<?x?xi64>,
1026
+ // CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<?x?xi64>,
1027
+ // CHECK-SAME: %[[VAL_3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<?x?xi64>)
1028
+ func.func @no_fuse_expand_collapsed_generic_input (%arg0: tensor <?x?x?xf32 >, %arg1: tensor <?x?xi64 >, %arg2: tensor <?x?xi64 >, %arg3: tensor <?x?xi64 >) -> tensor <?x?x?x?xf32 > {
1029
+ // CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %{{.+}} {{\[\[}}0, 1], [2], [3]] output_shape {{\[}}%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
1030
+ // CHECK: %[[OUT:.*]] = tensor.empty(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}) : tensor<?x?x?x?xf32>
1031
+ // CHECK: %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_3]] : tensor<?x?xi64>) outs(%[[OUT]] : tensor<?x?x?x?xf32>) {
1032
+ // CHECK: ^bb0(%[[VAL_5:.*]]: i64, %[[VAL_6:.*]]: f32):
1033
+ // CHECK: %[[OFFSETS:.*]] = arith.index_cast %[[VAL_5]] : i64 to index
1034
+ // CHECK: %[[SIZES:.*]] = linalg.index 2 : index
1035
+ // CHECK: %[[STRIDES:.*]] = linalg.index 3 : index
1036
+ // CHECK: %[[EXTRACT:.*]] = tensor.extract %[[VAL_0]]{{\[}}%[[OFFSETS]], %[[SIZES]], %[[STRIDES]]] : tensor<?x?x?xf32>
1037
+ // CHECK: linalg.yield %[[EXTRACT]] : f32
1038
+ // CHECK: } -> tensor<?x?x?x?xf32>
1039
+ // CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[VAL_4]] {{\[\[}}0, 1], [2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
1040
+ // CHECK: %[[SHAPE:.*]] = tensor.from_elements
1041
+ // CHECK: %[[RESULT:.*]] = tensor.reshape %[[COLLAPSED]](%[[SHAPE]]) : (tensor<?x?x?xf32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
1042
+ // CHECK: return %[[RESULT]] : tensor<?x?x?x?xf32>
1043
+ // CHECK: }
1044
+ %c1 = arith.constant 1 : index
1045
+ %c0 = arith.constant 0 : index
1046
+ %c2 = arith.constant 2 : index
1047
+ %dim = tensor.dim %arg1 , %c0 : tensor <?x?xi64 >
1048
+ %dim_0 = tensor.dim %arg1 , %c1 : tensor <?x?xi64 >
1049
+ %0 = arith.index_cast %dim : index to i64
1050
+ %1 = arith.index_cast %dim_0 : index to i64
1051
+ %collapsed = tensor.collapse_shape %arg3 [[0 , 1 ]] : tensor <?x?xi64 > into tensor <?xi64 >
1052
+ %dim_1 = tensor.dim %arg0 , %c1 : tensor <?x?x?xf32 >
1053
+ %dim_2 = tensor.dim %arg0 , %c2 : tensor <?x?x?xf32 >
1054
+ %2 = affine.apply #map ()[%dim , %dim_0 ]
1055
+ %3 = tensor.empty (%2 , %dim_1 , %dim_2 ) : tensor <?x?x?xf32 >
1056
+ %4 = linalg.generic {index ing_maps = [#map1 , #map2 ], iterator_types = [" parallel" , " parallel" , " parallel" ]} ins (%collapsed : tensor <?xi64 >) outs (%3 : tensor <?x?x?xf32 >) {
1057
+ ^bb0 (%in: i64 , %out: f32 ):
1058
+ %7 = arith.index_cast %in : i64 to index
1059
+ %8 = linalg.index 1 : index
1060
+ %9 = linalg.index 2 : index
1061
+ %extracted = tensor.extract %arg0 [%7 , %8 , %9 ] : tensor <?x?x?xf32 >
1062
+ linalg.yield %extracted : f32
1063
+ } -> tensor <?x?x?xf32 >
1064
+ %5 = arith.index_cast %dim_1 : index to i64
1065
+ %6 = arith.index_cast %dim_2 : index to i64
1066
+ %from_elements = tensor.from_elements %0 , %1 , %5 , %6 : tensor <4 xi64 >
1067
+ %reshape = tensor.reshape %4 (%from_elements ) : (tensor <?x?x?xf32 >, tensor <4 xi64 >) -> tensor <?x?x?x?xf32 >
1068
+ return %reshape : tensor <?x?x?x?xf32 >
1069
+ }
0 commit comments