Skip to content

Commit 6a66f9a

Browse files
SC llvm teamSC llvm team
authored andcommitted
Merged main:3b232f066d40a3e91ac27e421a3baeaca0cd59ec into amd-gfx:73e30982e63d
Local branch amd-gfx 73e3098 Merged main:52ada07ef5df2829e90ca2dd48305465a55e8121 into amd-gfx:cb528432c2ad Remote branch main 3b232f0 [mlir][linalg] `LinalgOp`: Disallow mixed tensor/buffer semantics (llvm#80660)
2 parents 73e3098 + 3b232f0 commit 6a66f9a

File tree

7 files changed

+45
-90
lines changed

7 files changed

+45
-90
lines changed

llvm/include/llvm/Config/llvm-config.h.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
/* Indicate that this is LLVM compiled from the amd-gfx branch. */
1818
#define LLVM_HAVE_BRANCH_AMD_GFX
19-
#define LLVM_MAIN_REVISION 490621
19+
#define LLVM_MAIN_REVISION 490623
2020

2121
/* Define if LLVM_ENABLE_DUMP is enabled */
2222
#cmakedefine LLVM_ENABLE_DUMP

llvm/include/llvm/Target/TargetSchedule.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,8 @@ def NoSchedPred : MCSchedPredicate<TruePred>;
399399
class SchedVar<SchedPredicateBase pred, list<SchedReadWrite> selected> {
400400
SchedPredicateBase Predicate = pred;
401401
list<SchedReadWrite> Selected = selected;
402+
// SchedModel silences warnings but is ignored.
403+
SchedMachineModel SchedModel = ?;
402404
}
403405

404406
// SchedModel silences warnings but is ignored.

llvm/lib/Target/RISCV/RISCVScheduleV.td

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,20 +88,25 @@ multiclass LMULWriteResMXVariant<string name, SchedPredicateBase Pred,
8888
let ReleaseAtCycles = noPredReleaseCycles;
8989
}
9090

