Skip to content

Commit 835a35c

Browse files
Add test for reproducing the crash
1 parent 532a053 commit 835a35c

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,3 +977,58 @@ module {
977977
// CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]]
978978
// CHECK: linalg.yield %[[T3]] : f32
979979
// CHECK: return %[[GENERIC]]
980+
981+
// -----
982+
983+
#map = affine_map<()[s0, s1] -> (s0 * s1)>
984+
#map1 = affine_map<(d0, d1, d2) -> (d0)>
985+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
986+
// CHECK-DAG: #[[$MAP_0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
987+
// CHECK-DAG: #[[$MAP_1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
988+
// CHECK-LABEL: func.func @no_fuse_expand_collapsed_generic_input(
989+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<?x?x?xf32>,
990+
// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<?x?xi64>,
991+
// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<?x?xi64>,
992+
// CHECK-SAME: %[[VAL_3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<?x?xi64>)
993+
func.func @no_fuse_expand_collapsed_generic_input(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xi64>, %arg2: tensor<?x?xi64>, %arg3: tensor<?x?xi64>) -> tensor<?x?x?x?xf32> {
994+
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %{{.+}} {{\[\[}}0, 1], [2], [3]] output_shape {{\[}}%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}} : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
995+
// CHECK: %[[OUT:.*]] = tensor.empty(%{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}) : tensor<?x?x?x?xf32>
996+
// CHECK: %[[VAL_4:.*]] = linalg.generic {indexing_maps = [#[[$MAP_0]], #[[$MAP_1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_3]] : tensor<?x?xi64>) outs(%[[OUT]] : tensor<?x?x?x?xf32>) {
997+
// CHECK: ^bb0(%[[VAL_5:.*]]: i64, %[[VAL_6:.*]]: f32):
998+
// CHECK: %[[OFFSETS:.*]] = arith.index_cast %[[VAL_5]] : i64 to index
999+
// CHECK: %[[SIZES:.*]] = linalg.index 2 : index
1000+
// CHECK: %[[STRIDES:.*]] = linalg.index 3 : index
1001+
// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[VAL_0]]{{\[}}%[[OFFSETS]], %[[SIZES]], %[[STRIDES]]] : tensor<?x?x?xf32>
1002+
// CHECK: linalg.yield %[[EXTRACT]] : f32
1003+
// CHECK: } -> tensor<?x?x?x?xf32>
1004+
// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[VAL_4]] {{\[\[}}0, 1], [2], [3]] : tensor<?x?x?x?xf32> into tensor<?x?x?xf32>
1005+
// CHECK: %[[SHAPE:.*]] = tensor.from_elements
1006+
// CHECK: %[[RESULT:.*]] = tensor.reshape %[[COLLAPSED]](%[[SHAPE]]) : (tensor<?x?x?xf32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
1007+
// CHECK: return %[[RESULT]] : tensor<?x?x?x?xf32>
1008+
// CHECK: }
1009+
%c1 = arith.constant 1 : index
1010+
%c0 = arith.constant 0 : index
1011+
%c2 = arith.constant 2 : index
1012+
%dim = tensor.dim %arg1, %c0 : tensor<?x?xi64>
1013+
%dim_0 = tensor.dim %arg1, %c1 : tensor<?x?xi64>
1014+
%0 = arith.index_cast %dim : index to i64
1015+
%1 = arith.index_cast %dim_0 : index to i64
1016+
%collapsed = tensor.collapse_shape %arg3 [[0, 1]] : tensor<?x?xi64> into tensor<?xi64>
1017+
%dim_1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
1018+
%dim_2 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
1019+
%2 = affine.apply #map()[%dim, %dim_0]
1020+
%3 = tensor.empty(%2, %dim_1, %dim_2) : tensor<?x?x?xf32>
1021+
%4 = linalg.generic {indexing_maps = [#map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} ins(%collapsed : tensor<?xi64>) outs(%3 : tensor<?x?x?xf32>) {
1022+
^bb0(%in: i64, %out: f32):
1023+
%7 = arith.index_cast %in : i64 to index
1024+
%8 = linalg.index 1 : index
1025+
%9 = linalg.index 2 : index
1026+
%extracted = tensor.extract %arg0[%7, %8, %9] : tensor<?x?x?xf32>
1027+
linalg.yield %extracted : f32
1028+
} -> tensor<?x?x?xf32>
1029+
%5 = arith.index_cast %dim_1 : index to i64
1030+
%6 = arith.index_cast %dim_2 : index to i64
1031+
%from_elements = tensor.from_elements %0, %1, %5, %6 : tensor<4xi64>
1032+
%reshape = tensor.reshape %4(%from_elements) : (tensor<?x?x?xf32>, tensor<4xi64>) -> tensor<?x?x?x?xf32>
1033+
return %reshape : tensor<?x?x?x?xf32>
1034+
}

0 commit comments

Comments
 (0)