Skip to content

Commit 9c6c868

Browse files
marbremgehre-amd
authored andcommitted
[mlir][emitc] Add a declare_func operation (llvm#80297)
This adds the `emitc.declare_func` operation that allows to emit the declaration of an `emitc.func` at a specific location.
1 parent ed20cea commit 9c6c868

File tree

6 files changed

+128
-6
lines changed

6 files changed

+128
-6
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,48 @@ def EmitC_CallOp : EmitC_Op<"call",
460460
}];
461461
}
462462

463+
def EmitC_DeclareFuncOp : EmitC_Op<"declare_func", [
464+
DeclareOpInterfaceMethods<SymbolUserOpInterface>
465+
]> {
466+
let summary = "An operation to declare a function";
467+
let description = [{
468+
The `declare_func` operation allows to insert a function declaration for an
469+
`emitc.func` at a specific position. The operation only requires the `callee`
470+
of the `emitc.func` to be specified as an attribute.
471+
472+
Example:
473+
474+
```mlir
475+
emitc.declare_func @bar
476+
emitc.func @foo(%arg0: i32) -> i32 {
477+
%0 = emitc.call @bar(%arg0) : (i32) -> (i32)
478+
emitc.return %0 : i32
479+
}
480+
481+
emitc.func @bar(%arg0: i32) -> i32 {
482+
emitc.return %arg0 : i32
483+
}
484+
```
485+
486+
```c++
487+
// Code emitted for the operations above.
488+
int32_t bar(int32_t v1);
489+
int32_t foo(int32_t v1) {
490+
int32_t v2 = bar(v1);
491+
return v2;
492+
}
493+
494+
int32_t bar(int32_t v1) {
495+
return v1;
496+
}
497+
```
498+
}];
499+
let arguments = (ins FlatSymbolRefAttr:$sym_name);
500+
let assemblyFormat = [{
501+
$sym_name attr-dict
502+
}];
503+
}
504+
463505
def EmitC_FuncOp : EmitC_Op<"func", [
464506
AutomaticAllocationScope,
465507
FunctionOpInterface, IsolatedFromAbove

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,24 @@ FunctionType CallOp::getCalleeType() {
394394
return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
395395
}
396396

397+
//===----------------------------------------------------------------------===//
398+
// DeclareFuncOp
399+
//===----------------------------------------------------------------------===//
400+
401+
LogicalResult
402+
DeclareFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
403+
// Check that the sym_name attribute was specified.
404+
auto fnAttr = getSymNameAttr();
405+
if (!fnAttr)
406+
return emitOpError("requires a 'sym_name' symbol reference attribute");
407+
FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
408+
if (!fn)
409+
return emitOpError() << "'" << fnAttr.getValue()
410+
<< "' does not reference a valid function";
411+
412+
return success();
413+
}
414+
397415
//===----------------------------------------------------------------------===//
398416
// FuncOp
399417
//===----------------------------------------------------------------------===//

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/IR/BuiltinTypes.h"
1515
#include "mlir/IR/Dialect.h"
1616
#include "mlir/IR/Operation.h"
17+
#include "mlir/IR/SymbolTable.h"
1718
#include "mlir/Support/IndentedOstream.h"
1819
#include "mlir/Support/LLVM.h"
1920
#include "mlir/Target/Cpp/CppEmitter.h"
@@ -870,8 +871,9 @@ static LogicalResult printFunctionBody(CppEmitter &emitter,
870871
// needs to be printed after the closing brace.
871872
// When generating code for an emitc.for and emitc.verbatim op, printing a
872873
// trailing semicolon is handled within the printOperation function.
873-
bool trailingSemicolon = !isa<cf::CondBranchOp, emitc::ForOp, emitc::IfOp,
874-
emitc::LiteralOp, emitc::VerbatimOp>(op);
874+
bool trailingSemicolon =
875+
!isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::ForOp,
876+
emitc::IfOp, emitc::LiteralOp, emitc::VerbatimOp>(op);
875877

876878
if (failed(emitter.emitOperation(
877879
op, /*trailingSemicolon=*/trailingSemicolon)))
@@ -953,6 +955,37 @@ static LogicalResult printOperation(CppEmitter &emitter,
953955
return success();
954956
}
955957

958+
static LogicalResult printOperation(CppEmitter &emitter,
959+
DeclareFuncOp declareFuncOp) {
960+
CppEmitter::Scope scope(emitter);
961+
raw_indented_ostream &os = emitter.ostream();
962+
963+
auto functionOp = SymbolTable::lookupNearestSymbolFrom<emitc::FuncOp>(
964+
declareFuncOp, declareFuncOp.getSymNameAttr());
965+
966+
if (!functionOp)
967+
return failure();
968+
969+
if (functionOp.getSpecifiers()) {
970+
for (Attribute specifier : functionOp.getSpecifiersAttr()) {
971+
os << cast<StringAttr>(specifier).str() << " ";
972+
}
973+
}
974+
975+
if (failed(emitter.emitTypes(functionOp.getLoc(),
976+
functionOp.getFunctionType().getResults())))
977+
return failure();
978+
os << " " << functionOp.getName();
979+
980+
os << "(";
981+
Operation *operation = functionOp.getOperation();
982+
if (failed(printFunctionArgs(emitter, operation, functionOp.getArguments())))
983+
return failure();
984+
os << ");";
985+
986+
return success();
987+
}
988+
956989
CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
957990
: os(os), declareVariablesAtTop(declareVariablesAtTop) {
958991
valueInScopeCount.push(0);
@@ -1285,10 +1318,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
12851318
// EmitC ops.
12861319
.Case<emitc::AddOp, emitc::ApplyOp, emitc::AssignOp, emitc::CallOp,
12871320
emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
1288-
emitc::ConstantOp, emitc::DivOp, emitc::ExpressionOp,
1289-
emitc::ForOp, emitc::FuncOp, emitc::IfOp, emitc::IncludeOp,
1290-
emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
1291-
emitc::SubscriptOp, emitc::VariableOp, emitc::VerbatimOp>(
1321+
emitc::ConstantOp, emitc::DeclareFuncOp, emitc::DivOp,
1322+
emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp, emitc::IfOp,
1323+
emitc::IncludeOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
1324+
emitc::SubOp, emitc::SubscriptOp, emitc::VariableOp,
1325+
emitc::VerbatimOp>(
12921326
[&](auto op) { return printOperation(*this, op); })
12931327
// Func ops.
12941328
.Case<func::CallOp, func::ConstantOp, func::FuncOp, func::ReturnOp>(

mlir/test/Dialect/EmitC/invalid_ops.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,3 +329,13 @@ func.func @return_inside_func.func(%0: i32) -> (i32) {
329329

330330
// expected-error@+1 {{expected non-function type}}
331331
emitc.func @func_variadic(...)
332+
333+
// -----
334+
335+
// expected-error@+1 {{'emitc.declare_func' op 'bar' does not reference a valid function}}
336+
emitc.declare_func @bar
337+
338+
// -----
339+
340+
// expected-error@+1 {{'emitc.declare_func' op requires attribute 'sym_name'}}
341+
"emitc.declare_func"() : () -> ()

mlir/test/Dialect/EmitC/ops.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ func.func @f(%arg0: i32, %f: !emitc.opaque<"int32_t">) {
1515
return
1616
}
1717

18+
emitc.declare_func @func
19+
1820
emitc.func @func(%arg0 : i32) {
1921
emitc.call_opaque "foo"(%arg0) : (i32) -> ()
2022
emitc.return
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
2+
3+
// CHECK: int32_t bar(int32_t [[V1:[^ ]*]]);
4+
emitc.declare_func @bar
5+
// CHECK: int32_t bar(int32_t [[V1:[^ ]*]]) {
6+
emitc.func @bar(%arg0: i32) -> i32 {
7+
emitc.return %arg0 : i32
8+
}
9+
10+
11+
// CHECK: static inline int32_t foo(int32_t [[V1:[^ ]*]]);
12+
emitc.declare_func @foo
13+
// CHECK: static inline int32_t foo(int32_t [[V1:[^ ]*]]) {
14+
emitc.func @foo(%arg0: i32) -> i32 attributes {specifiers = ["static","inline"]} {
15+
emitc.return %arg0 : i32
16+
}

0 commit comments

Comments
 (0)