Skip to content

Commit 60a0c61

Browse files
Frank Laubftynse
authored andcommitted
[MLIR] LLVM dialect: Add llvm.atomicrmw
Summary: This op is the counterpart to LLVM's atomicrmw instruction. Note that volatile and syncscope attributes are not yet supported. This will be useful for upcoming parallel versions of `affine.for` and generally for reduction-like semantics. Differential Revision: https://reviews.llvm.org/D72741
1 parent 37e2560 commit 60a0c61

File tree

6 files changed

+292
-0
lines changed

6 files changed

+292
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,4 +723,56 @@ def LLVM_Prefetch : LLVM_ZeroResultOp<"intr.prefetch">,
723723
}];
724724
}
725725

726+
def AtomicBinOpXchg : I64EnumAttrCase<"xchg", 0>;
727+
def AtomicBinOpAdd : I64EnumAttrCase<"add", 1>;
728+
def AtomicBinOpSub : I64EnumAttrCase<"sub", 2>;
729+
def AtomicBinOpAnd : I64EnumAttrCase<"_and", 3>;
730+
def AtomicBinOpNand : I64EnumAttrCase<"nand", 4>;
731+
def AtomicBinOpOr : I64EnumAttrCase<"_or", 5>;
732+
def AtomicBinOpXor : I64EnumAttrCase<"_xor", 6>;
733+
def AtomicBinOpMax : I64EnumAttrCase<"max", 7>;
734+
def AtomicBinOpMin : I64EnumAttrCase<"min", 8>;
735+
def AtomicBinOpUMax : I64EnumAttrCase<"umax", 9>;
736+
def AtomicBinOpUMin : I64EnumAttrCase<"umin", 10>;
737+
def AtomicBinOpFAdd : I64EnumAttrCase<"fadd", 11>;
738+
def AtomicBinOpFSub : I64EnumAttrCase<"fsub", 12>;
739+
def AtomicBinOp : I64EnumAttr<
740+
"AtomicBinOp",
741+
"llvm.atomicrmw binary operations",
742+
[AtomicBinOpXchg, AtomicBinOpAdd, AtomicBinOpSub, AtomicBinOpAnd,
743+
AtomicBinOpNand, AtomicBinOpOr, AtomicBinOpXor, AtomicBinOpMax,
744+
AtomicBinOpMin, AtomicBinOpUMax, AtomicBinOpUMin, AtomicBinOpFAdd,
745+
AtomicBinOpFSub]> {
746+
let cppNamespace = "::mlir::LLVM";
747+
}
748+
749+
def AtomicOrderingNotAtomic : I64EnumAttrCase<"not_atomic", 0>;
750+
def AtomicOrderingUnordered : I64EnumAttrCase<"unordered", 1>;
751+
def AtomicOrderingMonotonic : I64EnumAttrCase<"monotonic", 2>;
752+
def AtomicOrderingAcquire : I64EnumAttrCase<"acquire", 4>;
753+
def AtomicOrderingRelease : I64EnumAttrCase<"release", 5>;
754+
def AtomicOrderingAcquireRelease : I64EnumAttrCase<"acq_rel", 6>;
755+
def AtomicOrderingSequentiallyConsistent : I64EnumAttrCase<"seq_cst", 7>;
756+
def AtomicOrdering : I64EnumAttr<
757+
"AtomicOrdering",
758+
"Atomic ordering for LLVM's memory model",
759+
[AtomicOrderingNotAtomic, AtomicOrderingUnordered, AtomicOrderingMonotonic,
760+
AtomicOrderingAcquire, AtomicOrderingRelease, AtomicOrderingAcquireRelease,
761+
AtomicOrderingSequentiallyConsistent]> {
762+
let cppNamespace = "::mlir::LLVM";
763+
}
764+
765+
def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw">,
766+
Arguments<(ins AtomicBinOp:$bin_op, LLVM_Type:$ptr, LLVM_Type:$val,
767+
AtomicOrdering:$ordering)>,
768+
Results<(outs LLVM_Type:$res)> {
769+
let llvmBuilder = [{
770+
$res = builder.CreateAtomicRMW(getLLVMAtomicBinOp($bin_op), $ptr, $val,
771+
getLLVMAtomicOrdering($ordering));
772+
}];
773+
let parser = [{ return parseAtomicRMWOp(parser, result); }];
774+
let printer = [{ printAtomicRMWOp(p, *this); }];
775+
let verifier = "return ::verify(*this);";
776+
}
777+
726778
#endif // LLVMIR_OPS

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,6 +1368,107 @@ static LogicalResult verify(LLVM::NullOp op) {
13681368
return success();
13691369
}
13701370

