Skip to content

Commit 9fb57c8

Browse files
committed
[mlir] Add min/max operations to Standard.
[RFC: Add min/max ops](https://llvm.discourse.group/t/rfc-add-min-max-operations/4353) I was following the naming style for Arith dialect in https://reviews.llvm.org/D110200, i.e. similar to DivSIOp and DivUIOp I defined MaxSIOp, MaxUIOp. When Arith PR is landed, I will migrate these ops as well. Differential Revision: https://reviews.llvm.org/D110540
1 parent 20c0280 commit 9fb57c8

File tree

5 files changed

+334
-32
lines changed

5 files changed

+334
-32
lines changed

mlir/docs/Rationale/Rationale.md

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -344,33 +344,6 @@ possible to store the predicate as string attribute, it would have rendered
344344
impossible to implement switching logic based on the comparison kind and made
345345
attribute validity checks (one out of ten possible kinds) more complex.
346346

347-
### 'select' operation to implement min/max
348-
349-
Although `min` and `max` operations are likely to occur as a result of
350-
transforming affine loops in ML functions, we did not make them first-class
351-
operations. Instead, we provide the `select` operation that can be combined with
352-
`cmpi` to implement the minimum and maximum computation. Although they now
353-
require two operations, they are likely to be emitted automatically during the
354-
transformation inside MLIR. On the other hand, there are multiple benefits of
355-
introducing `select`: standalone min/max would concern themselves with the
356-
signedness of the comparison, already taken into account by `cmpi`; `select` can
357-
support floats transparently if used after a float-comparison operation; the
358-
lower-level targets provide `select`-like instructions making the translation
359-
trivial.
360-
361-
This operation could have been implemented with additional control flow: `%r =
362-
select %cond, %t, %f` is equivalent to
363-
364-
```mlir
365-
^bb0:
366-
cond_br %cond, ^bb1(%t), ^bb1(%f)
367-
^bb1(%r):
368-
```
369-
370-
However, this control flow granularity is not available in the ML functions
371-
where min/max, and thus `select`, are likely to appear. In addition, simpler
372-
control flow may be beneficial for optimization in general.
373-
374347
### Regions
375348

376349
#### Attributes of type 'Block'

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,6 +1247,152 @@ def IndexCastOp : ArithmeticCastOp<"index_cast"> {
12471247
let hasCanonicalizer = 1;
12481248
}
12491249

1250+
//===----------------------------------------------------------------------===//
1251+
// MaxFOp
1252+
//===----------------------------------------------------------------------===//
1253+
1254+
def MaxFOp : FloatBinaryOp<"maxf"> {
1255+
let summary = "floating-point maximum operation";
1256+
let description = [{
1257+
Syntax:
1258+
1259+
```
1260+
operation ::= ssa-id `=` `maxf` ssa-use `,` ssa-use `:` type
1261+
```
1262+
1263+
Returns the maximum of the two arguments, treating -0.0 as less than +0.0.
1264+
If one of the arguments is NaN, then the result is also NaN.
1265+
1266+
Example:
1267+
1268+
```mlir
1269+
// Scalar floating-point maximum.
1270+
%a = maxf %b, %c : f64
1271+
```
1272+
}];
1273+
}
1274+
1275+
//===----------------------------------------------------------------------===//
1276+
// MaxSIOp
1277+
//===----------------------------------------------------------------------===//
1278+
1279+
def MaxSIOp : IntBinaryOp<"maxsi"> {
1280+
let summary = "signed integer maximum operation";
1281+
let description = [{
1282+
Syntax:
1283+
1284+
```
1285+
operation ::= ssa-id `=` `maxsi` ssa-use `,` ssa-use `:` type
1286+
```
1287+
1288+
Returns the larger of %a and %b comparing the values as signed integers.
1289+
1290+
Example:
1291+
1292+
```mlir
1293+
// Scalar signed integer maximum.
1294+
%a = maxsi %b, %c : i64
1295+
```
1296+
}];
1297+
}
1298+
1299+
//===----------------------------------------------------------------------===//
1300+
// MaxUIOp
1301+
//===----------------------------------------------------------------------===//
1302+
1303+
def MaxUIOp : IntBinaryOp<"maxui"> {
1304+
let summary = "unsigned integer maximum operation";
1305+
let description = [{
1306+
Syntax:
1307+
1308+
```
1309+
operation ::= ssa-id `=` `maxui` ssa-use `,` ssa-use `:` type
1310+
```
1311+
1312+
Returns the larger of %a and %b comparing the values as unsigned integers.
1313+
1314+
Example:
1315+
1316+
```mlir
1317+
// Scalar unsigned integer maximum.
1318+
%a = maxui %b, %c : i64
1319+
```
1320+
}];
1321+
}
1322+
1323+
//===----------------------------------------------------------------------===//
1324+
// MinFOp
1325+
//===----------------------------------------------------------------------===//
1326+
1327+
def MinFOp : FloatBinaryOp<"minf"> {
1328+
let summary = "floating-point minimum operation";
1329+
let description = [{
1330+
Syntax:
1331+
1332+
```
1333+
operation ::= ssa-id `=` `minf` ssa-use `,` ssa-use `:` type
1334+
```
1335+
1336+
Returns the minimum of the two arguments, treating -0.0 as less than +0.0.
1337+
If one of the arguments is NaN, then the result is also NaN.
1338+
1339+
Example:
1340+
1341+
```mlir
1342+
// Scalar floating-point minimum.
1343+
%a = minf %b, %c : f64
1344+
```
1345+
}];
1346+
}
1347+
1348+
//===----------------------------------------------------------------------===//
1349+
// MinSIOp
1350+
//===----------------------------------------------------------------------===//
1351+
1352+
def MinSIOp : IntBinaryOp<"minsi"> {
1353+
let summary = "signed integer minimum operation";
1354+
let description = [{
1355+
Syntax:
1356+
1357+
```
1358+
operation ::= ssa-id `=` `minsi` ssa-use `,` ssa-use `:` type
1359+
```
1360+
1361+
Returns the smaller of %a and %b comparing the values as signed integers.
1362+
1363+
Example:
1364+
1365+
```mlir
1366+
// Scalar signed integer minimum.
1367+
%a = minsi %b, %c : i64
1368+
```
1369+
}];
1370+
}
1371+
1372+
//===----------------------------------------------------------------------===//
1373+
// MinUIOp
1374+
//===----------------------------------------------------------------------===//
1375+
1376+
def MinUIOp : IntBinaryOp<"minui"> {
1377+
let summary = "unsigned integer minimum operation";
1378+
let description = [{
1379+
Syntax:
1380+
1381+
```
1382+
operation ::= ssa-id `=` `minui` ssa-use `,` ssa-use `:` type
1383+
```
1384+
1385+
Returns the smaller of %a and %b comparing the values as unsigned integers.
1386+
1387+
Example:
1388+
1389+
```mlir
1390+
// Scalar unsigned integer minimum.
1391+
%a = minui %b, %c : i64
1392+
```
1393+
}];
1394+
}
1395+
12501396
//===----------------------------------------------------------------------===//
12511397
// MulFOp
12521398
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,55 @@ struct SignedFloorDivIOpConverter : public OpRewritePattern<SignedFloorDivIOp> {
215215
}
216216
};
217217

