Skip to content

Commit 85e4e9d

Browse files
authored
[mlir][arith] Further clean up select op definition (#93358)
* Improve the condition type requirement description ('scalar' -> signless i1), to match what is actually verified. * Use the `I1` type predicate instead of `AnyBooleanTypeMatch`. Related discussion: #93351 (comment).
1 parent 4d20f49 commit 85e4e9d

File tree

3 files changed

+9
-15
lines changed

3 files changed

+9
-15
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1553,21 +1553,15 @@ def Arith_CmpFOp : Arith_CompareOp<"cmpf",
15531553
// SelectOp
15541554
//===----------------------------------------------------------------------===//
15551555

1556-
class AnyBooleanTypeMatch<list<string> names> :
1557-
AnyMatchOperatorTrait<names, "$_self.getType().isSignlessInteger(1)",
1558-
"scalar type">;
1559-
1560-
class ScalarConditionOrMatchingShape<list<string> names> :
1556+
class BooleanConditionOrMatchingShape<string condition, string result> :
15611557
PredOpTrait<
1562-
!head(names) # " is scalar or has matching shape",
1563-
Or<[AnyBooleanTypeMatch<[!head(names)]>.predicate,
1564-
AllShapesMatch<names>.predicate]>> {
1565-
list<string> values = names;
1566-
}
1558+
condition # " is signless i1 or has matching shape",
1559+
Or<[TypeIsPred<condition, I1>,
1560+
AllShapesMatch<[condition, result]>.predicate]>>;
15671561

15681562
def SelectOp : Arith_Op<"select", [Pure,
15691563
AllTypesMatch<["true_value", "false_value", "result"]>,
1570-
ScalarConditionOrMatchingShape<["condition", "result"]>,
1564+
BooleanConditionOrMatchingShape<"condition", "result">,
15711565
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>,
15721566
] # ElementwiseMappable.traits> {
15731567
let summary = "select operation";

mlir/test/Dialect/Arith/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -832,15 +832,15 @@ func.func @func() {
832832
// -----
833833

834834
func.func @disallow_zero_rank_tensor_with_ranked_tensor(%arg0 : tensor<i1>, %arg1 : tensor<2xi64>, %arg2 : tensor<2xi64>) -> tensor<2xi64> {
835-
// expected-error @+1 {{'arith.select' op failed to verify that condition is scalar or has matching shape}}
835+
// expected-error @+1 {{'arith.select' op failed to verify that condition is signless i1 or has matching shape}}
836836
%0 = arith.select %arg0, %arg1, %arg2 : tensor<i1>, tensor<2xi64>
837837
return %0 : tensor<2xi64>
838838
}
839839

840840
// -----
841841

842842
func.func @disallow_zero_rank_tensor_with_unranked_tensor(%arg0 : tensor<i1>, %arg1 : tensor<2x?xi64>, %arg2 : tensor<2x?xi64>) -> tensor<2x?xi64> {
843-
// expected-error @+1 {{'arith.select' op failed to verify that condition is scalar or has matching shape}}
843+
// expected-error @+1 {{'arith.select' op failed to verify that condition is signless i1 or has matching shape}}
844844
%0 = arith.select %arg0, %arg1, %arg2 : tensor<i1>, tensor<2x?xi64>
845845
return %0 : tensor<2x?xi64>
846846
}

mlir/test/IR/invalid-ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,15 @@ func.func @func_with_ops(i1, i32, i64) {
8282

8383
func.func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) {
8484
^bb0(%cond : vector<12xi1>, %t : vector<42xi32>, %f : vector<42xi32>):
85-
// expected-error@+1 {{'arith.select' op failed to verify that condition is scalar or has matching shape}}
85+
// expected-error@+1 {{'arith.select' op failed to verify that condition is signless i1 or has matching shape}}
8686
%r = "arith.select"(%cond, %t, %f) : (vector<12xi1>, vector<42xi32>, vector<42xi32>) -> vector<42xi32>
8787
}
8888

8989
// -----
9090

9191
func.func @func_with_ops(tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) {
9292
^bb0(%cond : tensor<12xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>):
93-
// expected-error@+1 {{'arith.select' op failed to verify that condition is scalar or has matching shape}}
93+
// expected-error@+1 {{'arith.select' op failed to verify that condition is signless i1 or has matching shape}}
9494
%r = "arith.select"(%cond, %t, %f) : (tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32>
9595
}
9696

0 commit comments

Comments
 (0)