Skip to content

Commit 5c6b295

Browse files
Add test for reproducing the crash
1 parent 532a053 commit 5c6b295

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed

mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,3 +977,93 @@ module {
977977
// CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]]
978978
// CHECK: linalg.yield %[[T3]] : f32
979979
// CHECK: return %[[GENERIC]]
980+
981+
// -----
982+
983+
#map = affine_map<()[s0, s1] -> (s0 * s1)>
984+
#map1 = affine_map<(d0, d1, d2) -> (d0)>
985+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
986+
module {
987+
func.func @no_fusio(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xi64>, %arg2: tensor<?x?xi64>, %arg3: tensor<?x?xi64>) -> tensor<?x?x?x?xf32> {
988+
%c1 = arith.constant 1 : index
989+
%c0 = arith.constant 0 : index
990+
%c2 = arith.constant 2 : index
991+
%dim = tensor.dim %arg1, %c0 : tensor<?x?xi64>
992+
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?xi64>
993+
%0 = arith.index_cast %dim : index to i64
994+
%1 = arith.index_cast %dim_0 : index to i64
995+
%collapsed = tensor.collapse_shape %arg3 [[0, 1]] : tensor<?x?xi64> into tensor<?xi64>
996+
%dim_1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
997+
%dim_2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
998+
%2 = affine.apply #map()[%dim, %dim_0]
999+
%3 = tensor.empty(%2, %dim_1, %dim_2) : tensor<?x?x?xf32>
1000+
%4 = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed : tensor<?xi64>) outs(%3 : tensor<?x?x?xf32>) {
1001+
^bb0(%in: i64, %out: f32):
1002+
%7 = arith.index_cast %in : i64 to index
1003+
%8 = linalg.index 1 : index
1004+
%9 = linalg.index 2 : index
1005+
%extracted = tensor.extract %arg0[%7, %8, %9] : tensor<?x?x?xf32>
1006+
linalg.yield %extracted : f32
1007+
} -> tensor<?x?x?xf32>
1008+
%5 = arith.index_cast %dim_1 : index to i64
1009+
%6 = arith.index_cast %dim_2 : index to i64
1010+
%from_elements = tensor.from_elements %0, %1, %5, %6 : tensor<4xi64>
1011+
%reshape = tensor.reshape %4(%from_elements) : (tensor<?x?x?xf32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
1012+
return %reshape : tensor<?x?x?x?xf32>
1013+
}
1014+
}
1015+
1016+
// -----
1017+
1018+
#map = affine_map<()[s0, s1] -> (s0 * s1)>
1019+
#map1 = affine_map<(d0, d1, d2) -> (d0)>
1020+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
1021+
// CHECK-DAG: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
1022+
// CHECK-DAG: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
1023+
// CHECK-LABEL: func.func @no_fuse_expand_collapsed_generic_input(
1024+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<?x?x?xf32>,
1025+
// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<?x?xi64>,
1026+
// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<?x?xi64>,
1027+
// CHECK-SAME: %[[VAL_3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<?x?xi64>)
1028+
func.func @no_fuse_expand_collapsed_generic_input(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xi64>, %arg2: tensor<?x?xi64>, %arg3: tensor<?x?xi64>) -> tensor<?x?x?x?xf32> {
1029+
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %{{.+}} {{\[\[}}0, 1], [2], [3]] output_shape {{\[}}%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
1030+
// CHECK: %[[OUT:.*]] = tensor.empty(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}) : tensor<?x?x?x?xf32>
1031+
// CHECK: %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_3]] : tensor<?x?xi64>) outs(%[[OUT]] : tensor<?x?x?x?xf32>) {
1032+
// CHECK: ^bb0(%[[VAL_5:.*]]: i64, %[[VAL_6:.*]]: f32):
1033+
// CHECK: %[[OFFSETS:.*]] = arith.index_cast %[[VAL_5]] : i64 to index
1034+
// CHECK: %[[SIZES:.*]] = linalg.index 2 : index
1035+
// CHECK: %[[STRIDES:.*]] = linalg.index 3 : index
1036+
// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[VAL_0]]{{\[}}%[[OFFSETS]], %[[SIZES]], %[[STRIDES]]] : tensor<?x?x?xf32>
1037+
// CHECK: linalg.yield %[[EXTRACT]] : f32
1038+
// CHECK: } -> tensor<?x?x?x?xf32>
1039+
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[VAL_4]] {{\[\[}}0, 1], [2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
1040+
// CHECK: %[[SHAPE:.*]] = tensor.from_elements
1041+
// CHECK: %[[RESULT:.*]] = tensor.reshape %[[COLLAPSED]](%[[SHAPE]]) : (tensor<?x?x?xf32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
1042+
// CHECK: return %[[RESULT]] : tensor<?x?x?x?xf32>
1043+
// CHECK: }
1044+
%c1 = arith.constant 1 : index
1045+
%c0 = arith.constant 0 : index
1046+
%c2 = arith.constant 2 : index
1047+
%dim = tensor.dim %arg1, %c0 : tensor<?x?xi64>
1048+
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?xi64>
1049+
%0 = arith.index_cast %dim : index to i64
1050+
%1 = arith.index_cast %dim_0 : index to i64
1051+
%collapsed = tensor.collapse_shape %arg3 [[0, 1]] : tensor<?x?xi64> into tensor<?xi64>
1052+
%dim_1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
1053+
%dim_2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
1054+
%2 = affine.apply #map()[%dim, %dim_0]
1055+
%3 = tensor.empty(%2, %dim_1, %dim_2) : tensor<?x?x?xf32>
1056+
%4 = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed : tensor<?xi64>) outs(%3 : tensor<?x?x?xf32>) {
1057+
^bb0(%in: i64, %out: f32):
1058+
%7 = arith.index_cast %in : i64 to index
1059+
%8 = linalg.index 1 : index
1060+
%9 = linalg.index 2 : index
1061+
%extracted = tensor.extract %arg0[%7, %8, %9] : tensor<?x?x?xf32>
1062+
linalg.yield %extracted : f32
1063+
} -> tensor<?x?x?xf32>
1064+
%5 = arith.index_cast %dim_1 : index to i64
1065+
%6 = arith.index_cast %dim_2 : index to i64
1066+
%from_elements = tensor.from_elements %0, %1, %5, %6 : tensor<4xi64>
1067+
%reshape = tensor.reshape %4(%from_elements) : (tensor<?x?x?xf32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
1068+
return %reshape : tensor<?x?x?x?xf32>
1069+
}

0 commit comments

Comments
 (0)