1371+
//===----------------------------------------------------------------------===//
1372+
// Printer, parser and verifier for LLVM::AtomicRMWOp.
1373+
//===----------------------------------------------------------------------===//
1374+
1375+
static void printAtomicRMWOp(OpAsmPrinter &p, AtomicRMWOp &op) {
1376+
p << op.getOperationName() << " ";
1377+
p << '"' << stringifyAtomicBinOp(op.bin_op()) << "\" ";
1378+
p << '"' << stringifyAtomicOrdering(op.ordering()) << "\" ";
1379+
p << op.ptr() << ", " << op.val();
1380+
p.printOptionalAttrDict(op.getAttrs(), {"bin_op", "ordering"});
1381+
p << " : (" << op.ptr().getType() << ", " << op.val().getType() << ") -> "
1382+
<< op.res().getType();
1383+
}
1384+
1385+
// <operation> ::= `llvm.atomicrmw` string-literal string-literal
1386+
// ssa-use `,` ssa-use attribute-dict? `:` type
1387+
static ParseResult parseAtomicRMWOp(OpAsmParser &parser,
1388+
OperationState &result) {
1389+
Type type;
1390+
StringAttr binOp, ordering;
1391+
llvm::SMLoc binOpLoc, orderingLoc, trailingTypeLoc;
1392+
OpAsmParser::OperandType ptr, val;
1393+
if (parser.getCurrentLocation(&binOpLoc) ||
1394+
parser.parseAttribute(binOp, "bin_op", result.attributes) ||
1395+
parser.getCurrentLocation(&orderingLoc) ||
1396+
parser.parseAttribute(ordering, "ordering", result.attributes) ||
1397+
parser.parseOperand(ptr) || parser.parseComma() ||
1398+
parser.parseOperand(val) ||
1399+
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1400+
parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
1401+
return failure();
1402+
1403+
// Extract the result type from the trailing function type.
1404+
auto funcType = type.dyn_cast<FunctionType>();
1405+
if (!funcType || funcType.getNumInputs() != 2 ||
1406+
funcType.getNumResults() != 1)
1407+
return parser.emitError(
1408+
trailingTypeLoc,
1409+
"expected trailing function type with two arguments and one result");
1410+
1411+
if (parser.resolveOperand(ptr, funcType.getInput(0), result.operands) ||
1412+
parser.resolveOperand(val, funcType.getInput(1), result.operands))
1413+
return failure();
1414+
1415+
// Replace the string attribute `bin_op` with an integer attribute.
1416+
auto binOpKind = symbolizeAtomicBinOp(binOp.getValue());
1417+
if (!binOpKind) {
1418+
return parser.emitError(binOpLoc)
1419+
<< "'" << binOp.getValue()
1420+
<< "' is an incorrect value of the 'bin_op' attribute";
1421+
}
1422+
1423+
auto binOpValue = static_cast<int64_t>(binOpKind.getValue());
1424+
auto binOpAttr = parser.getBuilder().getI64IntegerAttr(binOpValue);
1425+
result.attributes[0].second = binOpAttr;
1426+
1427+
// Replace the string attribute `ordering` with an integer attribute.
1428+
auto orderingKind = symbolizeAtomicOrdering(ordering.getValue());
1429+
if (!orderingKind) {
1430+
return parser.emitError(orderingLoc)
1431+
<< "'" << ordering.getValue()
1432+
<< "' is an incorrect value of the 'ordering' attribute";
1433+
}
1434+
1435+
auto orderingValue = static_cast<int64_t>(orderingKind.getValue());
1436+
auto orderingAttr = parser.getBuilder().getI64IntegerAttr(orderingValue);
1437+
result.attributes[1].second = orderingAttr;
1438+
1439+
result.addTypes(funcType.getResults());
1440+
return success();
1441+
}
1442+
1443+
static LogicalResult verify(AtomicRMWOp op) {
1444+
auto ptrType = op.ptr().getType().cast<LLVM::LLVMType>();
1445+
if (!ptrType.isPointerTy())
1446+
return op.emitOpError("expected LLVM IR pointer type for operand #0");
1447+
auto valType = op.val().getType().cast<LLVM::LLVMType>();
1448+
if (valType != ptrType.getPointerElementTy())
1449+
return op.emitOpError("expected LLVM IR element type for operand #0 to "
1450+
"match type for operand #1");
1451+
auto resType = op.res().getType().cast<LLVM::LLVMType>();
1452+
if (resType != valType)
1453+
return op.emitOpError(
1454+
"expected LLVM IR result type to match type for operand #1");
1455+
if (op.bin_op() == AtomicBinOp::fadd || op.bin_op() == AtomicBinOp::fsub) {
1456+
if (!valType.getUnderlyingType()->isFloatingPointTy())
1457+
return op.emitOpError("expected LLVM IR floating point type");
1458+
} else if (op.bin_op() == AtomicBinOp::xchg) {
1459+
if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
1460+
!valType.isIntegerTy(32) && !valType.isIntegerTy(64) &&
1461+
!valType.getUnderlyingType()->isHalfTy() && !valType.isFloatTy() &&
1462+
!valType.isDoubleTy())
1463+
return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
1464+
} else {
1465+
if (!valType.isIntegerTy(8) && !valType.isIntegerTy(16) &&
1466+
!valType.isIntegerTy(32) && !valType.isIntegerTy(64))
1467+
return op.emitOpError("expected LLVM IR integer type");
1468+
}
1469+
return success();
1470+
}
1471+
13711472
//===----------------------------------------------------------------------===//
13721473
// LLVMDialect initialization, type parsing, and registration.
13731474
//===----------------------------------------------------------------------===//

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,58 @@ static llvm::CmpInst::Predicate getLLVMCmpPredicate(FCmpPredicate p) {
211211
llvm_unreachable("incorrect comparison predicate");
212212
}
213213

