Skip to content

Commit dec8af7

Browse files
committed
[mlir] Move SelectOp from Standard to Arithmetic
This is part of splitting up the standard dialect. See https://llvm.discourse.group/t/standard-dialect-the-final-chapter/ for discussion. Differential Revision: https://reviews.llvm.org/D118648
1 parent 6a8ba31 commit dec8af7

File tree

116 files changed

+1033
-1135
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

116 files changed

+1033
-1135
lines changed

flang/lib/Optimizer/Builder/Character.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ mlir::Value genMin(fir::FirOpBuilder &builder, mlir::Location loc,
434434
mlir::Value a, mlir::Value b) {
435435
auto cmp =
436436
builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, b);
437-
return builder.create<mlir::SelectOp>(loc, cmp, a, b);
437+
return builder.create<mlir::arith::SelectOp>(loc, cmp, a, b);
438438
}
439439

440440
void fir::factory::CharacterExprHelper::createAssign(
@@ -532,7 +532,8 @@ fir::CharBoxValue fir::factory::CharacterExprHelper::createSubstring(
532532
auto zero = builder.createIntegerConstant(loc, substringLen.getType(), 0);
533533
auto cdt = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
534534
substringLen, zero);
535-
substringLen = builder.create<mlir::SelectOp>(loc, cdt, zero, substringLen);
535+
substringLen =
536+
builder.create<mlir::arith::SelectOp>(loc, cdt, zero, substringLen);
536537

537538
return {substringRef, substringLen};
538539
}
@@ -570,8 +571,8 @@ fir::factory::CharacterExprHelper::createLenTrim(const fir::CharBoxValue &str) {
570571
// Compute length after iteration (zero if all blanks)
571572
mlir::Value newLen =
572573
builder.create<arith::AddIOp>(loc, iterWhile.getResult(1), one);
573-
auto result =
574-
builder.create<mlir::SelectOp>(loc, iterWhile.getResult(0), zero, newLen);
574+
auto result = builder.create<mlir::arith::SelectOp>(
575+
loc, iterWhile.getResult(0), zero, newLen);
575576
return builder.createConvert(loc, builder.getCharacterLengthType(), result);
576577
}
577578

flang/lib/Optimizer/Builder/MutableBox.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -675,8 +675,8 @@ void fir::factory::genReallocIfNeeded(fir::FirOpBuilder &builder,
675675
// reallocate = reallocate || previous != required
676676
auto cmp = builder.create<arith::CmpIOp>(
677677
loc, arith::CmpIPredicate::ne, castPrevious, required);
678-
mustReallocate =
679-
builder.create<mlir::SelectOp>(loc, cmp, cmp, mustReallocate);
678+
mustReallocate = builder.create<mlir::arith::SelectOp>(
679+
loc, cmp, cmp, mustReallocate);
680680
};
681681
llvm::SmallVector<mlir::Value> previousLbounds;
682682
llvm::SmallVector<mlir::Value> previousExtents =

flang/lib/Optimizer/Builder/Runtime/Numeric.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,8 @@ mlir::Value fir::runtime::genNearest(fir::FirOpBuilder &builder,
295295
mlir::Value False = builder.createIntegerConstant(loc, boolTy, 0);
296296
mlir::Value True = builder.createIntegerConstant(loc, boolTy, 1);
297297

298-
mlir::Value positive = builder.create<mlir::SelectOp>(loc, cmp, True, False);
298+
mlir::Value positive =
299+
builder.create<mlir::arith::SelectOp>(loc, cmp, True, False);
299300
auto args = fir::runtime::createArguments(builder, loc, funcTy, x, positive);
300301

301302
return builder.create<fir::CallOp>(loc, func, args).getResult(0);

flang/lib/Optimizer/Transforms/RewriteLoop.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class CfgLoopConv : public mlir::OpRewritePattern<fir::DoLoopOp> {
7575
auto cond = rewriter.create<mlir::arith::CmpIOp>(
7676
loc, arith::CmpIPredicate::sle, iters, zero);
7777
auto one = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 1);
78-
iters = rewriter.create<mlir::SelectOp>(loc, cond, one, iters);
78+
iters = rewriter.create<mlir::arith::SelectOp>(loc, cond, one, iters);
7979
}
8080

