Skip to content

Commit d521324

Browse files
Simon Camphausenmarbre
andauthored
[mlir][EmitC] Unify handling of operations which are emitted in a deferred way (llvm#97804)
Several operations from the EmitC dialect don't produce output directly during emission, but rather when being used as an operand. These changes unify the handling of such operations and fix a bug in the emission of global ops. Co-authored-by: Marius Brehler <[email protected]>
1 parent b841e2e commit d521324

File tree

5 files changed

+122
-64
lines changed

5 files changed

+122
-64
lines changed

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,10 @@ LogicalResult emitc::AssignOp::verify() {
213213
Value variable = getVar();
214214
Operation *variableDef = variable.getDefiningOp();
215215
if (!variableDef ||
216-
!llvm::isa<emitc::VariableOp, emitc::SubscriptOp>(variableDef))
216+
!llvm::isa<emitc::GetGlobalOp, emitc::SubscriptOp, emitc::VariableOp>(
217+
variableDef))
217218
return emitOpError() << "requires first operand (" << variable
218-
<< ") to be a Variable or subscript";
219+
<< ") to be a get_global, subscript or variable";
219220

220221
Value value = getValue();
221222
if (variable.getType() != value.getType())

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 43 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ struct CppEmitter {
174174
/// Emit an expression as a C expression.
175175
LogicalResult emitExpression(ExpressionOp expressionOp);
176176

177+
/// Insert the expression representing the operation into the value cache.
178+
void cacheDeferredOpResult(Value value, StringRef str);
179+
177180
/// Return the existing or a new name for a Value.
178181
StringRef getOrCreateName(Value val);
179182

@@ -273,6 +276,12 @@ struct CppEmitter {
273276
};
274277
} // namespace
275278

279+
/// Determine whether expression \p op should be emitted in a deferred way.
280+
static bool hasDeferredEmission(Operation *op) {
281+
return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp,
282+
emitc::SubscriptOp>(op);
283+
}
284+
276285
/// Determine whether expression \p expressionOp should be emitted inline, i.e.
277286
/// as part of its user. This function recommends inlining of any expressions
278287
/// that can be inlined unless it is used by another expression, under the
@@ -295,10 +304,9 @@ static bool shouldBeInlined(ExpressionOp expressionOp) {
295304

296305
Operation *user = *result.getUsers().begin();
297306

298-
// Do not inline expressions used by subscript operations, since the
299-
// way the subscript operation translation is implemented requires that
300-
// variables be materialized.
301-
if (isa<emitc::SubscriptOp>(user))
307+
// Do not inline expressions used by operations with deferred emission, since
308+
// their translation requires the materialization of variables.
309+
if (hasDeferredEmission(user))
302310
return false;
303311

304312
// Do not inline expressions used by ops with the CExpression trait. If this
@@ -370,20 +378,6 @@ static LogicalResult printOperation(CppEmitter &emitter,
370378
return emitter.emitOperand(assignOp.getValue());
371379
}
372380

373-
static LogicalResult printOperation(CppEmitter &emitter,
374-
emitc::GetGlobalOp op) {
375-
// Add name to cache so that `hasValueInScope` works.
376-
emitter.getOrCreateName(op.getResult());
377-
return success();
378-
}
379-
380-
static LogicalResult printOperation(CppEmitter &emitter,
381-
emitc::SubscriptOp subscriptOp) {
382-
// Add name to cache so that `hasValueInScope` works.
383-
emitter.getOrCreateName(subscriptOp.getResult());
384-
return success();
385-
}
386-
387381
static LogicalResult printBinaryOperation(CppEmitter &emitter,
388382
Operation *operation,
389383
StringRef binaryOperator) {
@@ -621,9 +615,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
621615
if (t.getType().isIndex()) {
622616
int64_t idx = t.getInt();
623617
Value operand = op.getOperand(idx);
624-
auto literalDef =
625-
dyn_cast_if_present<LiteralOp>(operand.getDefiningOp());
626-
if (!literalDef && !emitter.hasValueInScope(operand))
618+
if (!emitter.hasValueInScope(operand))
627619
return op.emitOpError("operand ")
628620
<< idx << "'s value not defined in scope";
629621
os << emitter.getOrCreateName(operand);
@@ -948,8 +940,7 @@ static LogicalResult printFunctionBody(CppEmitter &emitter,
948940
// regions.
949941
WalkResult result =
950942
functionOp->walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
951-
if (isa<emitc::LiteralOp>(op) ||
952-
isa<emitc::ExpressionOp>(op->getParentOp()) ||
943+
if (isa<emitc::ExpressionOp>(op->getParentOp()) ||
953944
(isa<emitc::ExpressionOp>(op) &&
954945
shouldBeInlined(cast<emitc::ExpressionOp>(op))))
955946
return WalkResult::skip();
@@ -1001,7 +992,7 @@ static LogicalResult printFunctionBody(CppEmitter &emitter,
1001992
// trailing semicolon is handled within the printOperation function.
1002993
bool trailingSemicolon =
1003994
!isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::ForOp,
1004-
emitc::IfOp, emitc::LiteralOp, emitc::VerbatimOp>(op);
995+
emitc::IfOp, emitc::VerbatimOp>(op);
1005996

1006997
if (failed(emitter.emitOperation(
1007998
op, /*trailingSemicolon=*/trailingSemicolon)))
@@ -1134,20 +1125,18 @@ std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
11341125
return out;
11351126
}
11361127

1128+
void CppEmitter::cacheDeferredOpResult(Value value, StringRef str) {
1129+
if (!valueMapper.count(value))
1130+
valueMapper.insert(value, str.str());
1131+
}
1132+
11371133
/// Return the existing or a new name for a Value.
11381134
StringRef CppEmitter::getOrCreateName(Value val) {
1139-
if (auto literal = dyn_cast_if_present<emitc::LiteralOp>(val.getDefiningOp()))
1140-
return literal.getValue();
11411135
if (!valueMapper.count(val)) {
1142-
if (auto subscript =
1143-
dyn_cast_if_present<emitc::SubscriptOp>(val.getDefiningOp())) {
1144-
valueMapper.insert(val, getSubscriptName(subscript));
1145-
} else if (auto getGlobal = dyn_cast_if_present<emitc::GetGlobalOp>(
1146-
val.getDefiningOp())) {
1147-
valueMapper.insert(val, getGlobal.getName().str());
1148-
} else {
1149-
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
1150-
}
1136+
assert(!hasDeferredEmission(val.getDefiningOp()) &&
1137+
"cacheDeferredOpResult should have been called on this value, "
1138+
"update the emitOperation function.");
1139+
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
11511140
}
11521141
return *valueMapper.begin(val);
11531142
}
@@ -1341,9 +1330,6 @@ LogicalResult CppEmitter::emitOperand(Value value) {
13411330
if (expressionOp && shouldBeInlined(expressionOp))
13421331
return emitExpression(expressionOp);
13431332

1344-
auto literalOp = dyn_cast_if_present<LiteralOp>(value.getDefiningOp());
1345-
if (!literalOp && !hasValueInScope(value))
1346-
return failure();
13471333
os << getOrCreateName(value);
13481334
return success();
13491335
}
@@ -1399,7 +1385,7 @@ LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
13991385

14001386
LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
14011387
bool trailingSemicolon) {
1402-
if (isa<emitc::SubscriptOp>(result.getDefiningOp()))
1388+
if (hasDeferredEmission(result.getDefiningOp()))
14031389
return success();
14041390
if (hasValueInScope(result)) {
14051391
return result.getDefiningOp()->emitError(
@@ -1498,24 +1484,35 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
14981484
emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
14991485
emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
15001486
emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
1501-
emitc::GlobalOp, emitc::GetGlobalOp, emitc::IfOp,
1502-
emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
1503-
emitc::LogicalOrOp, emitc::MulOp, emitc::RemOp, emitc::ReturnOp,
1504-
emitc::SubOp, emitc::SubscriptOp, emitc::UnaryMinusOp,
1505-
emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>(
1487+
emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp,
1488+
emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp,
1489+
emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
1490+
emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
1491+
emitc::VerbatimOp>(
15061492
[&](auto op) { return printOperation(*this, op); })
15071493
// Func ops.
15081494
.Case<func::CallOp, func::FuncOp, func::ReturnOp>(
15091495
[&](auto op) { return printOperation(*this, op); })
1510-
.Case<emitc::LiteralOp>([&](auto op) { return success(); })
1496+
.Case<emitc::GetGlobalOp>([&](auto op) {
1497+
cacheDeferredOpResult(op.getResult(), op.getName());
1498+
return success();
1499+
})
1500+
.Case<emitc::LiteralOp>([&](auto op) {
1501+
cacheDeferredOpResult(op.getResult(), op.getValue());
1502+
return success();
1503+
})
1504+
.Case<emitc::SubscriptOp>([&](auto op) {
1505+
cacheDeferredOpResult(op.getResult(), getSubscriptName(op));
1506+
return success();
1507+
})
15111508
.Default([&](Operation *) {
15121509
return op.emitOpError("unable to find printer for op");
15131510
});
15141511

15151512
if (failed(status))
15161513
return failure();
15171514

1518-
if (isa<emitc::LiteralOp, emitc::SubscriptOp, emitc::GetGlobalOp>(op))
1515+
if (hasDeferredEmission(&op))
15191516
return success();
15201517

15211518
if (getEmittedExpression() ||

mlir/test/Dialect/EmitC/invalid_ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ func.func @test_misplaced_yield() {
235235
// -----
236236

237237
func.func @test_assign_to_non_variable(%arg1: f32, %arg2: f32) {
238-
// expected-error @+1 {{'emitc.assign' op requires first operand (<block argument> of type 'f32' at index: 1) to be a Variable or subscript}}
238+
// expected-error @+1 {{'emitc.assign' op requires first operand (<block argument> of type 'f32' at index: 1) to be a get_global, subscript or variable}}
239239
emitc.assign %arg1 : f32 to %arg2 : f32
240240
return
241241
}

mlir/test/Dialect/EmitC/ops.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,9 @@ func.func @use_global(%i: index) -> f32 {
248248
%1 = emitc.subscript %0[%i] : (!emitc.array<2xf32>, index) -> f32
249249
return %1 : f32
250250
}
251+
252+
func.func @assign_global(%arg0 : i32) {
253+
%0 = emitc.get_global @myglobal_int : i32
254+
emitc.assign %arg0 : i32 to %0 : i32
255+
return
256+
}

mlir/test/Target/Cpp/global.mlir

Lines changed: 69 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,92 @@
1-
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s
2-
// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s
1+
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT
2+
// RUN: mlir-translate -mlir-to-cpp -declare-variables-at-top %s | FileCheck %s -check-prefix=CPP-DECLTOP
33

44
emitc.global extern @decl : i8
5-
// CHECK: extern int8_t decl;
5+
// CPP-DEFAULT: extern int8_t decl;
6+
// CPP-DECLTOP: extern int8_t decl;
67

78
emitc.global @uninit : i32
8-
// CHECK: int32_t uninit;
9+
// CPP-DEFAULT: int32_t uninit;
10+
// CPP-DECLTOP: int32_t uninit;
911

1012
emitc.global @myglobal_int : i32 = 4
11-
// CHECK: int32_t myglobal_int = 4;
13+
// CPP-DEFAULT: int32_t myglobal_int = 4;
14+
// CPP-DECLTOP: int32_t myglobal_int = 4;
1215

1316
emitc.global @myglobal : !emitc.array<2xf32> = dense<4.000000e+00>
14-
// CHECK: float myglobal[2] = {4.000000000e+00f, 4.000000000e+00f};
17+
// CPP-DEFAULT: float myglobal[2] = {4.000000000e+00f, 4.000000000e+00f};
18+
// CPP-DECLTOP: float myglobal[2] = {4.000000000e+00f, 4.000000000e+00f};
1519

1620
emitc.global const @myconstant : !emitc.array<2xi16> = dense<2>
17-
// CHECK: const int16_t myconstant[2] = {2, 2};
21+
// CPP-DEFAULT: const int16_t myconstant[2] = {2, 2};
22+
// CPP-DECLTOP: const int16_t myconstant[2] = {2, 2};
1823

1924
emitc.global extern const @extern_constant : !emitc.array<2xi16>
20-
// CHECK: extern const int16_t extern_constant[2];
25+
// CPP-DEFAULT: extern const int16_t extern_constant[2];
26+
// CPP-DECLTOP: extern const int16_t extern_constant[2];
2127

2228
emitc.global static @static_var : f32
23-
// CHECK: static float static_var;
29+
// CPP-DEFAULT: static float static_var;
30+
// CPP-DECLTOP: static float static_var;
2431

2532
emitc.global static @static_const : f32 = 3.0
26-
// CHECK: static float static_const = 3.000000000e+00f;
33+
// CPP-DEFAULT: static float static_const = 3.000000000e+00f;
34+
// CPP-DECLTOP: static float static_const = 3.000000000e+00f;
2735

2836
emitc.global @opaque_init : !emitc.opaque<"char"> = #emitc.opaque<"CHAR_MIN">
29-
// CHECK: char opaque_init = CHAR_MIN;
37+
// CPP-DEFAULT: char opaque_init = CHAR_MIN;
38+
// CPP-DECLTOP: char opaque_init = CHAR_MIN;
3039

31-
func.func @use_global(%i: index) -> f32 {
40+
func.func @use_global_scalar_read() -> i32 {
41+
%0 = emitc.get_global @myglobal_int : i32
42+
return %0 : i32
43+
}
44+
// CPP-DEFAULT-LABEL: int32_t use_global_scalar_read()
45+
// CPP-DEFAULT-NEXT: return myglobal_int;
46+
47+
// CPP-DECLTOP-LABEL: int32_t use_global_scalar_read()
48+
// CPP-DECLTOP-NEXT: return myglobal_int;
49+
50+
func.func @use_global_scalar_write(%arg0 : i32) {
51+
%0 = emitc.get_global @myglobal_int : i32
52+
emitc.assign %arg0 : i32 to %0 : i32
53+
return
54+
}
55+
// CPP-DEFAULT-LABEL: void use_global_scalar_write
56+
// CPP-DEFAULT-SAME: (int32_t [[V1:.*]])
57+
// CPP-DEFAULT-NEXT: myglobal_int = [[V1]];
58+
// CPP-DEFAULT-NEXT: return;
59+
60+
// CPP-DECLTOP-LABEL: void use_global_scalar_write
61+
// CPP-DECLTOP-SAME: (int32_t [[V1:.*]])
62+
// CPP-DECLTOP-NEXT: myglobal_int = [[V1]];
63+
// CPP-DECLTOP-NEXT: return;
64+
65+
func.func @use_global_array_read(%i: index) -> f32 {
3266
%0 = emitc.get_global @myglobal : !emitc.array<2xf32>
3367
%1 = emitc.subscript %0[%i] : (!emitc.array<2xf32>, index) -> f32
3468
return %1 : f32
35-
// CHECK-LABEL: use_global
36-
// CHECK-SAME: (size_t [[V1:.*]])
37-
// CHECK: return myglobal[[[V1]]];
3869
}
70+
// CPP-DEFAULT-LABEL: float use_global_array_read
71+
// CPP-DEFAULT-SAME: (size_t [[V1:.*]])
72+
// CPP-DEFAULT-NEXT: return myglobal[[[V1]]];
73+
74+
// CPP-DECLTOP-LABEL: float use_global_array_read
75+
// CPP-DECLTOP-SAME: (size_t [[V1:.*]])
76+
// CPP-DECLTOP-NEXT: return myglobal[[[V1]]];
77+
78+
func.func @use_global_array_write(%i: index, %val : f32) {
79+
%0 = emitc.get_global @myglobal : !emitc.array<2xf32>
80+
%1 = emitc.subscript %0[%i] : (!emitc.array<2xf32>, index) -> f32
81+
emitc.assign %val : f32 to %1 : f32
82+
return
83+
}
84+
// CPP-DEFAULT-LABEL: void use_global_array_write
85+
// CPP-DEFAULT-SAME: (size_t [[V1:.*]], float [[V2:.*]])
86+
// CPP-DEFAULT-NEXT: myglobal[[[V1]]] = [[V2]];
87+
// CPP-DEFAULT-NEXT: return;
88+
89+
// CPP-DECLTOP-LABEL: void use_global_array_write
90+
// CPP-DECLTOP-SAME: (size_t [[V1:.*]], float [[V2:.*]])
91+
// CPP-DECLTOP-NEXT: myglobal[[[V1]]] = [[V2]];
92+
// CPP-DECLTOP-NEXT: return;

0 commit comments

Comments
 (0)