214+
static llvm::AtomicRMWInst::BinOp getLLVMAtomicBinOp(AtomicBinOp op) {
215+
switch (op) {
216+
case LLVM::AtomicBinOp::xchg:
217+
return llvm::AtomicRMWInst::BinOp::Xchg;
218+
case LLVM::AtomicBinOp::add:
219+
return llvm::AtomicRMWInst::BinOp::Add;
220+
case LLVM::AtomicBinOp::sub:
221+
return llvm::AtomicRMWInst::BinOp::Sub;
222+
case LLVM::AtomicBinOp::_and:
223+
return llvm::AtomicRMWInst::BinOp::And;
224+
case LLVM::AtomicBinOp::nand:
225+
return llvm::AtomicRMWInst::BinOp::Nand;
226+
case LLVM::AtomicBinOp::_or:
227+
return llvm::AtomicRMWInst::BinOp::Or;
228+
case LLVM::AtomicBinOp::_xor:
229+
return llvm::AtomicRMWInst::BinOp::Xor;
230+
case LLVM::AtomicBinOp::max:
231+
return llvm::AtomicRMWInst::BinOp::Max;
232+
case LLVM::AtomicBinOp::min:
233+
return llvm::AtomicRMWInst::BinOp::Min;
234+
case LLVM::AtomicBinOp::umax:
235+
return llvm::AtomicRMWInst::BinOp::UMax;
236+
case LLVM::AtomicBinOp::umin:
237+
return llvm::AtomicRMWInst::BinOp::UMin;
238+
case LLVM::AtomicBinOp::fadd:
239+
return llvm::AtomicRMWInst::BinOp::FAdd;
240+
case LLVM::AtomicBinOp::fsub:
241+
return llvm::AtomicRMWInst::BinOp::FSub;
242+
}
243+
llvm_unreachable("incorrect atomic binary operator");
244+
}
245+
246+
static llvm::AtomicOrdering getLLVMAtomicOrdering(AtomicOrdering ordering) {
247+
switch (ordering) {
248+
case LLVM::AtomicOrdering::not_atomic:
249+
return llvm::AtomicOrdering::NotAtomic;
250+
case LLVM::AtomicOrdering::unordered:
251+
return llvm::AtomicOrdering::Unordered;
252+
case LLVM::AtomicOrdering::monotonic:
253+
return llvm::AtomicOrdering::Monotonic;
254+
case LLVM::AtomicOrdering::acquire:
255+
return llvm::AtomicOrdering::Acquire;
256+
case LLVM::AtomicOrdering::release:
257+
return llvm::AtomicOrdering::Release;
258+
case LLVM::AtomicOrdering::acq_rel:
259+
return llvm::AtomicOrdering::AcquireRelease;
260+
case LLVM::AtomicOrdering::seq_cst:
261+
return llvm::AtomicOrdering::SequentiallyConsistent;
262+
}
263+
llvm_unreachable("incorrect atomic ordering");
264+
}
265+
214266
/// Given a single MLIR operation, create the corresponding LLVM IR operation
215267
/// using the `builder`. LLVM IR Builder does not have a generic interface so
216268
/// this has to be a long chain of `if`s calling different functions with a

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,4 +393,51 @@ llvm.func @recursive_type(%a : !llvm<"%a = type { %a* }">) ->
393393
llvm.return %a : !llvm<"%a = type { %a* }">
394394
}
395395