91+
// Define SchedVars
92+
def nameMX # PredSchedVar
93+
: SchedVar<Pred, [!cast<SchedWriteRes>(NAME # nameMX # "_Pred")]>;
94+
def nameMX # NoPredSchedVar
95+
: SchedVar<NoSchedPred, [!cast<SchedWriteRes>(NAME # nameMX #"_NoPred")]>;
96+
// Allow multiclass to refer to SchedVars -- need to have NAME prefix.
97+
defvar PredSchedVar = !cast<SchedVar>(NAME # nameMX # PredSchedVar);
98+
defvar NoPredSchedVar = !cast<SchedVar>(NAME # nameMX # NoPredSchedVar);
99+
91100
// Tie behavior to predicate
92-
def NAME # nameMX # "_Variant" : SchedWriteVariant<[
93-
SchedVar<Pred, [!cast<SchedWriteRes>(NAME # nameMX # "_Pred")]>,
94-
SchedVar<NoSchedPred, [!cast<SchedWriteRes>(NAME # nameMX # "_NoPred")]>
95-
]>;
101+
def NAME # nameMX # "_Variant"
102+
: SchedWriteVariant<[PredSchedVar, NoPredSchedVar]>;
96103
def : SchedAlias<
97104
!cast<SchedReadWrite>(nameMX),
98105
!cast<SchedReadWrite>(NAME # nameMX # "_Variant")>;
99106

100107
if IsWorstCase then {
101-
def NAME # name # "_WorstCase_Variant" : SchedWriteVariant<[
102-
SchedVar<Pred, [!cast<SchedWriteRes>(NAME # nameMX # "_Pred")]>,
103-
SchedVar<NoSchedPred, [!cast<SchedWriteRes>(NAME # nameMX # "_NoPred")]>
104-
]>;
108+
def NAME # name # "_WorstCase_Variant"
109+
: SchedWriteVariant<[PredSchedVar, NoPredSchedVar]>;
105110
def : SchedAlias<
106111
!cast<SchedReadWrite>(name # "_WorstCase"),
107112
!cast<SchedReadWrite>(NAME # name # "_WorstCase_Variant")>;

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,6 +1041,11 @@ int64_t LinalgOp::getIndexingMapIndex(OpOperand *opOperand) {
10411041
LogicalResult mlir::linalg::detail::verifyStructuredOpInterface(Operation *op) {
10421042
LinalgOp linalgOp = cast<LinalgOp>(op);
10431043

1044+
// Mixed tensor/buffer operands are not allowed.
1045+
if (!linalgOp.hasPureTensorSemantics() &&
1046+
!linalgOp.hasPureBufferSemantics() && op->getNumOperands() > 0)
1047+
return op->emitOpError("expected to have pure tensor or buffer semantics");
1048+
10441049
// Before checking indexing maps, we need to make sure the attributes
10451050
// referenced by it are valid.
10461051
if (linalgOp.hasDynamicIndexingMaps())

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 14 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -102,17 +102,16 @@ func.func @tensor.cast.unranked(%a : tensor<*xf32>, %b : tensor<*xf32>, %c : ten
102102
// -----
103103

104104
// CHECK-LABEL: func @linalg_effects(
105-
// CHECK-SAME: %[[A:[a-z0-9]*]]: tensor<?x?xf32>
106-
// CHECK-SAME: %[[B:[a-z0-9]*]]: memref<?x?xf32>
107-
// CHECK-SAME: %[[C:[a-z0-9]*]]: tensor<?x?xf32>
108-
func.func @linalg_effects(%a : tensor<?x?xf32>, %b : memref<?x?xf32>, %c : tensor<?x?xf32>) {
105+
func.func @linalg_effects(
106+
%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : tensor<?x?xf32>,
107+
%d : memref<?x?xf32>, %e : memref<?x?xf32>, %f : memref<?x?xf32>) {
109108
// CHECK-NOT: %{{.*}} = linalg.matmul
110-
%t = linalg.matmul ins(%a, %b : tensor<?x?xf32>, memref<?x?xf32>)
109+
%t = linalg.matmul ins(%a, %b : tensor<?x?xf32>, tensor<?x?xf32>)
111110
outs(%c : tensor<?x?xf32>) -> tensor<?x?xf32>
112111

113112
// CHECK: linalg.matmul
114-
linalg.matmul ins(%a, %c : tensor<?x?xf32>, tensor<?x?xf32>)
115-
outs(%b : memref<?x?xf32>)
113+
linalg.matmul ins(%d, %e : memref<?x?xf32>, memref<?x?xf32>)
114+
outs(%f : memref<?x?xf32>)
116115
return
117116
}
118117

@@ -889,39 +888,38 @@ func.func @fold_multi_use_generic_op_with_consumer(%arg0 : tensor<?x?x?xf32>) ->
889888
// -----
890889

891890
#map = affine_map<(d0) -> (d0)>
892-
func.func @identity_mixed(%arg0 : tensor<?xf32>, %arg1: memref<?xf32>) {
891+
func.func @identity_buffer(%arg0 : memref<?xf32>, %arg1: memref<?xf32>) {
893892
linalg.generic {
894893
indexing_maps = [#map, #map],
895894
iterator_types = ["parallel"]
896-
} ins(%arg0 : tensor<?xf32>)
895+
} ins(%arg0 : memref<?xf32>)
897896
outs(%arg1 : memref<?xf32>) {
898897
^bb0(%arg2 : f32, %arg3 : f32):
899898
linalg.yield %arg2 : f32
900899
}
901900
return
902901
}
903902

904-
// There was a crash in EraseIdentityGenericOp for generic with mixed semantics.
905-
// For now, check generic remained unchanged.
906-
// CHECK-LABEL: func @identity_mixed
907-
// CHECK-SAME: (%[[ARG1:.*]]: tensor<?xf32>, %[[ARG2:.*]]: memref<?xf32>)
903+
// Do not erase ops with buffer semantics.
904+
// CHECK-LABEL: func @identity_buffer
905+
// CHECK-SAME: (%[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf32>)
908906
// CHECK: linalg.generic {
909907
// CHECK-SAME: indexing_maps = [#map, #map],
910908
// CHECK-SAME: iterator_types = ["parallel"]
911-
// CHECK-SAME: } ins(%[[ARG1]] : tensor<?xf32>)
909+
// CHECK-SAME: } ins(%[[ARG1]] : memref<?xf32>)
912910
// CHECK-SAME: outs(%[[ARG2]] : memref<?xf32>) {
913911

914912
// -----
915913

916914
// Just make sure that we don't crash.
917915

918916
// CHECK-LABEL: func @dedeplicate_regression_test
919-
func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: memref<4xf32>) {
917+
func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: tensor<4xf32>) {
920918
%36 = linalg.generic
921919
{indexing_maps = [affine_map<(d0) -> (d0)>,
922920
affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
923921
iterator_types = ["parallel"]}
924-
ins(%1, %1 : memref<4xf32>, memref<4xf32>)
922+
ins(%1, %1 : tensor<4xf32>, tensor<4xf32>)
925923
outs(%0 : tensor<4xf32>) {
926924
^bb0(%in: f32, %in_24: f32, %out: f32):
927925
linalg.yield %in : f32
@@ -937,31 +935,6 @@ func.func @dedeplicate_regression_test(%0: tensor<4xf32>, %1: memref<4xf32>) {
937935

938936
// -----
939937

940-
#map = affine_map<(d0) -> (d0)>
941-
func.func @cast_producer_mixed(%arg0 : tensor<5xf32>, %arg1: memref<?xf32>) {
942-
%0 = tensor.cast %arg0 : tensor<5xf32> to tensor<?xf32>
943-
linalg.generic {
944-
indexing_maps = [#map, #map],
945-
iterator_types = ["parallel"]
946-
} ins(%0 : tensor<?xf32>)
947-
outs(%arg1 : memref<?xf32>) {
948-
^bb0(%arg2 : f32, %arg3 : f32):
949-
linalg.yield %arg2 : f32
950-
}
951-
return
952-
}
953-
954-
// We need a mixed linalg as a bridge between tensor and memref worlds.
955-
// CHECK-LABEL: func @cast_producer_mixed
956-
// CHECK-SAME: (%[[ARG1:.*]]: tensor<5xf32>, %[[ARG2:.*]]: memref<?xf32>)
957-
// CHECK: linalg.generic {
958-
// CHECK-SAME: indexing_maps = [#map, #map],
959-
// CHECK-SAME: iterator_types = ["parallel"]
960-
// CHECK-SAME: } ins(%[[ARG1]] : tensor<5xf32>)
961-
// CHECK-SAME: outs(%[[ARG2]] : memref<?xf32>) {
962-
963-
// -----
964-
965938
// CHECK-LABEL: dead_softmax
966939
func.func @dead_softmax(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> {
967940
%0 = tensor.empty() : tensor<16x64x256xf32>

mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,43 +1110,3 @@ module {
11101110
// CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]]
11111111
// CHECK: linalg.yield %[[T3]] : f32
11121112
// CHECK: return %[[GENERIC]]
1113-
1114-
// -----
1115-
1116-
// CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)>
1117-
#map0 = affine_map<(d0, d1) -> (d0, d1)>
1118-
1119-
// CHECK-LABEL: @mixed_fusion
1120-
func.func @mixed_fusion(%arg0: tensor<?x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<?x?xf32>, %arg8 : memref<?x?xf32>)
1121-
{
1122-
%c0 = arith.constant 0 : index
1123-
%c1 = arith.constant 1 : index
1124-
%0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
1125-
%1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
1126-
%2 = tensor.empty(%0, %1) : tensor<?x?xf32>
1127-
%3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
1128-
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
1129-
outs(%2 : tensor<?x?xf32>) {
1130-
^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
1131-
%4 = arith.addf %arg3, %arg4 : f32
1132-
linalg.yield %4 : f32
1133-
} -> tensor<?x?xf32>
1134-
// CHECK: linalg.generic {
1135-
// CHECK-SAME: indexing_maps = {{\[}}[[$MAP0]], [[$MAP0]], [[$MAP0]], [[$MAP0]]{{\]}}
1136-
linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
1137-
ins(%3, %arg2 : tensor<?x?xf32>, tensor<?x?xf32>)
1138-
outs(%arg8 : memref<?x?xf32>) {
1139-
// CHECK: ^{{[a-zA-Z0-9_]*}}
1140-
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9_]*]]
1141-
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9_]*]]
1142-
// CHECK-SAME: [[ARG2:%[a-zA-Z0-9_]*]]
1143-
^bb0(%arg5: f32, %arg6: f32, %arg7: f32):
1144-
// CHECK: [[T1:%[a-zA-Z0-9_]*]] = arith.addf [[ARG0]], [[ARG1]]
1145-
// CHECK-NOT: linalg.yield
1146-
// CHECK: arith.mulf [[T1]], [[ARG2]]
1147-
// CHECK: linalg.yield
1148-
%5 = arith.mulf %arg5, %arg6 : f32
1149-
linalg.yield %5 : f32
1150-
}
1151-
return
1152-
}

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,3 +770,13 @@ func.func @mmt4d_rank_mismatch(%A: tensor<16x16x8x1xf32>,
770770
-> tensor<8x8xf32>
771771
return %res : tensor<8x8xf32>
772772
}
773+
774+
// -----
775+
776+
func.func @mixed_semantics(%a: tensor<?x?xf32>, %b: tensor<?x?xf32>, %c: memref<?x?xf32>) {
777+
// expected-error @+1 {{expected to have pure tensor or buffer semantics}}
778+
linalg.matmul ins(%a, %b: tensor<?x?xf32>, tensor<?x?xf32>)
779+
outs(%c: memref<?x?xf32>)
780+
return
781+
}
782+

0 commit comments

Comments
 (0)