218+
static Type getElementTypeOrSelf(Type type) {
219+
if (auto st = type.dyn_cast<ShapedType>())
220+
return st.getElementType();
221+
return type;
222+
}
223+
224+
template <typename OpTy, CmpFPredicate pred>
225+
struct MaxMinFOpConverter : public OpRewritePattern<OpTy> {
226+
public:
227+
using OpRewritePattern<OpTy>::OpRewritePattern;
228+
229+
LogicalResult matchAndRewrite(OpTy op,
230+
PatternRewriter &rewriter) const final {
231+
Value lhs = op.lhs();
232+
Value rhs = op.rhs();
233+
234+
Location loc = op.getLoc();
235+
Value cmp = rewriter.create<CmpFOp>(loc, pred, lhs, rhs);
236+
Value select = rewriter.create<SelectOp>(loc, cmp, lhs, rhs);
237+
238+
auto floatType = getElementTypeOrSelf(lhs.getType()).cast<FloatType>();
239+
Value isNaN = rewriter.create<CmpFOp>(loc, CmpFPredicate::UNO, lhs, rhs);
240+
241+
Value nan = rewriter.create<ConstantFloatOp>(
242+
loc, APFloat::getQNaN(floatType.getFloatSemantics()), floatType);
243+
if (VectorType vectorType = lhs.getType().dyn_cast<VectorType>())
244+
nan = rewriter.create<SplatOp>(loc, vectorType, nan);
245+
246+
rewriter.replaceOpWithNewOp<SelectOp>(op, isNaN, nan, select);
247+
return success();
248+
}
249+
};
250+
251+
template <typename OpTy, CmpIPredicate pred>
252+
struct MaxMinIOpConverter : public OpRewritePattern<OpTy> {
253+
public:
254+
using OpRewritePattern<OpTy>::OpRewritePattern;
255+
LogicalResult matchAndRewrite(OpTy op,
256+
PatternRewriter &rewriter) const final {
257+
Value lhs = op.lhs();
258+
Value rhs = op.rhs();
259+
260+
Location loc = op.getLoc();
261+
Value cmp = rewriter.create<CmpIOp>(loc, pred, lhs, rhs);
262+
rewriter.replaceOpWithNewOp<SelectOp>(op, cmp, lhs, rhs);
263+
return success();
264+
}
265+
};
266+
218267
struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
219268
void runOnFunction() override {
220269
MLIRContext &ctx = getContext();
@@ -232,8 +281,18 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
232281
target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
233282
return !op.shape().getType().cast<MemRefType>().hasStaticShape();
234283
});
235-
target.addIllegalOp<SignedCeilDivIOp>();
236-
target.addIllegalOp<SignedFloorDivIOp>();
284+
// clang-format off
285+
target.addIllegalOp<
286+
MaxFOp,
287+
MaxSIOp,
288+
MaxUIOp,
289+
MinFOp,
290+
MinSIOp,
291+
MinUIOp,
292+
SignedCeilDivIOp,
293+
SignedFloorDivIOp
294+
>();
295+
// clang-format on
237296
if (failed(
238297
applyPartialConversion(getFunction(), target, std::move(patterns))))
239298
signalPassFailure();
@@ -243,9 +302,20 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
243302
} // namespace
244303