8181
llvm::SmallVector<mlir::Value> loopOperands;

flang/test/Fir/loop02.fir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ func private @y(%addr : !fir.ref<index>)
2323
// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
2424
// CHECK: %[[VAL_7:.*]] = arith.cmpi sle, %[[VAL_5]], %[[VAL_6]] : index
2525
// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
26-
// CHECK: %[[VAL_9:.*]] = select %[[VAL_7]], %[[VAL_8]], %[[VAL_5]] : index
26+
// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_7]], %[[VAL_8]], %[[VAL_5]] : index
2727
// CHECK: br ^bb1(%[[VAL_1]], %[[VAL_9]] : index, index)
2828
// CHECK: ^bb1(%[[VAL_10:.*]]: index, %[[VAL_11:.*]]: index):
2929
// CHECK: %[[VAL_12:.*]] = arith.constant 0 : index

flang/unittests/Optimizer/Builder/Runtime/NumericTest.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ void testGenNearest(fir::FirOpBuilder &builder, mlir::Type xType,
5656
checkCallOp(nearest.getDefiningOp(), fctName, 2, /*addLocArg=*/false);
5757
auto callOp = mlir::dyn_cast<fir::CallOp>(nearest.getDefiningOp());
5858
mlir::Value select = callOp.getOperands()[1];
59-
EXPECT_TRUE(mlir::isa<mlir::SelectOp>(select.getDefiningOp()));
60-
auto selectOp = mlir::dyn_cast<mlir::SelectOp>(select.getDefiningOp());
59+
EXPECT_TRUE(mlir::isa<mlir::arith::SelectOp>(select.getDefiningOp()));
60+
auto selectOp = mlir::dyn_cast<mlir::arith::SelectOp>(select.getDefiningOp());
6161
mlir::Value cmp = selectOp.getCondition();
6262
EXPECT_TRUE(mlir::isa<mlir::arith::CmpFOp>(cmp.getDefiningOp()));
6363
auto cmpOp = mlir::dyn_cast<mlir::arith::CmpFOp>(cmp.getDefiningOp());

mlir/benchmark/python/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def setup_passes(mlir_module):
2929
f"convert-scf-to-std,"
3030
f"func-bufferize,"
3131
f"arith-bufferize,"
32-
f"builtin.func(tensor-bufferize,std-bufferize,finalizing-bufferize),"
32+
f"builtin.func(tensor-bufferize,finalizing-bufferize),"
3333
f"convert-vector-to-llvm"
3434
f"{{reassociate-fp-reductions=1 enable-index-optimizations=1}},"
3535
f"lower-affine,"

mlir/docs/Bufferization.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ The code, slightly simplified and annotated, is reproduced here:
8787
pm.addNestedPass<FuncOp>(createTCPBufferizePass()); // Bufferizes the downstream `tcp` dialect.
8888
pm.addNestedPass<FuncOp>(createSCFBufferizePass());
8989
pm.addNestedPass<FuncOp>(createLinalgBufferizePass());
90-
pm.addNestedPass<FuncOp>(createStdBufferizePass());
9190
pm.addNestedPass<FuncOp>(createTensorBufferizePass());
9291
pm.addPass(createFuncBufferizePass());
9392

mlir/include/mlir/Dialect/Arithmetic/IR/Arithmetic.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/IR/OpDefinition.h"
1313
#include "mlir/IR/OpImplementation.h"
1414
#include "mlir/Interfaces/CastInterfaces.h"
15+
#include "mlir/Interfaces/InferTypeOpInterface.h"
1516
#include "mlir/Interfaces/SideEffectInterfaces.h"
1617
#include "mlir/Interfaces/VectorInterfaces.h"
1718

mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticOps.td

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1088,7 +1088,7 @@ def Arith_CmpIOp : Arith_CompareOpOfAnyRank<"cmpi"> {
10881088
%x = arith.cmpi "eq", %lhs, %rhs : vector<4xi64>
10891089

10901090
// Generic form of the same operation.
1091-
%x = "std.arith.cmpi"(%lhs, %rhs) {predicate = 0 : i64}
1091+
%x = "arith.cmpi"(%lhs, %rhs) {predicate = 0 : i64}
10921092
: (vector<4xi64>, vector<4xi64>) -> vector<4xi1>
10931093
```
10941094
}];
@@ -1161,4 +1161,55 @@ def Arith_CmpFOp : Arith_CompareOp<"cmpf"> {
11611161
let hasFolder = 1;
11621162
}
11631163

1164+
//===----------------------------------------------------------------------===//
1165+
// SelectOp
1166+
//===----------------------------------------------------------------------===//
1167+
1168+
def SelectOp : Arith_Op<"select", [
1169+
AllTypesMatch<["true_value", "false_value", "result"]>
1170+
] # ElementwiseMappable.traits> {
1171+
let summary = "select operation";
1172+
let description = [{
1173+
The `arith.select` operation chooses one value based on a binary condition
1174+
supplied as its first operand. If the value of the first operand is `1`,
1175+
the second operand is chosen, otherwise the third operand is chosen.
1176+
The second and the third operand must have the same type.
1177+
1178+
The operation applies to vectors and tensors elementwise given the _shape_
1179+
of all operands is identical. The choice is made for each element
1180+
individually based on the value at the same position as the element in the
1181+
condition operand. If an i1 is provided as the condition, the entire vector
1182+
or tensor is chosen.
1183+
1184+
Example:
1185+
1186+
```mlir
1187+
// Custom form of scalar selection.
1188+
%x = arith.select %cond, %true, %false : i32
1189+
1190+
// Generic form of the same operation.
1191+
%x = "arith.select"(%cond, %true, %false) : (i1, i32, i32) -> i32
1192+
1193+
// Element-wise vector selection.
1194+
%vx = arith.select %vcond, %vtrue, %vfalse : vector<42xi1>, vector<42xf32>
1195+
1196+
// Full vector selection.
1197+
%vx = arith.select %cond, %vtrue, %vfalse : vector<42xf32>
1198+
```
1199+
}];
1200+
1201+
let arguments = (ins BoolLike:$condition,
1202+
AnyType:$true_value,
1203+
AnyType:$false_value);
1204+
let results = (outs AnyType:$result);
1205+
1206+
let hasCanonicalizer = 1;
1207+
let hasFolder = 1;
1208+
let hasVerifier = 1;
1209+
1210+
// FIXME: Switch this to use the declarative assembly format.
1211+
let printer = [{ return ::print(p, *this); }];
1212+
let parser = [{ return ::parse$cppClass(parser, result); }];
1213+
}
1214+
11641215
#endif // ARITHMETIC_OPS

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
146146

147147
Note: This method can return multiple OpOperands, indicating that the
148148
given OpResult may at runtime alias with any of the OpOperands. This
149-
is useful for branches and for ops such as `std.select`.
149+
is useful for branches and for ops such as `arith.select`.
150150
}],
151151
/*retType=*/"SmallVector<OpOperand *>",
152152
/*methodName=*/"getAliasingOpOperand",

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -456,57 +456,6 @@ def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">,
456456
let hasVerifier = 1;
457457
}
458458

459-
//===----------------------------------------------------------------------===//
460-
// SelectOp
461-
//===----------------------------------------------------------------------===//
462-
463-
def SelectOp : Std_Op<"select", [NoSideEffect,
464-
AllTypesMatch<["true_value", "false_value", "result"]>,
465-
DeclareOpInterfaceMethods<VectorUnrollOpInterface>] #
466-
ElementwiseMappable.traits> {
467-
let summary = "select operation";
468-
let description = [{
469-
The `select` operation chooses one value based on a binary condition
470-
supplied as its first operand. If the value of the first operand is `1`,
471-
the second operand is chosen, otherwise the third operand is chosen.
472-
The second and the third operand must have the same type.
473-
474-
The operation applies to vectors and tensors elementwise given the _shape_
475-
of all operands is identical. The choice is made for each element
476-
individually based on the value at the same position as the element in the
477-
condition operand. If an i1 is provided as the condition, the entire vector
478-
or tensor is chosen.
479-
480-
The `select` operation combined with [`cmpi`](#stdcmpi-cmpiop) can be used
481-
to implement `min` and `max` with signed or unsigned comparison semantics.
482-
483-
Example:
484-
485-
```mlir
486-
// Custom form of scalar selection.
487-
%x = select %cond, %true, %false : i32
488-
489-
// Generic form of the same operation.
490-
%x = "std.select"(%cond, %true, %false) : (i1, i32, i32) -> i32
491-
492-
// Element-wise vector selection.
493-
%vx = std.select %vcond, %vtrue, %vfalse : vector<42xi1>, vector<42xf32>
494-
495-
// Full vector selection.
496-
%vx = std.select %cond, %vtrue, %vfalse : vector<42xf32>
497-
```
498-
}];
499-
500-
let arguments = (ins BoolLike:$condition,
501-
AnyType:$true_value,
502-
AnyType:$false_value);
503-
let results = (outs AnyType:$result);
504-
505-
let hasCanonicalizer = 1;
506-
let hasFolder = 1;
507-
let hasVerifier = 1;
508-
}
509-
510459
//===----------------------------------------------------------------------===//
511460
// SwitchOp
512461
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/StandardOps/Transforms/BufferizableOpInterfaceImpl.h

Lines changed: 0 additions & 18 deletions
This file was deleted.

mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@ class BufferizeTypeConverter;
2323

2424
class RewritePatternSet;
2525

26-
/// Creates an instance of std bufferization pass.
27-
std::unique_ptr<Pass> createStdBufferizePass();
28-
2926
/// Creates an instance of func bufferization pass.
3027
std::unique_ptr<Pass> createFuncBufferizePass();
3128

mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.td

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,6 @@
1111

1212
include "mlir/Pass/PassBase.td"
1313

14-
def StdBufferize : Pass<"std-bufferize", "FuncOp"> {
15-
let summary = "Bufferize the std dialect";
16-
let constructor = "mlir::createStdBufferizePass()";
17-
}
18-
1914
def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> {
2015
let summary = "Bufferize func/call/return ops";
2116
let description = [{

mlir/include/mlir/Dialect/Tosa/Utils/CoversionUtils.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,14 @@ SmallVector<Value> condenseValues(const SmallVector<Value> &values);
3030

3131
// Takes the parameters for a clamp and turns it into a series of ops.
3232
template <typename T, typename P>
33-
mlir::SelectOp clampHelper(Location loc, Value arg, arith::ConstantOp min,
34-
arith::ConstantOp max, P pred, OpBuilder &rewriter) {
33+
arith::SelectOp clampHelper(Location loc, Value arg, arith::ConstantOp min,
34+
arith::ConstantOp max, P pred,
35+
OpBuilder &rewriter) {
3536
auto smallerThanMin = rewriter.create<T>(loc, pred, arg, min);
3637
auto minOrArg =
37-
rewriter.create<mlir::SelectOp>(loc, smallerThanMin, min, arg);
38+
rewriter.create<arith::SelectOp>(loc, smallerThanMin, min, arg);
3839
auto largerThanMax = rewriter.create<T>(loc, pred, max, arg);
39-
return rewriter.create<mlir::SelectOp>(loc, largerThanMax, max, minOrArg);
40+
return rewriter.create<arith::SelectOp>(loc, largerThanMax, max, minOrArg);
4041
}
4142

4243
// Returns the values in an attribute as an array of values.

mlir/include/mlir/IR/OpDefinition.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,14 +1368,14 @@ struct Elementwise : public TraitBase<ConcreteType, Elementwise> {
13681368
///
13691369
/// Example:
13701370
/// ```
1371-
/// %tensor_select = "std.select"(%pred_tensor, %true_val, %false_val)
1371+
/// %tensor_select = "arith.select"(%pred_tensor, %true_val, %false_val)
13721372
/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
13731373
/// -> tensor<?xf32>
13741374
/// ```
13751375
/// can be scalarized to
13761376
///
13771377
/// ```
1378-
/// %scalar_select = "std.select"(%pred, %true_val_scalar, %false_val_scalar)
1378+
/// %scalar_select = "arith.select"(%pred, %true_val_scalar, %false_val_scalar)
13791379
/// : (i1, f32, f32) -> f32
13801380
/// ```
13811381
template <typename ConcreteType>
@@ -1430,12 +1430,12 @@ struct Vectorizable : public TraitBase<ConcreteType, Vectorizable> {
14301430
/// ```
14311431
///
14321432
/// ```
1433-
/// %scalar_pred = "std.select"(%pred, %true_val, %false_val)
1433+
/// %scalar_pred = "arith.select"(%pred, %true_val, %false_val)
14341434
/// : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
14351435
/// ```
14361436
/// can be tensorized to
14371437
/// ```
1438-
/// %tensor_pred = "std.select"(%pred, %true_val, %false_val)
1438+
/// %tensor_pred = "arith.select"(%pred, %true_val, %false_val)
14391439
/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
14401440
/// -> tensor<?xf32>
14411441
/// ```

mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ class AffineApplyExpander
9696
loc, arith::CmpIPredicate::slt, remainder, zeroCst);
9797
Value correctedRemainder =
9898
builder.create<arith::AddIOp>(loc, remainder, rhs);
99-
Value result = builder.create<SelectOp>(loc, isRemainderNegative,
100-
correctedRemainder, remainder);
99+
Value result = builder.create<arith::SelectOp>(
100+
loc, isRemainderNegative, correctedRemainder, remainder);
101101
return result;
102102
}
103103

@@ -134,12 +134,12 @@ class AffineApplyExpander
134134
loc, arith::CmpIPredicate::slt, lhs, zeroCst);
135135
Value negatedDecremented = builder.create<arith::SubIOp>(loc, noneCst, lhs);
136136
Value dividend =
137-
builder.create<SelectOp>(loc, negative, negatedDecremented, lhs);
137+
builder.create<arith::SelectOp>(loc, negative, negatedDecremented, lhs);
138138
Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
139139
Value correctedQuotient =
140140
builder.create<arith::SubIOp>(loc, noneCst, quotient);
141-
Value result =
142-
builder.create<SelectOp>(loc, negative, correctedQuotient, quotient);
141+
Value result = builder.create<arith::SelectOp>(loc, negative,
142+
correctedQuotient, quotient);
143143
return result;
144144
}
145145

