Skip to content

Commit d242aa2

Browse files
shraiyshftynse
authored andcommitted
[MLIR] Added llvm.invoke and llvm.landingpad
Summary: I have tried to implement `llvm.invoke` and `llvm.landingpad`. # `llvm.invoke` is similar to `llvm.call` with two successors added, the first one is the normal label and the second one is unwind label. # `llvm.launchpad` takes a variable number of args with either `catch` or `filter` associated with them. Catch clauses are not array types and filter clauses are array types. This is same as the criteria used by LLVM (https://github.com/llvm/llvm-project/blob/4f82af81a04d711721300f6ca32f402f2ea6faf4/llvm/include/llvm/IR/Instructions.h#L2866) Examples: LLVM IR ``` define i32 @caller(i32 %a) personality i8* bitcast (i32 (...)* @__gxx_personality_v0 to i8*) { invoke i32 @foo(i32 2) to label %success unwind label %fail success: ret i32 2 fail: landingpad {i8*, i32} catch i8** @_ZTIi catch i8** null catch i8* bitcast (i8** @_ZTIi to i8*) filter [1 x i8] [ i8 1 ] ret i32 3 } ``` MLIR LLVM Dialect ``` llvm.func @caller(%arg0: !llvm.i32) -> !llvm.i32 { %0 = llvm.mlir.constant(3 : i32) : !llvm.i32 %1 = llvm.mlir.constant("\01") : !llvm<"[1 x i8]"> %2 = llvm.mlir.addressof @_ZTIi : !llvm<"i8**"> %3 = llvm.bitcast %2 : !llvm<"i8**"> to !llvm<"i8*"> %4 = llvm.mlir.null : !llvm<"i8**"> %5 = llvm.mlir.addressof @_ZTIi : !llvm<"i8**"> %6 = llvm.mlir.constant(2 : i32) : !llvm.i32 %7 = llvm.invoke @foo(%6) to ^bb1 unwind ^bb2 : (!llvm.i32) -> !llvm.i32 ^bb1: // pred: ^bb0 llvm.return %6 : !llvm.i32 ^bb2: // pred: ^bb0 %8 = llvm.landingpad (catch %5 : !llvm<"i8**">) (catch %4 : !llvm<"i8**">) (catch %3 : !llvm<"i8*">) (filter %1 : !llvm<"[1 x i8]">) : !llvm<"{ i8*, i32 }"> llvm.return %0 : !llvm.i32 } ``` Signed-off-by: Shraiysh Vaishay <[email protected]> Differential Revision: https://reviews.llvm.org/D72006
1 parent 06e1289 commit d242aa2

File tree

8 files changed

+525
-19
lines changed

8 files changed

+525
-19
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,41 @@ def LLVM_FPExtOp : LLVM_CastOp<"fpext", "CreateFPExt">;
315315
def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "CreateFPTrunc">;
316316

317317
// Call-related operations.
318+
def LLVM_InvokeOp : LLVM_Op<"invoke", [Terminator]>,
319+
Arguments<(ins OptionalAttr<FlatSymbolRefAttr>:$callee,
320+
Variadic<LLVM_Type>)>,
321+
Results<(outs Variadic<LLVM_Type>)> {
322+
let builders = [OpBuilder<
323+
"Builder *b, OperationState &result, ArrayRef<Type> tys, "
324+
"FlatSymbolRefAttr callee, ValueRange ops, Block* normal, "
325+
"ValueRange normalOps, Block* unwind, ValueRange unwindOps",
326+
[{
327+
result.addAttribute("callee", callee);
328+
build(b, result, tys, ops, normal, normalOps, unwind, unwindOps);
329+
}]>,
330+
OpBuilder<
331+
"Builder *b, OperationState &result, ArrayRef<Type> tys, "
332+
"ValueRange ops, Block* normal, "
333+
"ValueRange normalOps, Block* unwind, ValueRange unwindOps",
334+
[{
335+
result.addTypes(tys);
336+
result.addOperands(ops);
337+
result.addSuccessor(normal, normalOps);
338+
result.addSuccessor(unwind, unwindOps);
339+
}]>];
340+
let verifier = [{ return ::verify(*this); }];
341+
let parser = [{ return parseInvokeOp(parser, result); }];
342+
let printer = [{ printInvokeOp(p, *this); }];
343+
}
344+
345+
def LLVM_LandingpadOp : LLVM_OneResultOp<"landingpad">,
346+
Arguments<(ins UnitAttr:$cleanup,
347+
Variadic<LLVM_Type>)> {
348+
let verifier = [{ return ::verify(*this); }];
349+
let parser = [{ return parseLandingpadOp(parser, result); }];
350+
let printer = [{ printLandingpadOp(p, *this); }];
351+
}
352+
318353
def LLVM_CallOp : LLVM_Op<"call">,
319354
Arguments<(ins OptionalAttr<FlatSymbolRefAttr>:$callee,
320355
Variadic<LLVM_Type>)>,

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,231 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
272272
return success();
273273
}
274274