245304
void mlir::populateStdExpandOpsPatterns(RewritePatternSet &patterns) {
246-
patterns.add<AtomicRMWOpConverter, MemRefReshapeOpConverter,
247-
SignedCeilDivIOpConverter, SignedFloorDivIOpConverter>(
248-
patterns.getContext());
305+
// clang-format off
306+
patterns.add<
307+
AtomicRMWOpConverter,
308+
MaxMinFOpConverter<MaxFOp, CmpFPredicate::OGT>,
309+
MaxMinFOpConverter<MinFOp, CmpFPredicate::OLT>,
310+
MaxMinIOpConverter<MaxSIOp, CmpIPredicate::sgt>,
311+
MaxMinIOpConverter<MaxUIOp, CmpIPredicate::ugt>,
312+
MaxMinIOpConverter<MinSIOp, CmpIPredicate::slt>,
313+
MaxMinIOpConverter<MinUIOp, CmpIPredicate::ult>,
314+
MemRefReshapeOpConverter,
315+
SignedCeilDivIOpConverter,
316+
SignedFloorDivIOpConverter
317+
>(patterns.getContext());
318+
// clang-format on
249319
}
250320

251321
std::unique_ptr<Pass> mlir::createStdExpandOpsPass() {

mlir/test/Dialect/Standard/expand-ops.mlir

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,92 @@ func @memref_reshape(%input: memref<*xf32>,
109109
// CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], 8],
110110
// CHECK-SAME: strides: {{\[}}[[STRIDE_0]], [[STRIDE_1]], [[C1]]]
111111
// CHECK-SAME: : memref<*xf32> to memref<?x?x8xf32>
112+
113+
// -----
114+
115+
// CHECK-LABEL: func @maxf
116+
func @maxf(%a: f32, %b: f32) -> f32 {
117+
%result = maxf(%a, %b): (f32, f32) -> f32
118+
return %result : f32
119+
}
120+
// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
121+
// CHECK-NEXT: %[[CMP:.*]] = cmpf ogt, %[[LHS]], %[[RHS]] : f32
122+
// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32
123+
// CHECK-NEXT: %[[IS_NAN:.*]] = cmpf uno, %[[LHS]], %[[RHS]] : f32
124+
// CHECK-NEXT: %[[NAN:.*]] = constant 0x7FC00000 : f32
125+
// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32
126+
// CHECK-NEXT: return %[[RESULT]] : f32
127+
128+
// -----
129+
130+
// CHECK-LABEL: func @maxf_vector
131+
func @maxf_vector(%a: vector<4xf16>, %b: vector<4xf16>) -> vector<4xf16> {
132+
%result = maxf(%a, %b): (vector<4xf16>, vector<4xf16>) -> vector<4xf16>
133+
return %result : vector<4xf16>
134+
}
135+
// CHECK-SAME: %[[LHS:.*]]: vector<4xf16>, %[[RHS:.*]]: vector<4xf16>)
136+
// CHECK-NEXT: %[[CMP:.*]] = cmpf ogt, %[[LHS]], %[[RHS]] : vector<4xf16>
137+
// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]]
138+
// CHECK-NEXT: %[[IS_NAN:.*]] = cmpf uno, %[[LHS]], %[[RHS]] : vector<4xf16>
139+
// CHECK-NEXT: %[[NAN:.*]] = constant 0x7E00 : f16
140+
// CHECK-NEXT: %[[SPLAT_NAN:.*]] = splat %[[NAN]] : vector<4xf16>
141+
// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[SPLAT_NAN]], %[[SELECT]]
142+
// CHECK-NEXT: return %[[RESULT]] : vector<4xf16>
143+
144+
// -----
145+
146+
// CHECK-LABEL: func @minf
147+
func @minf(%a: f32, %b: f32) -> f32 {
148+
%result = minf(%a, %b): (f32, f32) -> f32
149+
return %result : f32
150+
}
151+
// CHECK-SAME: %[[LHS:.*]]: f32, %[[RHS:.*]]: f32)
152+
// CHECK-NEXT: %[[CMP:.*]] = cmpf olt, %[[LHS]], %[[RHS]] : f32
153+
// CHECK-NEXT: %[[SELECT:.*]] = select %[[CMP]], %[[LHS]], %[[RHS]] : f32
154+
// CHECK-NEXT: %[[IS_NAN:.*]] = cmpf uno, %[[LHS]], %[[RHS]] : f32
155+
// CHECK-NEXT: %[[NAN:.*]] = constant 0x7FC00000 : f32
156+
// CHECK-NEXT: %[[RESULT:.*]] = select %[[IS_NAN]], %[[NAN]], %[[SELECT]] : f32
157+
// CHECK-NEXT: return %[[RESULT]] : f32
158+
159+
160+
// -----
161+
162+
// CHECK-LABEL: func @maxsi
163+
func @maxsi(%a: i32, %b: i32) -> i32 {
164+
%result = maxsi(%a, %b): (i32, i32) -> i32
165+
return %result : i32
166+
}
167+
// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
168+
// CHECK-NEXT: %[[CMP:.*]] = cmpi sgt, %[[LHS]], %[[RHS]] : i32
169+
170+
// -----
171+
172+
// CHECK-LABEL: func @minsi
173+
func @minsi(%a: i32, %b: i32) -> i32 {
174+
%result = minsi(%a, %b): (i32, i32) -> i32
175+
return %result : i32
176+
}
177+
// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
178+
// CHECK-NEXT: %[[CMP:.*]] = cmpi slt, %[[LHS]], %[[RHS]] : i32
179+
180+
181+
// -----
182+
183+
// CHECK-LABEL: func @maxui
184+
func @maxui(%a: i32, %b: i32) -> i32 {
185+
%result = maxui(%a, %b): (i32, i32) -> i32
186+
return %result : i32
187+
}
188+
// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
189+
// CHECK-NEXT: %[[CMP:.*]] = cmpi ugt, %[[LHS]], %[[RHS]] : i32
190+
191+
192+
// -----
193+
194+
// CHECK-LABEL: func @minui
195+
func @minui(%a: i32, %b: i32) -> i32 {
196+
%result = minui(%a, %b): (i32, i32) -> i32
197+
return %result : i32
198+
}
199+
// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32)
200+
// CHECK-NEXT: %[[CMP:.*]] = cmpi ult, %[[LHS]], %[[RHS]] : i32

0 commit comments

Comments
 (0)