Skip to content

Commit 6667bc9

Browse files
FMarnoyuxuanchen1997
authored andcommitted
[mlir] Added new attributes to the llvm.call op in llvmir target (#99663)
Summary: The new attributes are: * convergent * no_unwind * will_return * memory effects Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60251463
1 parent 51169b0 commit 6667bc9

File tree

7 files changed

+205
-5
lines changed

7 files changed

+205
-5
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,12 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
652652
"{}">:$fastmathFlags,
653653
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
654654
DefaultValuedAttr<CConv, "CConv::C">:$CConv,
655-
DefaultValuedAttr<TailCallKind, "TailCallKind::None">:$TailCallKind);
655+
DefaultValuedAttr<TailCallKind, "TailCallKind::None">:$TailCallKind,
656+
OptionalAttr<LLVM_MemoryEffectsAttr>:$memory,
657+
OptionalAttr<UnitAttr>:$convergent,
658+
OptionalAttr<UnitAttr>:$no_unwind,
659+
OptionalAttr<UnitAttr>:$will_return
660+
);
656661
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
657662
let arguments = !con(args, aliasAttrs);
658663
let results = (outs Optional<LLVM_Type>:$result);

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -982,6 +982,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
982982
/*var_callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr,
983983
/*branch_weights=*/nullptr,
984984
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
985+
/*memory=*/nullptr,
986+
/*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
985987
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
986988
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
987989
}
@@ -1005,7 +1007,9 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
10051007
getCallOpVarCalleeType(calleeType), callee, args,
10061008
/*fastmathFlags=*/nullptr,
10071009
/*branch_weights=*/nullptr, /*CConv=*/nullptr,
1008-
/*TailCallKind=*/nullptr, /*access_groups=*/nullptr,
1010+
/*TailCallKind=*/nullptr, /*memory=*/nullptr, /*convergent=*/nullptr,
1011+
/*no_unwind=*/nullptr, /*will_return=*/nullptr,
1012+
/*access_groups=*/nullptr,
10091013
/*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
10101014
}
10111015

@@ -1015,7 +1019,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state,
10151019
getCallOpVarCalleeType(calleeType),
10161020
/*callee=*/nullptr, args,
10171021
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
1018-
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
1022+
/*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory=*/nullptr,
1023+
/*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
10191024
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
10201025
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
10211026
}
@@ -1026,7 +1031,8 @@ void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
10261031
build(builder, state, getCallOpResultTypes(calleeType),
10271032
getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), args,
10281033
/*fastmathFlags=*/nullptr, /*branch_weights=*/nullptr,
1029-
/*CConv=*/nullptr, /*TailCallKind=*/nullptr,
1034+
/*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory=*/nullptr,
1035+
/*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
10301036
/*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
10311037
/*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
10321038
}
@@ -1221,7 +1227,7 @@ void CallOp::print(OpAsmPrinter &p) {
12211227
if (getCConv() != LLVM::CConv::C)
12221228
p << stringifyCConv(getCConv()) << ' ';
12231229

1224-
if(getTailCallKind() != LLVM::TailCallKind::None)
1230+
if (getTailCallKind() != LLVM::TailCallKind::None)
12251231
p << tailcallkind::stringifyTailCallKind(getTailCallKind()) << ' ';
12261232

12271233
// Print the direct callee if present as a function attribute, or an indirect

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,25 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
219219
}
220220
call->setCallingConv(convertCConvToLLVM(callOp.getCConv()));
221221
call->setTailCallKind(convertTailCallKindToLLVM(callOp.getTailCallKind()));
222+
if (callOp.getConvergentAttr())
223+
call->addFnAttr(llvm::Attribute::Convergent);
224+
if (callOp.getNoUnwindAttr())
225+
call->addFnAttr(llvm::Attribute::NoUnwind);
226+
if (callOp.getWillReturnAttr())
227+
call->addFnAttr(llvm::Attribute::WillReturn);
228+
229+
if (MemoryEffectsAttr memAttr = callOp.getMemoryAttr()) {
230+
llvm::MemoryEffects memEffects =
231+
llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem,
232+
convertModRefInfoToLLVM(memAttr.getArgMem())) |
233+
llvm::MemoryEffects(
234+
llvm::MemoryEffects::Location::InaccessibleMem,
235+
convertModRefInfoToLLVM(memAttr.getInaccessibleMem())) |
236+
llvm::MemoryEffects(llvm::MemoryEffects::Location::Other,
237+
convertModRefInfoToLLVM(memAttr.getOther()));
238+
call->setMemoryEffects(memEffects);
239+
}
240+
222241
moduleTranslation.setAccessGroupsMetadata(callOp, call);
223242
moduleTranslation.setAliasScopeMetadata(callOp, call);
224243
moduleTranslation.setTBAAMetadata(callOp, call);

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,6 +1471,28 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
14711471
callOp.setTailCallKind(
14721472
convertTailCallKindFromLLVM(callInst->getTailCallKind()));
14731473
setFastmathFlagsAttr(inst, callOp);
1474+
1475+
// Handle function attributes.
1476+
if (callInst->hasFnAttr(llvm::Attribute::Convergent))
1477+
callOp.setConvergent(true);
1478+
if (callInst->hasFnAttr(llvm::Attribute::NoUnwind))
1479+
callOp.setNoUnwind(true);
1480+
if (callInst->hasFnAttr(llvm::Attribute::WillReturn))
1481+
callOp.setWillReturn(true);
1482+
1483+
llvm::MemoryEffects memEffects = callInst->getMemoryEffects();
1484+
ModRefInfo othermem = convertModRefInfoFromLLVM(
1485+
memEffects.getModRef(llvm::MemoryEffects::Location::Other));
1486+
ModRefInfo argMem = convertModRefInfoFromLLVM(
1487+
memEffects.getModRef(llvm::MemoryEffects::Location::ArgMem));
1488+
ModRefInfo inaccessibleMem = convertModRefInfoFromLLVM(
1489+
memEffects.getModRef(llvm::MemoryEffects::Location::InaccessibleMem));
1490+
auto memAttr = MemoryEffectsAttr::get(callOp.getContext(), othermem, argMem,
1491+
inaccessibleMem);
1492+
// Only set the attribute when it does not match the default value.
1493+
if (!memAttr.isReadWrite())
1494+
callOp.setMemoryAttr(memAttr);
1495+
14741496
if (!callInst->getType()->isVoidTy())
14751497
mapValue(inst, callOp.getResult());
14761498
else

mlir/test/Dialect/LLVMIR/roundtrip.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
22

3+
4+
// CHECK-LABEL: func @baz
5+
// something to call
6+
llvm.func @baz()
7+
38
// CHECK-LABEL: func @ops
49
// CHECK-SAME: (%[[I32:.*]]: i32, %[[FLOAT:.*]]: f32, %[[PTR1:.*]]: !llvm.ptr, %[[PTR2:.*]]: !llvm.ptr, %[[BOOL:.*]]: i1, %[[VPTR1:.*]]: !llvm.vec<2 x ptr>)
510
func.func @ops(%arg0: i32, %arg1: f32,
@@ -93,6 +98,19 @@ func.func @ops(%arg0: i32, %arg1: f32,
9398
llvm.call %variadic_func(%arg0, %arg0) vararg(!llvm.func<void (i32, ...)>) : !llvm.ptr, (i32, i32) -> ()
9499
llvm.call %variadic_func(%arg0, %arg0) vararg(!llvm.func<void (i32, ...)>) {fastmathFlags = #llvm.fastmath<fast>} : !llvm.ptr, (i32, i32) -> ()
95100

101+
// Function call attributes
102+
// CHECK: llvm.call @baz() {convergent} : () -> ()
103+
llvm.call @baz() {convergent} : () -> ()
104+
105+
// CHECK: llvm.call @baz() {no_unwind} : () -> ()
106+
llvm.call @baz() {no_unwind} : () -> ()
107+
108+
// CHECK: llvm.call @baz() {will_return} : () -> ()
109+
llvm.call @baz() {will_return} : () -> ()
110+
111+
// CHECK: llvm.call @baz() {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = write>} : () -> ()
112+
llvm.call @baz() {memory = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = write>} : () -> ()
113+
96114
// Terminator operations and their successors.
97115
//
98116
// CHECK: llvm.br ^[[BB1:.*]]

mlir/test/Target/LLVMIR/Import/instructions.ll

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,64 @@ define void @varargs_call(i32 %0) {
528528

529529
; // -----
530530

531+
; CHECK: llvm.func @f()
532+
declare void @f()
533+
534+
; CHECK-LABEL: @call_convergent
535+
define void @call_convergent() {
536+
; CHECK: llvm.call @f() {convergent}
537+
call void @f() convergent
538+
ret void
539+
}
540+
541+
; // -----
542+
543+
; CHECK: llvm.func @f()
544+
declare void @f()
545+
546+
; CHECK-LABEL: @call_no_unwind
547+
define void @call_no_unwind() {
548+
; CHECK: llvm.call @f() {no_unwind}
549+
call void @f() nounwind
550+
ret void
551+
}
552+
553+
; // -----
554+
555+
; CHECK: llvm.func @f()
556+
declare void @f()
557+
558+
; CHECK-LABEL: @call_will_return
559+
define void @call_will_return() {
560+
; CHECK: llvm.call @f() {will_return}
561+
call void @f() willreturn
562+
ret void
563+
}
564+
565+
; // -----
566+
567+
; CHECK: llvm.func @f()
568+
declare void @f()
569+
570+
; CHECK-LABEL: @call_memory_effects
571+
define void @call_memory_effects() {
572+
; CHECK: llvm.call @f() {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>}
573+
call void @f() memory(none)
574+
; CHECK: llvm.call @f() {memory = #llvm.memory_effects<other = none, argMem = write, inaccessibleMem = read>}
575+
call void @f() memory(none, argmem: write, inaccessiblemem: read)
576+
; CHECK: llvm.call @f() {memory = #llvm.memory_effects<other = write, argMem = none, inaccessibleMem = write>}
577+
call void @f() memory(write, argmem: none)
578+
; CHECK: llvm.call @f() {memory = #llvm.memory_effects<other = readwrite, argMem = readwrite, inaccessibleMem = read>}
579+
call void @f() memory(readwrite, inaccessiblemem: read)
580+
; CHECK: llvm.call @f()
581+
; CHECK-NOT: #llvm.memory_effects
582+
; CHECK-SAME: : () -> ()
583+
call void @f() memory(readwrite)
584+
ret void
585+
}
586+
587+
; // -----
588+
531589
%sub_struct = type { i32, i8 }
532590
%my_struct = type { %sub_struct, [4 x i32] }
533591

mlir/test/Target/LLVMIR/llvmir.mlir

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2477,3 +2477,75 @@ llvm.func @willreturn() attributes { will_return } {
24772477

24782478
// CHECK: #[[ATTRS]]
24792479
// CHECK-SAME: willreturn
2480+
2481+
// -----
2482+
2483+
llvm.func @f()
2484+
2485+
// CHECK-LABEL: @convergent_call
2486+
// CHECK: call void @f() #[[ATTRS:[0-9]+]]
2487+
llvm.func @convergent_call() {
2488+
llvm.call @f() {convergent} : () -> ()
2489+
llvm.return
2490+
}
2491+
2492+
// CHECK: #[[ATTRS]]
2493+
// CHECK-SAME: convergent
2494+
2495+
// -----
2496+
2497+
llvm.func @f()
2498+
2499+
// CHECK-LABEL: @nounwind_call
2500+
// CHECK: call void @f() #[[ATTRS:[0-9]+]]
2501+
llvm.func @nounwind_call() {
2502+
llvm.call @f() {no_unwind} : () -> ()
2503+
llvm.return
2504+
}
2505+
2506+
// CHECK: #[[ATTRS]]
2507+
// CHECK-SAME: nounwind
2508+
2509+
// -----
2510+
2511+
llvm.func @f()
2512+
2513+
// CHECK-LABEL: @willreturn_call
2514+
// CHECK: call void @f() #[[ATTRS:[0-9]+]]
2515+
llvm.func @willreturn_call() {
2516+
llvm.call @f() {will_return} : () -> ()
2517+
llvm.return
2518+
}
2519+
2520+
// CHECK: #[[ATTRS]]
2521+
// CHECK-SAME: willreturn
2522+
2523+
// -----
2524+
2525+
llvm.func @fa()
2526+
llvm.func @fb()
2527+
llvm.func @fc()
2528+
llvm.func @fd()
2529+
2530+
// CHECK-LABEL: @mem_none_call
2531+
// CHECK: call void @fa() #[[ATTRS_0:[0-9]+]]
2532+
// CHECK: call void @fb() #[[ATTRS_1:[0-9]+]]
2533+
// CHECK: call void @fc() #[[ATTRS_2:[0-9]+]]
2534+
// CHECK: call void @fd() #[[ATTRS_3:[0-9]+]]
2535+
llvm.func @mem_none_call() {
2536+
llvm.call @fa() {memory = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>} : () -> ()
2537+
llvm.call @fb() {memory = #llvm.memory_effects<other = read, argMem = none, inaccessibleMem = write>} : () -> ()
2538+
llvm.call @fc() {memory = #llvm.memory_effects<other = read, argMem = read, inaccessibleMem = write>} : () -> ()
2539+
llvm.call @fd() {memory = #llvm.memory_effects<other = readwrite, argMem = read, inaccessibleMem = readwrite>} : () -> ()
2540+
llvm.return
2541+
2542+
}
2543+
2544+
// CHECK: #[[ATTRS_0]]
2545+
// CHECK-SAME: memory(none)
2546+
// CHECK: #[[ATTRS_1]]
2547+
// CHECK-SAME: memory(read, argmem: none, inaccessiblemem: write)
2548+
// CHECK: #[[ATTRS_2]]
2549+
// CHECK-SAME: memory(read, inaccessiblemem: write)
2550+
// CHECK: #[[ATTRS_3]]
2551+
// CHECK-SAME: memory(readwrite, argmem: read)

0 commit comments

Comments
 (0)