275+
///===----------------------------------------------------------------------===//
276+
/// Verifying/Printing/Parsing for LLVM::InvokeOp.
277+
///===----------------------------------------------------------------------===//
278+
279+
static LogicalResult verify(InvokeOp op) {
280+
if (op.getNumResults() > 1)
281+
return op.emitOpError("must have 0 or 1 result");
282+
if (op.getNumSuccessors() != 2)
283+
return op.emitOpError("must have normal and unwind destinations");
284+
285+
if (op.getSuccessor(1)->empty())
286+
return op.emitError(
287+
"must have at least one operation in unwind destination");
288+
289+
// In unwind destination, first operation must be LandingpadOp
290+
if (!isa<LandingpadOp>(op.getSuccessor(1)->front()))
291+
return op.emitError("first operation in unwind destination should be a "
292+
"llvm.landingpad operation");
293+
294+
return success();
295+
}
296+
297+
static void printInvokeOp(OpAsmPrinter &p, InvokeOp &op) {
298+
auto callee = op.callee();
299+
bool isDirect = callee.hasValue();
300+
301+
p << op.getOperationName() << ' ';
302+
303+
// Either function name or pointer
304+
if (isDirect)
305+
p.printSymbolName(callee.getValue());
306+
else
307+
p << op.getOperand(0);
308+
309+
p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')';
310+
p << " to ";
311+
p.printSuccessorAndUseList(op.getOperation(), 0);
312+
p << " unwind ";
313+
p.printSuccessorAndUseList(op.getOperation(), 1);
314+
315+
p.printOptionalAttrDict(op.getAttrs(), {"callee"});
316+
317+
SmallVector<Type, 8> argTypes(
318+
llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1));
319+
320+
p << " : "
321+
<< FunctionType::get(argTypes, op.getResultTypes(), op.getContext());
322+
}
323+
324+
/// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)`
325+
/// `to` bb-id (`[` ssa-use-and-type-list `]`)?
326+
/// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
327+
/// attribute-dict? `:` function-type
328+
static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) {
329+
SmallVector<OpAsmParser::OperandType, 8> operands;
330+
FunctionType funcType;
331+
SymbolRefAttr funcAttr;
332+
llvm::SMLoc trailingTypeLoc;
333+
Block *normalDest, *unwindDest;
334+
SmallVector<Value, 4> normalOperands, unwindOperands;
335+
336+
// Parse an operand list that will, in practice, contain 0 or 1 operand. In
337+
// case of an indirect call, there will be 1 operand before `(`. In case of a
338+
// direct call, there will be no operands and the parser will stop at the
339+
// function identifier without complaining.
340+
if (parser.parseOperandList(operands))
341+
return failure();
342+
bool isDirect = operands.empty();
343+
344+
// Optionally parse a function identifier.
345+
if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes))
346+
return failure();
347+
348+
if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
349+
parser.parseKeyword("to") ||
350+
parser.parseSuccessorAndUseList(normalDest, normalOperands) ||
351+
parser.parseKeyword("unwind") ||
352+
parser.parseSuccessorAndUseList(unwindDest, unwindOperands) ||
353+
parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
354+
parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(funcType))
355+
return failure();
356+
357+
if (isDirect) {
358+
// Make sure types match.
359+
if (parser.resolveOperands(operands, funcType.getInputs(),
360+
parser.getNameLoc(), result.operands))
361+
return failure();
362+
result.addTypes(funcType.getResults());
363+
} else {
364+
// Construct the LLVM IR Dialect function type that the first operand
365+
// should match.
366+
if (funcType.getNumResults() > 1)
367+
return parser.emitError(trailingTypeLoc,
368+
"expected function with 0 or 1 result");
369+
370+
Builder &builder = parser.getBuilder();
371+
auto *llvmDialect =
372+
builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
373+
LLVM::LLVMType llvmResultType;
374+
if (funcType.getNumResults() == 0) {
375+
llvmResultType = LLVM::LLVMType::getVoidTy(llvmDialect);
376+
} else {
377+
llvmResultType = funcType.getResult(0).dyn_cast<LLVM::LLVMType>();
378+
if (!llvmResultType)
379+
return parser.emitError(trailingTypeLoc,
380+
"expected result to have LLVM type");
381+
}
382+
383+
SmallVector<LLVM::LLVMType, 8> argTypes;
384+
argTypes.reserve(funcType.getNumInputs());
385+
for (Type ty : funcType.getInputs()) {
386+
if (auto argType = ty.dyn_cast<LLVM::LLVMType>())
387+
argTypes.push_back(argType);
388+
else
389+
return parser.emitError(trailingTypeLoc,
390+
"expected LLVM types as inputs");
391+
}
392+
393+
auto llvmFuncType = LLVM::LLVMType::getFunctionTy(llvmResultType, argTypes,
394+
/*isVarArg=*/false);
395+
auto wrappedFuncType = llvmFuncType.getPointerTo();
396+
397+
auto funcArguments = llvm::makeArrayRef(operands).drop_front();
398+
399+
// Make sure that the first operand (indirect callee) matches the wrapped
400+
// LLVM IR function type, and that the types of the other call operands
401+
// match the types of the function arguments.
402+
if (parser.resolveOperand(operands[0], wrappedFuncType, result.operands) ||
403+
parser.resolveOperands(funcArguments, funcType.getInputs(),
404+
parser.getNameLoc(), result.operands))
405+
return failure();
406+
407+
result.addTypes(llvmResultType);
408+
}
409+
result.addSuccessor(normalDest, normalOperands);
410+
result.addSuccessor(unwindDest, unwindOperands);
411+
return success();
412+
}
413+
414+
///===----------------------------------------------------------------------===//
415+
/// Verifying/Printing/Parsing for LLVM::LandingpadOp.
416+
///===----------------------------------------------------------------------===//
417+
418+
static LogicalResult verify(LandingpadOp op) {
419+
Value value;
420+
421+
if (!op.cleanup() && op.getOperands().empty())
422+
return op.emitError("landingpad instruction expects at least one clause or "
423+
"cleanup attribute");
424+
425+
for (unsigned idx = 0, ie = op.getNumOperands(); idx < ie; idx++) {
426+
value = op.getOperand(idx);
427+
bool isFilter = value.getType().cast<LLVMType>().isArrayTy();
428+
if (isFilter) {
429+
// FIXME: Verify filter clauses when arrays are appropriately handled
430+
} else {
431+
// catch - global addresses only.
432+
// Bitcast ops should have global addresses as their args.
433+
if (auto bcOp = dyn_cast_or_null<BitcastOp>(value.getDefiningOp())) {
434+
if (auto addrOp =
435+
dyn_cast_or_null<AddressOfOp>(bcOp.arg().getDefiningOp()))
436+
continue;
437+
return op.emitError("constant clauses expected")
438+
.attachNote(bcOp.getLoc())
439+
<< "global addresses expected as operand to "
440+
"bitcast used in clauses for landingpad";
441+
}
442+
// NullOp and AddressOfOp allowed
443+
if (dyn_cast_or_null<NullOp>(value.getDefiningOp()))
444+
continue;
445+
if (dyn_cast_or_null<AddressOfOp>(value.getDefiningOp()))
446+
continue;
447+
return op.emitError("clause #")
448+
<< idx << " is not a known constant - null, addressof, bitcast";
449+
}
450+
}
451+
return success();
452+
}
453+
454+
static void printLandingpadOp(OpAsmPrinter &p, LandingpadOp &op) {
455+
p << op.getOperationName() << (op.cleanup() ? " cleanup " : " ");
456+
457+
// Clauses
458+
for (auto value : op.getOperands()) {
459+
// Similar to llvm - if clause is an array type then it is filter
460+
// clause else catch clause
461+
bool isArrayTy = value.getType().cast<LLVMType>().isArrayTy();
462+
p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
463+
<< value.getType() << ") ";
464+
}
465+
466+
p.printOptionalAttrDict(op.getAttrs(), {"cleanup"});
467+
468+
p << ": " << op.getType();
469+
}
470+
471+
/// <operation> ::= `llvm.landingpad` `cleanup`?
472+
/// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict?
473+
static ParseResult parseLandingpadOp(OpAsmParser &parser,
474+
OperationState &result) {
475+
// Check for cleanup
476+
if (succeeded(parser.parseOptionalKeyword("cleanup")))
477+
result.addAttribute("cleanup", parser.getBuilder().getUnitAttr());
478+
479+
// Parse clauses with types
480+
while (succeeded(parser.parseOptionalLParen()) &&
481+
(succeeded(parser.parseOptionalKeyword("filter")) ||
482+
succeeded(parser.parseOptionalKeyword("catch")))) {
483+
OpAsmParser::OperandType operand;
484+
Type ty;
485+
if (parser.parseOperand(operand) || parser.parseColon() ||
486+
parser.parseType(ty) ||
487+
parser.resolveOperand(operand, ty, result.operands) ||
488+
parser.parseRParen())
489+
return failure();
490+
}
491+
492+
Type type;
493+
if (parser.parseColon() || parser.parseType(type))
494+
return failure();
495+
496+
result.addTypes(type);
497+
return success();
498+
}
499+
275500
//===----------------------------------------------------------------------===//
276501
// Printing/parsing for LLVM::CallOp.
277502
//===----------------------------------------------------------------------===//

mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class Importer {
7676
/// `br` branches to `target`. Append the block arguments to attach to the
7777
/// generated branch op to `blockArguments`. These should be in the same order
7878
/// as the PHIs in `target`.
79-
LogicalResult processBranchArgs(llvm::BranchInst *br,
79+
LogicalResult processBranchArgs(llvm::Instruction *br,
8080
llvm::BasicBlock *target,
8181
SmallVectorImpl<Value> &blockArguments);
8282
/// Returns the standard type equivalent to be used in attributes for the
@@ -422,21 +422,26 @@ GlobalOp Importer::processGlobal(llvm::GlobalVariable *GV) {
422422
}
423423

424424
Value Importer::processConstant(llvm::Constant *c) {
425+
OpBuilder bEntry(currentEntryBlock, currentEntryBlock->begin());
425426
if (Attribute attr = getConstantAsAttr(c)) {
426427
// These constants can be represented as attributes.
427428
OpBuilder b(currentEntryBlock, currentEntryBlock->begin());
428429
LLVMType type = processType(c->getType());
429430
if (!type)
430431
return nullptr;
431-
return instMap[c] = b.create<ConstantOp>(unknownLoc, type, attr);
432+
return instMap[c] = bEntry.create<ConstantOp>(unknownLoc, type, attr);
432433
}
433434
if (auto *cn = dyn_cast<llvm::ConstantPointerNull>(c)) {
434-
OpBuilder b(currentEntryBlock, currentEntryBlock->begin());
435435
LLVMType type = processType(cn->getType());
436436
if (!type)
437437
return nullptr;
438-
return instMap[c] = b.create<NullOp>(unknownLoc, type);
438+
return instMap[c] = bEntry.create<NullOp>(unknownLoc, type);
439439
}
440+
if (auto *GV = dyn_cast<llvm::GlobalVariable>(c))
441+
return bEntry.create<AddressOfOp>(UnknownLoc::get(context),
442+
processGlobal(GV),
443+
ArrayRef<NamedAttribute>());
444+
440445
if (auto *ce = dyn_cast<llvm::ConstantExpr>(c)) {
441446
llvm::Instruction *i = ce->getAsInstruction();
442447
OpBuilder::InsertionGuard guard(b);
@@ -471,16 +476,6 @@ Value Importer::processValue(llvm::Value *value) {
471476
return unknownInstMap[value]->getResult(0);
472477
}
473478

474-
if (auto *GV = dyn_cast<llvm::GlobalVariable>(value)) {
475-
auto global = processGlobal(GV);
476-
if (!global)
477-
return nullptr;
478-
return b.create<AddressOfOp>(UnknownLoc::get(context), global,
479-
ArrayRef<NamedAttribute>());
480-
}
481-
482-
// Note, constant global variables are both GlobalVariables and Constants,
483-
// so we handle GlobalVariables first above.
484479
if (auto *c = dyn_cast<llvm::Constant>(value))
485480
return processConstant(c);
486481

@@ -570,7 +565,7 @@ static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) {
570565
// `br` branches to `target`. Return the branch arguments to `br`, in the
571566
// same order of the PHIs in `target`.
572567
LogicalResult
573-
Importer::processBranchArgs(llvm::BranchInst *br, llvm::BasicBlock *target,
568+
Importer::processBranchArgs(llvm::Instruction *br, llvm::BasicBlock *target,
574569
SmallVectorImpl<Value> &blockArguments) {
575570
for (auto inst = target->begin(); isa<llvm::PHINode>(inst); ++inst) {
576571
auto *PN = cast<llvm::PHINode>(&*inst);
@@ -719,6 +714,49 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
719714
v = op->getResult(0);
720715
return success();
721716
}
717+
case llvm::Instruction::LandingPad: {
718+
llvm::LandingPadInst *lpi = cast<llvm::LandingPadInst>(inst);
719+
SmallVector<Value, 4> ops;
720+
721+
for (unsigned i = 0, ie = lpi->getNumClauses(); i < ie; i++)
722+
ops.push_back(processConstant(lpi->getClause(i)));
723+
724+
b.create<LandingpadOp>(loc, processType(lpi->getType()), lpi->isCleanup(),
725+
ops);
726+
return success();
727+
}
728+
case llvm::Instruction::Invoke: {
729+
llvm::InvokeInst *ii = cast<llvm::InvokeInst>(inst);
730+
731+
SmallVector<Type, 2> tys;
732+
if (!ii->getType()->isVoidTy())
733+
tys.push_back(processType(inst->getType()));
734+
735+
SmallVector<Value, 4> ops;
736+
ops.reserve(inst->getNumOperands() + 1);
737+
for (auto &op : ii->arg_operands())
738+
ops.push_back(processValue(op.get()));
739+
740+
SmallVector<Value, 4> normalArgs, unwindArgs;
741+
processBranchArgs(ii, ii->getNormalDest(), normalArgs);
742+
processBranchArgs(ii, ii->getUnwindDest(), unwindArgs);
743+
744+
Operation *op;
745+
if (llvm::Function *callee = ii->getCalledFunction()) {
746+
op = b.create<InvokeOp>(loc, tys, b.getSymbolRefAttr(callee->getName()),
747+
ops, blocks[ii->getNormalDest()], normalArgs,
748+
blocks[ii->getUnwindDest()], unwindArgs);
749+
} else {
750+
ops.insert(ops.begin(), processValue(ii->getCalledValue()));
751+
op = b.create<InvokeOp>(loc, tys, ops, blocks[ii->getNormalDest()],
752+
normalArgs, blocks[ii->getUnwindDest()],
753+
unwindArgs);
754+
}
755+
756+
if (!ii->getType()->isVoidTy())
757+
v = op->getResult(0);
758+
return success();
759+
}
722760
case llvm::Instruction::GetElementPtr: {
723761
// FIXME: Support inbounds GEPs.
724762
llvm::GetElementPtrInst *gep = cast<llvm::GetElementPtrInst>(inst);

0 commit comments

Comments
 (0)