Skip to content

Commit 3b232f0

Browse files
[mlir][linalg] LinalgOp: Disallow mixed tensor/buffer semantics (llvm#80660)
Related discussion: https://github.com/llvm/llvm-project/pull/73908/files#r1414913030. This change fixes llvm#73547.
1 parent be083db commit 3b232f0

File tree

4 files changed

+29
-81
lines changed

4 files changed

+29
-81
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,6 +1041,11 @@ int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
10411041
LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
10421042
LinalgOp linalgOp = cast<LinalgOp>(op);
10431043

1044+
// Mixed tensor/buffer operands are not allowed.
1045+
if (!linalgOp.hasPureTensorSemantics() &&
1046+
!linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
1047+
return op->emitOpError("expected to have pure tensor or buffer semantics");
1048+
10441049
// Before checking indexing maps, we need to make sure the attributes
10451050
// referenced by it are valid.
10461051
if (linalgOp.hasDynamicIndexingMaps())

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 14 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -102,17 +102,16 @@ func.func @tensor.cast.unranked(%a : tensor<*xf32>, %b : tensor<*xf32>, %c : ten
102102
// -----
103103

104104
// CHECK-LABEL: func @linalg_effects(
105-
// CHECK-SAME: %[[A:[a-z0-9]*]]: tensor<?x?xf32>
106-
// CHECK-SAME: %[[B:[a-z0-9]*]]: memref<?x?xf32>
107-
// CHECK-SAME: %[[C:[a-z0-9]*]]: tensor<?x?xf32>
108-
func.func @linalg_effects(%a : tensor<?x?xf32>, %b : memref<?x?xf32>, %c : tensor<?x?xf32>) {
105+
func.func @linalg_effects(
106+
%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : tensor<?x?xf32>,
107+
%d : memref<?x?xf32>, %e : memref<?x?xf32>, %f : memref<?x?xf32>) {
109108
// CHECK-NOT: %{{.*}} = linalg.matmul
110-
%t = linalg.matmul ins(%a, %b : tensor<?x?xf32>, memref<?x?xf32>)
109+
%t = linalg.matmul ins(%a, %b : tensor<?x?xf32>, tensor<?x?xf32>)
111110
outs(%c : tensor<?x?xf32>) -> tensor<?x?xf32>
112111

113112
// CHECK: linalg.matmul
114-
linalg.matmul ins(%a, %c : tensor<?x?xf32>, tensor<?x?xf32>)
115-
outs(%b : memref<?x?xf32>)
113+
linalg.matmul ins(%d, %e : memref<?x?xf32>, memref<?x?xf32>)
114+
outs(%f : memref<?x?xf32>)
116115
return
117116
}
118117

@@ -889,39 +888,38 @@ func.func @fold_multi_use_generic_op_with_consumer(%arg0 : tensor<?x?x?xf32>) ->
889888
// -----
890889

891890
#map = affine_map<(d0) -> (d0)>
892-
func.func @identity_mixed(%arg0 : tensor<?xf32>, %arg1: memref<?xf32>) {
891+
func.func @identity_buffer(%arg0 : memref<?xf32>, %arg1: memref<?xf32>) {
893892
linalg.generic {
894893
indexing_maps = [#map, #map],
895894
iterator_types = ["parallel"]
896-
} ins(%arg0 : tensor<?xf32>)
895+
} ins(%arg0 : memref<?xf32>)
897896
outs(%arg1 : memref<?xf32>) {
898897
^bb0(%arg2 : f32, %arg3 : f32):
899898
linalg.yield %arg2 : f32
900899
}
901900
return
902901
}
903902

904-
// There was a crash in EraseIdentityGenericOp for generic with mixed semantics.
905-
// For now, check generic remained unchanged.
906-
// CHECK-LABEL: func @identity_mixed
907-
// CHECK-SAME: (%[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: memref<?xf32>)
903+
// Do not erase ops with buffer semantics.
904+
// CHECK-LABEL: func @identity_buffer
905+
// CHECK-SAME: (%[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf32>)
908906
// CHECK: linalg.generic {
909907
// CHECK-SAME: indexing_maps = [#map, #map],
910908
// CHECK-SAME: iterator_types = ["parallel"]
911-
// CHECK-SAME: } ins(%[[ARG1]] : tensor<?xf32>)
909+
// CHECK-SAME: } ins(%[[ARG1]] : memref<?xf32>)
912910
// CHECK-SAME: outs(%[[ARG2]] : memref<?xf32>) {
913911

914912
// -----
915913

916914
// Just make sure that we don't crash.
917915

918916
// CHECK-LABEL: func @dedeplicate_regression_test
919-
func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: memref<4xf32>) {
917+
func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: tensor<4xf32>) {
920918
%36 = linalg.generic
921919
{indexing_maps = [affine_map<(d0) -> (d0)>,
922920
affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
923921
iterator_types = ["parallel"]}
924-
ins(%1, %1 : memref<4xf32>, memref<4xf32>)
922+
ins(%1, %1 : tensor<4xf32>, tensor<4xf32>)
925923
outs(%0 : tensor<4xf32>) {
926924
^bb0(%in: f32, %in_24: f32, %out: f32):
927925
linalg.yield %in : f32
@@ -937,31 +935,6 @@ func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: memref<4xf32>) {
937935

938936
// -----
939937

940-
#map = affine_map<(d0) -> (d0)>
941-
func.func @cast_producer_mixed(%arg0 : tensor<5xf32>, %arg1: memref<?xf32>) {
942-
%0 = tensor.cast %arg0 : tensor<5xf32> to tensor<?xf32>
943-
linalg.generic {
944-
indexing_maps = [#map, #map],
945-
iterator_types = ["parallel"]
946-
} ins(%0 : tensor<?xf32>)
947-
outs(%arg1 : memref<?xf32>) {
948-
^bb0(%arg2 : f32, %arg3 : f32):
949-
linalg.yield %arg2 : f32
950-
}
951-
return
952-
}
953-
954-
// We need a mixed linalg as a bridge between tensor and memref worlds.
955-
// CHECK-LABEL: func @cast_producer_mixed
956-
// CHECK-SAME: (%[[ARG1:.*]]: tensor<5xf32>, %[[ARG2:.*]]: memref<?xf32>)
957-
// CHECK: linalg.generic {
958-
// CHECK-SAME: indexing_maps = [#map, #map],
959-
// CHECK-SAME: iterator_types = ["parallel"]
960-
// CHECK-SAME: } ins(%[[ARG1]] : tensor<5xf32>)
961-
// CHECK-SAME: outs(%[[ARG2]] : memref<?xf32>) {
962-
963-
// -----
964-
965938
// CHECK-LABEL: dead_softmax
966939
func.func @dead_softmax(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> {
967940
%0 = tensor.empty() : tensor<16x64x256xf32>

mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,43 +1110,3 @@ module {
11101110
// CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]]
11111111
// CHECK: linalg.yield %[[T3]] : f32
11121112
// CHECK: return %[[GENERIC]]
1113-
1114-
// -----
1115-
1116-
// CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)>
1117-
#map0 = affine_map<(d0, d1) -> (d0, d1)>
1118-
1119-
// CHECK-LABEL: @mixed_fusion
1120-
func.func @mixed_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>, %arg8 : memref<?x?xf32>)
1121-
{
1122-
%c0 = arith.constant 0 : index
1123-
%c1 = arith.constant 1 : index
1124-
%0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
1125-
%1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
1126-
%2 = tensor.empty(%0, %1) : tensor<?x?xf32>
1127-
%3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
1128-
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
1129-
outs(%2 : tensor<?x?xf32>) {
1130-
^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
1131-
%4 = arith.addf %arg3, %arg4 : f32
1132-
linalg.yield %4 : f32
1133-
} -> tensor<?x?xf32>
1134-
// CHECK: linalg.generic {
1135-
// CHECK-SAME: indexing_maps = {{\[}}[[$MAP0]], [[$MAP0]], [[$MAP0]], [[$MAP0]]{{\]}}
1136-
linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
1137-
ins(%3, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
1138-
outs(%arg8 : memref<?x?xf32>) {
1139-
// CHECK: ^{{[a-zA-Z0-9_]*}}
1140-
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]
1141-
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]
1142-
// CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]]
1143-
^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
1144-
// CHECK: [[T1:%[a-zA-Z0-9_]*]] = arith.addf [[ARG0]], [[ARG1]]
1145-
// CHECK-NOT: linalg.yield
1146-
// CHECK: arith.mulf [[T1]], [[ARG2]]
1147-
// CHECK: linalg.yield
1148-
%5 = arith.mulf %arg5, %arg6 : f32
1149-
linalg.yield %5 : f32
1150-
}
1151-
return
1152-
}

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,3 +770,13 @@ func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>,
770770
-> tensor<8x8xf32>
771771
return %res : tensor<8x8xf32>
772772
}
773+
774+
// -----
775+
776+
func.func @mixed_semantics(%a: tensor<?x?xf32>, %b: tensor<?x?xf32>, %c: memref<?x?xf32>) {
777+
// expected-error @+1 {{expected to have pure tensor or buffer semantics}}
778+
linalg.matmul ins(%a, %b: tensor<?x?xf32>, tensor<?x?xf32>)
779+
outs(%c: memref<?x?xf32>)
780+
return
781+
}
782+

0 commit comments

Comments
 (0)