Skip to content

Commit 519f591

Browse files
committed
[mlir] Add fma operation to std dialect
Will remove `vector.fma` operation in the followup CLs. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D96801
1 parent fb19400 commit 519f591

File tree

4 files changed

+146
-26
lines changed

4 files changed

+146
-26
lines changed

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td

Lines changed: 105 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
103103

104104
// Base class for standard arithmetic operations. Requires operands and
105105
// results to be of the same type, but does not constrain them to specific
106-
// types. Individual classes will have `lhs` and `rhs` accessor to operands.
106+
// types.
107107
class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
108108
Op<StandardOps_Dialect, mnemonic,
109109
!listconcat(traits, [NoSideEffect,
@@ -122,6 +122,32 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
122122
}];
123123
}
124124

125+
// Base class for standard binary arithmetic operations.
126+
class ArithmeticBinaryOp<string mnemonic, list<OpTrait> traits = []> :
127+
ArithmeticOp<mnemonic, traits> {
128+
129+
let parser = [{
130+
return impl::parseOneResultSameOperandTypeOp(parser, result);
131+
}];
132+
133+
let printer = [{
134+
return printStandardBinaryOp(this->getOperation(), p);
135+
}];
136+
}
137+
138+
// Base class for standard ternary arithmetic operations.
139+
class ArithmeticTernaryOp<string mnemonic, list<OpTrait> traits = []> :
140+
ArithmeticOp<mnemonic, traits> {
141+
142+
let parser = [{
143+
return impl::parseOneResultSameOperandTypeOp(parser, result);
144+
}];
145+
146+
let printer = [{
147+
return printStandardTernaryOp(this->getOperation(), p);
148+
}];
149+
}
150+
125151
// Base class for standard arithmetic operations on integers, vectors and
126152
// tensors thereof. This operation takes two operands and returns one result,
127153
// each of these is required to be of the same type. This type may be an
@@ -130,8 +156,8 @@ class ArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
130156
//
131157
// <op>i %0, %1 : i32
132158
//
133-
class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
134-
ArithmeticOp<mnemonic,
159+
class IntBinaryOp<string mnemonic, list<OpTrait> traits = []> :
160+
ArithmeticBinaryOp<mnemonic,
135161
!listconcat(traits,
136162
[DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
137163
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>;
@@ -145,12 +171,27 @@ class IntArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
145171
//
146172
// <op>f %0, %1 : f32
147173
//
148-
class FloatArithmeticOp<string mnemonic, list<OpTrait> traits = []> :
149-
ArithmeticOp<mnemonic,
174+
class FloatBinaryOp<string mnemonic, list<OpTrait> traits = []> :
175+
ArithmeticBinaryOp<mnemonic,
150176
!listconcat(traits,
151177
[DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
152178
Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>;
153179

180+
// Base class for standard arithmetic ternary operations on floats, vectors and
181+
// tensors thereof. This operation has three operands and returns one result,
182+
// each of these is required to be of the same type. This type may be a
183+
// floating point scalar type, a vector whose element type is a floating point
184+
// type, or a floating point tensor. The custom assembly form of the operation
185+
// is as follows
186+
//
187+
// <op> %0, %1, %2 : f32
188+
//
189+
class FloatTernaryOp<string mnemonic, list<OpTrait> traits = []> :
190+
ArithmeticTernaryOp<mnemonic,
191+
!listconcat(traits,
192+
[DeclareOpInterfaceMethods<VectorUnrollOpInterface>])>,
193+
Arguments<(ins FloatLike:$a, FloatLike:$b, FloatLike:$c)>;
194+
154195
// Base class for memref allocating ops: alloca and alloc.
155196
//
156197
// %0 = alloclike(%m)[%s] : memref<8x?xf32, (d0, d1)[s0] -> ((d0 + s0), d1)>
@@ -257,7 +298,7 @@ def AbsFOp : FloatUnaryOp<"absf"> {
257298
// AddFOp
258299
//===----------------------------------------------------------------------===//
259300

260-
def AddFOp : FloatArithmeticOp<"addf"> {
301+
def AddFOp : FloatBinaryOp<"addf"> {
261302
let summary = "floating point addition operation";
262303
let description = [{
263304
Syntax:
@@ -294,7 +335,7 @@ def AddFOp : FloatArithmeticOp<"addf"> {
294335
// AddIOp
295336
//===----------------------------------------------------------------------===//
296337

297-
def AddIOp : IntArithmeticOp<"addi", [Commutative]> {
338+
def AddIOp : IntBinaryOp<"addi", [Commutative]> {
298339
let summary = "integer addition operation";
299340
let description = [{
300341
Syntax:
@@ -418,7 +459,7 @@ def AllocaOp : AllocLikeOp<"alloca", AutomaticAllocationScopeResource> {
418459
// AndOp
419460
//===----------------------------------------------------------------------===//
420461

421-
def AndOp : IntArithmeticOp<"and", [Commutative]> {
462+
def AndOp : IntBinaryOp<"and", [Commutative]> {
422463
let summary = "integer binary and";
423464
let description = [{
424465
Syntax:
@@ -1269,7 +1310,7 @@ def ConstantOp : Std_Op<"constant",
12691310
// CopySignOp
12701311
//===----------------------------------------------------------------------===//
12711312

1272-
def CopySignOp : FloatArithmeticOp<"copysign"> {
1313+
def CopySignOp : FloatBinaryOp<"copysign"> {
12731314
let summary = "A copysign operation";
12741315
let description = [{
12751316
Syntax:
@@ -1384,11 +1425,49 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> {
13841425
// DivFOp
13851426
//===----------------------------------------------------------------------===//
13861427

1387-
def DivFOp : FloatArithmeticOp<"divf"> {
1428+
def DivFOp : FloatBinaryOp<"divf"> {
13881429
let summary = "floating point division operation";
13891430
let hasFolder = 1;
13901431
}
13911432

1433+
//===----------------------------------------------------------------------===//
1434+
// FmaFOp
1435+
//===----------------------------------------------------------------------===//
1436+
1437+
def FmaFOp : FloatTernaryOp<"fmaf"> {
1438+
let summary = "floating point fused multipy-add operation";
1439+
let description = [{
1440+
Syntax:
1441+
1442+
```
1443+
operation ::= ssa-id `=` `std.fmaf` ssa-use `,` ssa-use `,` ssa-use `:` type
1444+
```
1445+
1446+
The `fmaf` operation takes three operands and returns one result, each of
1447+
these is required to be the same type. This type may be a floating point
1448+
scalar type, a vector whose element type is a floating point type, or a
1449+
floating point tensor.
1450+
1451+
Example:
1452+
1453+
```mlir
1454+
// Scalar fused multiply-add: d = a*b + c
1455+
%d = fmaf %a, %b, %c : f64
1456+
1457+
// SIMD vector fused multiply-add, e.g. for Intel SSE.
1458+
%i = fmaf %f, %g, %h : vector<4xf32>
1459+
1460+
// Tensor fused multiply-add.
1461+
%w = fmaf %x, %y, %z : tensor<4x?xbf16>
1462+
```
1463+
1464+
The semantics of the operation correspond to those of the `llvm.fma`
1465+
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-fma-intrinsic). In the
1466+
particular case of lowering to LLVM, this is guaranteed to lower
1467+
to the `llvm.fma.*` intrinsic.
1468+
}];
1469+
}
1470+
13921471
//===----------------------------------------------------------------------===//
13931472
// FPExtOp
13941473
//===----------------------------------------------------------------------===//
@@ -1854,7 +1933,7 @@ def MemRefReshapeOp: Std_Op<"memref_reshape", [
18541933
// MulFOp
18551934
//===----------------------------------------------------------------------===//
18561935

1857-
def MulFOp : FloatArithmeticOp<"mulf"> {
1936+
def MulFOp : FloatBinaryOp<"mulf"> {
18581937
let summary = "floating point multiplication operation";
18591938
let description = [{
18601939
Syntax:
@@ -1891,7 +1970,7 @@ def MulFOp : FloatArithmeticOp<"mulf"> {
18911970
// MulIOp
18921971
//===----------------------------------------------------------------------===//
18931972

1894-
def MulIOp : IntArithmeticOp<"muli", [Commutative]> {
1973+
def MulIOp : IntBinaryOp<"muli", [Commutative]> {
18951974
let summary = "integer multiplication operation";
18961975
let hasFolder = 1;
18971976
}
@@ -1933,7 +2012,7 @@ def NegFOp : FloatUnaryOp<"negf"> {
19332012
// OrOp
19342013
//===----------------------------------------------------------------------===//
19352014

1936-
def OrOp : IntArithmeticOp<"or", [Commutative]> {
2015+
def OrOp : IntBinaryOp<"or", [Commutative]> {
19372016
let summary = "integer binary or";
19382017
let description = [{
19392018
Syntax:
@@ -2040,7 +2119,7 @@ def RankOp : Std_Op<"rank", [NoSideEffect]> {
20402119
// RemFOp
20412120
//===----------------------------------------------------------------------===//
20422121

2043-
def RemFOp : FloatArithmeticOp<"remf"> {
2122+
def RemFOp : FloatBinaryOp<"remf"> {
20442123
let summary = "floating point division remainder operation";
20452124
}
20462125

@@ -2141,7 +2220,7 @@ def SelectOp : Std_Op<"select", [NoSideEffect,
21412220
// ShiftLeftOp
21422221
//===----------------------------------------------------------------------===//
21432222

2144-
def ShiftLeftOp : IntArithmeticOp<"shift_left"> {
2223+
def ShiftLeftOp : IntBinaryOp<"shift_left"> {
21452224
let summary = "integer left-shift";
21462225
let description = [{
21472226
The shift_left operation shifts an integer value to the left by a variable
@@ -2161,7 +2240,7 @@ def ShiftLeftOp : IntArithmeticOp<"shift_left"> {
21612240
// SignedDivIOp
21622241
//===----------------------------------------------------------------------===//
21632242

2164-
def SignedDivIOp : IntArithmeticOp<"divi_signed"> {
2243+
def SignedDivIOp : IntBinaryOp<"divi_signed"> {
21652244
let summary = "signed integer division operation";
21662245
let description = [{
21672246
Syntax:
@@ -2196,7 +2275,7 @@ def SignedDivIOp : IntArithmeticOp<"divi_signed"> {
21962275
// SignedFloorDivIOp
21972276
//===----------------------------------------------------------------------===//
21982277

2199-
def SignedFloorDivIOp : IntArithmeticOp<"floordivi_signed"> {
2278+
def SignedFloorDivIOp : IntBinaryOp<"floordivi_signed"> {
22002279
let summary = "signed floor integer division operation";
22012280
let description = [{
22022281
Syntax:
@@ -2225,7 +2304,7 @@ def SignedFloorDivIOp : IntArithmeticOp<"floordivi_signed"> {
22252304
// SignedCeilDivIOp
22262305
//===----------------------------------------------------------------------===//
22272306

2228-
def SignedCeilDivIOp : IntArithmeticOp<"ceildivi_signed"> {
2307+
def SignedCeilDivIOp : IntBinaryOp<"ceildivi_signed"> {
22292308
let summary = "signed ceil integer division operation";
22302309
let description = [{
22312310
Syntax:
@@ -2253,7 +2332,7 @@ def SignedCeilDivIOp : IntArithmeticOp<"ceildivi_signed"> {
22532332
// SignedRemIOp
22542333
//===----------------------------------------------------------------------===//
22552334

2256-
def SignedRemIOp : IntArithmeticOp<"remi_signed"> {
2335+
def SignedRemIOp : IntBinaryOp<"remi_signed"> {
22572336
let summary = "signed integer division remainder operation";
22582337
let description = [{
22592338
Syntax:
@@ -2288,7 +2367,7 @@ def SignedRemIOp : IntArithmeticOp<"remi_signed"> {
22882367
// SignedShiftRightOp
22892368
//===----------------------------------------------------------------------===//
22902369

2291-
def SignedShiftRightOp : IntArithmeticOp<"shift_right_signed"> {
2370+
def SignedShiftRightOp : IntBinaryOp<"shift_right_signed"> {
22922371
let summary = "signed integer right-shift";
22932372
let description = [{
22942373
The shift_right_signed operation shifts an integer value to the right by
@@ -2488,7 +2567,7 @@ def StoreOp : Std_Op<"store",
24882567
// SubFOp
24892568
//===----------------------------------------------------------------------===//
24902569

2491-
def SubFOp : FloatArithmeticOp<"subf"> {
2570+
def SubFOp : FloatBinaryOp<"subf"> {
24922571
let summary = "floating point subtraction operation";
24932572
let hasFolder = 1;
24942573
}
@@ -2497,7 +2576,7 @@ def SubFOp : FloatArithmeticOp<"subf"> {
24972576
// SubIOp
24982577
//===----------------------------------------------------------------------===//
24992578

2500-
def SubIOp : IntArithmeticOp<"subi"> {
2579+
def SubIOp : IntBinaryOp<"subi"> {
25012580
let summary = "integer subtraction operation";
25022581
let hasFolder = 1;
25032582
}
@@ -3173,7 +3252,7 @@ def UIToFPOp : ArithmeticCastOp<"uitofp">, Arguments<(ins AnyType:$in)> {
31733252
// UnsignedDivIOp
31743253
//===----------------------------------------------------------------------===//
31753254

3176-
def UnsignedDivIOp : IntArithmeticOp<"divi_unsigned"> {
3255+
def UnsignedDivIOp : IntBinaryOp<"divi_unsigned"> {
31773256
let summary = "unsigned integer division operation";
31783257
let description = [{
31793258
Syntax:
@@ -3208,7 +3287,7 @@ def UnsignedDivIOp : IntArithmeticOp<"divi_unsigned"> {
32083287
// UnsignedRemIOp
32093288
//===----------------------------------------------------------------------===//
32103289

3211-
def UnsignedRemIOp : IntArithmeticOp<"remi_unsigned"> {
3290+
def UnsignedRemIOp : IntBinaryOp<"remi_unsigned"> {
32123291
let summary = "unsigned integer division remainder operation";
32133292
let description = [{
32143293
Syntax:
@@ -3243,7 +3322,7 @@ def UnsignedRemIOp : IntArithmeticOp<"remi_unsigned"> {
32433322
// UnsignedShiftRightOp
32443323
//===----------------------------------------------------------------------===//
32453324

3246-
def UnsignedShiftRightOp : IntArithmeticOp<"shift_right_unsigned"> {
3325+
def UnsignedShiftRightOp : IntBinaryOp<"shift_right_unsigned"> {
32473326
let summary = "unsigned integer right-shift";
32483327
let description = [{
32493328
The shift_right_unsigned operation shifts an integer value to the right by
@@ -3332,7 +3411,7 @@ def ViewOp : Std_Op<"view", [
33323411
// XOrOp
33333412
//===----------------------------------------------------------------------===//
33343413

3335-
def XOrOp : IntArithmeticOp<"xor", [Commutative]> {
3414+
def XOrOp : IntBinaryOp<"xor", [Commutative]> {
33363415
let summary = "integer binary xor";
33373416
let description = [{
33383417
The `xor` operation takes two operands and returns one result, each of these

mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1662,6 +1662,7 @@ using DivFOpLowering = VectorConvertToLLVMPattern<DivFOp, LLVM::FDivOp>;
16621662
using ExpOpLowering = VectorConvertToLLVMPattern<math::ExpOp, LLVM::ExpOp>;
16631663
using Exp2OpLowering = VectorConvertToLLVMPattern<math::Exp2Op, LLVM::Exp2Op>;
16641664
using FloorFOpLowering = VectorConvertToLLVMPattern<FloorFOp, LLVM::FFloorOp>;
1665+
using FmaFOpLowering = VectorConvertToLLVMPattern<FmaFOp, LLVM::FMAOp>;
16651666
using Log10OpLowering =
16661667
VectorConvertToLLVMPattern<math::Log10Op, LLVM::Log10Op>;
16671668
using Log2OpLowering = VectorConvertToLLVMPattern<math::Log2Op, LLVM::Log2Op>;
@@ -3775,6 +3776,7 @@ void mlir::populateStdToLLVMNonMemoryConversionPatterns(
37753776
ExpOpLowering,
37763777
Exp2OpLowering,
37773778
FloorFOpLowering,
3779+
FmaFOpLowering,
37783780
GenericAtomicRMWOpLowering,
37793781
LogOpLowering,
37803782
Log10OpLowering,

mlir/lib/Dialect/StandardOps/IR/Ops.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,32 @@ static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
158158
p << " : " << op->getResult(0).getType();
159159
}
160160

161+
/// A custom ternary operation printer that omits the "std." prefix from the
162+
/// operation names.
163+
static void printStandardTernaryOp(Operation *op, OpAsmPrinter &p) {
164+
assert(op->getNumOperands() == 3 && "ternary op should have three operands");
165+
assert(op->getNumResults() == 1 && "ternary op should have one result");
166+
167+
// If not all the operand and result types are the same, just use the
168+
// generic assembly form to avoid omitting information in printing.
169+
auto resultType = op->getResult(0).getType();
170+
if (op->getOperand(0).getType() != resultType ||
171+
op->getOperand(1).getType() != resultType ||
172+
op->getOperand(2).getType() != resultType) {
173+
p.printGenericOp(op);
174+
return;
175+
}
176+
177+
int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
178+
p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
179+
<< op->getOperand(0) << ", " << op->getOperand(1) << ", "
180+
<< op->getOperand(2);
181+
p.printOptionalAttrDict(op->getAttrs());
182+
183+
// Now we can output only one type for all operands and the result.
184+
p << " : " << op->getResult(0).getType();
185+
}
186+
161187
/// A custom cast operation printer that omits the "std." prefix from the
162188
/// operation names.
163189
static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {

mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,16 @@ func @powf(%arg0 : f64) {
223223
%0 = math.powf %arg0, %arg0 : f64
224224
std.return
225225
}
226+
227+
// -----
228+
229+
// CHECK-LABEL: func @fmaf(
230+
// CHECK-SAME: %[[ARG0:.*]]: f32
231+
// CHECK-SAME: %[[ARG1:.*]]: vector<4xf32>
232+
func @fmaf(%arg0: f32, %arg1: vector<4xf32>) {
233+
// CHECK: %[[S:.*]] = "llvm.intr.fma"(%[[ARG0]], %[[ARG0]], %[[ARG0]]) : (f32, f32, f32) -> f32
234+
%0 = fmaf %arg0, %arg0, %arg0 : f32
235+
// CHECK: %[[V:.*]] = "llvm.intr.fma"(%[[ARG1]], %[[ARG1]], %[[ARG1]]) : (vector<4xf32>, vector<4xf32>, vector<4xf32>) -> vector<4xf32>
236+
%1 = fmaf %arg1, %arg1, %arg1 : vector<4xf32>
237+
std.return
238+
}

0 commit comments

Comments
 (0)