396+
// -----
397+
398+
// CHECK-LABEL: @atomicrmw_expected_ptr
399+
func @atomicrmw_expected_ptr(%f32 : !llvm.float) {
400+
// expected-error@+1 {{expected LLVM IR pointer type for operand #0}}
401+
%0 = llvm.atomicrmw "fadd" "unordered" %f32, %f32 : (!llvm.float, !llvm.float) -> !llvm.float
402+
llvm.return
403+
}
404+
405+
// -----
406+
// CHECK-LABEL: @atomicrmw_mismatched_operands
407+
func @atomicrmw_mismatched_operands(%f32_ptr : !llvm<"float*">, %i32 : !llvm.i32) {
408+
// expected-error@+1 {{expected LLVM IR element type for operand #0 to match type for operand #1}}
409+
%0 = llvm.atomicrmw "fadd" "unordered" %f32_ptr, %i32 : (!llvm<"float*">, !llvm.i32) -> !llvm.float
410+
llvm.return
411+
}
412+
413+
// -----
414+
// CHECK-LABEL: @atomicrmw_mismatched_result
415+
func @atomicrmw_mismatched_operands(%f32_ptr : !llvm<"float*">, %f32 : !llvm.float) {
416+
// expected-error@+1 {{expected LLVM IR result type to match type for operand #1}}
417+
%0 = llvm.atomicrmw "fadd" "unordered" %f32_ptr, %f32 : (!llvm<"float*">, !llvm.float) -> !llvm.i32
418+
llvm.return
419+
}
396420

421+
// -----
422+
// CHECK-LABEL: @atomicrmw_expected_float
423+
func @atomicrmw_expected_float(%i32_ptr : !llvm<"i32*">, %i32 : !llvm.i32) {
424+
// expected-error@+1 {{expected LLVM IR floating point type}}
425+
%0 = llvm.atomicrmw "fadd" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
426+
llvm.return
427+
}
428+
429+
// -----
430+
// CHECK-LABEL: @atomicrmw_unexpected_xchg_type
431+
func @atomicrmw_xchg_type(%i1_ptr : !llvm<"i1*">, %i1 : !llvm.i1) {
432+
// expected-error@+1 {{unexpected LLVM IR type for 'xchg' bin_op}}
433+
%0 = llvm.atomicrmw "xchg" "unordered" %i1_ptr, %i1 : (!llvm<"i1*">, !llvm.i1) -> !llvm.i1
434+
llvm.return
435+
}
436+
437+
// -----
438+
// CHECK-LABEL: @atomicrmw_expected_int
439+
func @atomicrmw_expected_int(%f32_ptr : !llvm<"float*">, %f32 : !llvm.float) {
440+
// expected-error@+1 {{expected LLVM IR integer type}}
441+
%0 = llvm.atomicrmw "max" "unordered" %f32_ptr, %f32 : (!llvm<"float*">, !llvm.float) -> !llvm.float
442+
llvm.return
443+
}