@@ -175,14 +175,14 @@ class AffineApplyExpander
175175
Value negated = builder.create<arith::SubIOp>(loc, zeroCst, lhs);
176176
Value decremented = builder.create<arith::SubIOp>(loc, lhs, oneCst);
177177
Value dividend =
178-
builder.create<SelectOp>(loc, nonPositive, negated, decremented);
178+
builder.create<arith::SelectOp>(loc, nonPositive, negated, decremented);
179179
Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs);
180180
Value negatedQuotient =
181181
builder.create<arith::SubIOp>(loc, zeroCst, quotient);
182182
Value incrementedQuotient =
183183
builder.create<arith::AddIOp>(loc, quotient, oneCst);
184-
Value result = builder.create<SelectOp>(loc, nonPositive, negatedQuotient,
185-
incrementedQuotient);
184+
Value result = builder.create<arith::SelectOp>(
185+
loc, nonPositive, negatedQuotient, incrementedQuotient);
186186
return result;
187187
}
188188

@@ -259,7 +259,8 @@ static Value buildMinMaxReductionSeq(Location loc,
259259
Value value = *valueIt++;
260260
for (; valueIt != values.end(); ++valueIt) {
261261
auto cmpOp = builder.create<arith::CmpIOp>(loc, predicate, value, *valueIt);
262-
value = builder.create<SelectOp>(loc, cmpOp.getResult(), value, *valueIt);
262+
value = builder.create<arith::SelectOp>(loc, cmpOp.getResult(), value,
263+
*valueIt);
263264
}
264265

265266
return value;

0 commit comments

Comments
 (0)