mlir/test/Dialect/LLVMIR/roundtrip.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,10 @@ func @null() {
218218
%1 = llvm.mlir.null : !llvm<"{void(i32, void()*)*, i64}*">
219219
llvm.return
220220
}
221+
222+
// CHECK-LABEL: @atomics
223+
func @atomics(%arg0 : !llvm<"float*">, %arg1 : !llvm.float) {
224+
// CHECK: llvm.atomicrmw "fadd" "unordered" %{{.*}}, %{{.*}} : (!llvm<"float*">, !llvm.float) -> !llvm.float
225+
%0 = llvm.atomicrmw "fadd" "unordered" %arg0, %arg1 : (!llvm<"float*">, !llvm.float) -> !llvm.float
226+
llvm.return
227+
}

mlir/test/Target/llvmir.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,3 +1086,36 @@ llvm.func @elements_constant_3d_array() -> !llvm<"[2 x [2 x [2 x i32]]]"> {
10861086
%0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm<"[2 x [2 x [2 x i32]]]">
10871087
llvm.return %0 : !llvm<"[2 x [2 x [2 x i32]]]">
10881088
}
1089+
1090+
// CHECK-LABEL: @atomics
1091+
llvm.func @atomics(
1092+
%f32_ptr : !llvm<"float*">, %f32 : !llvm.float,
1093+
%i32_ptr : !llvm<"i32*">, %i32 : !llvm.i32) -> !llvm.float {
1094+
// CHECK: atomicrmw fadd float* %{{.*}}, float %{{.*}} unordered
1095+
%0 = llvm.atomicrmw "fadd" "unordered" %f32_ptr, %f32 : (!llvm<"float*">, !llvm.float) -> !llvm.float
1096+
// CHECK: atomicrmw fsub float* %{{.*}}, float %{{.*}} unordered
1097+
%1 = llvm.atomicrmw "fsub" "unordered" %f32_ptr, %f32 : (!llvm<"float*">, !llvm.float) -> !llvm.float
1098+
// CHECK: atomicrmw xchg float* %{{.*}}, float %{{.*}} monotonic
1099+
%2 = llvm.atomicrmw "xchg" "monotonic" %f32_ptr, %f32 : (!llvm<"float*">, !llvm.float) -> !llvm.float
1100+
// CHECK: atomicrmw add i32* %{{.*}}, i32 %{{.*}} acquire
1101+
%3 = llvm.atomicrmw "add" "acquire" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
1102+
// CHECK: atomicrmw sub i32* %{{.*}}, i32 %{{.*}} release
1103+
%4 = llvm.atomicrmw "sub" "release" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
1104+
// CHECK: atomicrmw and i32* %{{.*}}, i32 %{{.*}} acq_rel
1105+
%5 = llvm.atomicrmw "_and" "acq_rel" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
1106+
// CHECK: atomicrmw nand i32* %{{.*}}, i32 %{{.*}} seq_cst
1107+
%6 = llvm.atomicrmw "nand" "seq_cst" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
1108+
// CHECK: atomicrmw or i32* %{{.*}}, i32 %{{.*}} unordered
1109+
%7 = llvm.atomicrmw "_or" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
1110+
// CHECK: atomicrmw xor i32* %{{.*}}, i32 %{{.*}} unordered
1111+
%8 = llvm.atomicrmw "_xor" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
1112+
// CHECK: atomicrmw max i32* %{{.*}}, i32 %{{.*}} unordered
1113+
%9 = llvm.atomicrmw "max" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
1114+
// CHECK: atomicrmw min i32* %{{.*}}, i32 %{{.*}} unordered
1115+
%10 = llvm.atomicrmw "min" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
1116+
// CHECK: atomicrmw umax i32* %{{.*}}, i32 %{{.*}} unordered
1117+
%11 = llvm.atomicrmw "umax" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
1118+
// CHECK: atomicrmw umin i32* %{{.*}}, i32 %{{.*}} unordered
1119+
%12 = llvm.atomicrmw "umin" "unordered" %i32_ptr, %i32 : (!llvm<"i32*">, !llvm.i32) -> !llvm.i32
1120+
llvm.return %0 : !llvm.float
1121+
}

0 commit comments

Comments
 (0)