Skip to content

Commit a30c983

Browse files
author
Simon Camphausen
committed
Emit emitc.apply "*" in a deferred way to allow assignment through pointers.
1 parent 6d085d7 commit a30c983

File tree

2 files changed

+76
-49
lines changed

2 files changed

+76
-49
lines changed

mlir/lib/Target/Cpp/TranslateToCpp.cpp

Lines changed: 75 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ struct CppEmitter {
174174
/// Emit an expression as a C expression.
175175
LogicalResult emitExpression(ExpressionOp expressionOp);
176176

177+
/// Insert the expression representing the operation into the value cache.
178+
LogicalResult cacheDeferredOpResult(Operation *op);
179+
177180
/// Return the existing or a new name for a Value.
178181
StringRef getOrCreateName(Value val);
179182

@@ -273,6 +276,18 @@ struct CppEmitter {
273276
};
274277
} // namespace
275278

279+
/// Determine whether expression \p op should be emitted in a deferred way.
280+
static bool hasDeferredEmission(Operation *op) {
281+
if (isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp, emitc::SubscriptOp>(
282+
op))
283+
return true;
284+
285+
if (auto applyOp = dyn_cast_or_null<emitc::ApplyOp>(op))
286+
return applyOp.getApplicableOperator() == "*";
287+
288+
return false;
289+
}
290+
276291
/// Determine whether expression \p expressionOp should be emitted inline, i.e.
277292
/// as part of its user. This function recommends inlining of any expressions
278293
/// that can be inlined unless it is used by another expression, under the
@@ -295,10 +310,10 @@ static bool shouldBeInlined(ExpressionOp expressionOp) {
295310

296311
Operation *user = *result.getUsers().begin();
297312

298-
// Do not inline expressions used by subscript operations, since the
299-
// way the subscript operation translation is implemented requires that
300-
// variables be materialized.
301-
if (isa<emitc::SubscriptOp, emitc::GetGlobalOp>(user))
313+
// Do not inline expressions used by operations with deferred emission, since
314+
// the way their translation is implemented requires that variables be
315+
// materialized.
316+
if (hasDeferredEmission(user))
302317
return false;
303318

304319
// Do not inline expressions used by ops with the CExpression trait. If this
@@ -370,13 +385,6 @@ static LogicalResult printOperation(CppEmitter &emitter,
370385
return emitter.emitOperand(assignOp.getValue());
371386
}
372387

373-
static LogicalResult printOperation(CppEmitter &emitter,
374-
emitc::GetGlobalOp op) {
375-
// Add name to cache so that `hasValueInScope` works.
376-
emitter.getOrCreateName(op.getResult());
377-
return success();
378-
}
379-
380388
static LogicalResult printOperation(CppEmitter &emitter,
381389
emitc::LValueLoadOp lValueLoadOp) {
382390
if (failed(emitter.emitAssignPrefix(*lValueLoadOp)))
@@ -385,13 +393,6 @@ static LogicalResult printOperation(CppEmitter &emitter,
385393
return emitter.emitOperand(lValueLoadOp.getOperand());
386394
}
387395

388-
static LogicalResult printOperation(CppEmitter &emitter,
389-
emitc::SubscriptOp subscriptOp) {
390-
// Add name to cache so that `hasValueInScope` works.
391-
emitter.getOrCreateName(subscriptOp.getResult());
392-
return success();
393-
}
394-
395396
static LogicalResult printBinaryOperation(CppEmitter &emitter,
396397
Operation *operation,
397398
StringRef binaryOperator) {
@@ -629,9 +630,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
629630
if (t.getType().isIndex()) {
630631
int64_t idx = t.getInt();
631632
Value operand = op.getOperand(idx);
632-
auto literalDef =
633-
dyn_cast_if_present<LiteralOp>(operand.getDefiningOp());
634-
if (!literalDef && !emitter.hasValueInScope(operand))
633+
if (!emitter.hasValueInScope(operand))
635634
return op.emitOpError("operand ")
636635
<< idx << "'s value not defined in scope";
637636
os << emitter.getOrCreateName(operand);
@@ -668,6 +667,10 @@ static LogicalResult printOperation(CppEmitter &emitter,
668667
emitc::ApplyOp applyOp) {
669668
raw_ostream &os = emitter.ostream();
670669
Operation &op = *applyOp.getOperation();
670+
StringRef applicableOp = applyOp.getApplicableOperator();
671+
672+
if (applicableOp == "*")
673+
return emitter.cacheDeferredOpResult(applyOp);
671674

672675
if (failed(emitter.emitAssignPrefix(op)))
673676
return failure();
@@ -956,8 +959,7 @@ static LogicalResult printFunctionBody(CppEmitter &emitter,
956959
// regions.
957960
WalkResult result =
958961
functionOp->walk<WalkOrder::PreOrder>([&](Operation *op) -> WalkResult {
959-
if (isa<emitc::LiteralOp>(op) ||
960-
isa<emitc::ExpressionOp>(op->getParentOp()) ||
962+
if (isa<emitc::ExpressionOp>(op->getParentOp()) ||
961963
(isa<emitc::ExpressionOp>(op) &&
962964
shouldBeInlined(cast<emitc::ExpressionOp>(op))))
963965
return WalkResult::skip();
@@ -1009,7 +1011,8 @@ static LogicalResult printFunctionBody(CppEmitter &emitter,
10091011
// trailing semicolon is handled within the printOperation function.
10101012
bool trailingSemicolon =
10111013
!isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::ForOp,
1012-
emitc::IfOp, emitc::LiteralOp, emitc::VerbatimOp>(op);
1014+
emitc::IfOp, emitc::VerbatimOp>(op) ||
1015+
hasDeferredEmission(&op);
10131016

10141017
if (failed(emitter.emitOperation(
10151018
op, /*trailingSemicolon=*/trailingSemicolon)))
@@ -1142,20 +1145,48 @@ std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
11421145
return out;
11431146
}
11441147

1148+
LogicalResult CppEmitter::cacheDeferredOpResult(Operation *op) {
1149+
if (op->getNumResults() != 1)
1150+
return op->emitError("Adding deferred ops into value cahce only works for "
1151+
"single results operations, got ")
1152+
<< op->getNumResults() << " results";
1153+
1154+
Value result = op->getResult(0);
1155+
if (valueMapper.count(result))
1156+
return success();
1157+
1158+
if (auto applyOp = dyn_cast<emitc::ApplyOp>(op)) {
1159+
assert(applyOp.getApplicableOperator() == "*" && "expected derefernce");
1160+
valueMapper.insert(result, std::string("*") +
1161+
getOrCreateName(applyOp.getOperand()).str());
1162+
return success();
1163+
}
1164+
1165+
if (auto getGlobal = dyn_cast<emitc::GetGlobalOp>(op)) {
1166+
valueMapper.insert(result, getGlobal.getName().str());
1167+
return success();
1168+
}
1169+
1170+
if (auto literal = dyn_cast<emitc::LiteralOp>(op)) {
1171+
valueMapper.insert(result, literal.getValue().str());
1172+
return success();
1173+
}
1174+
1175+
if (auto subscript = dyn_cast<emitc::SubscriptOp>(op)) {
1176+
valueMapper.insert(result, getSubscriptName(subscript));
1177+
return success();
1178+
}
1179+
1180+
return op->emitError("cacheDeferredOpResult not implemented");
1181+
}
1182+
11451183
/// Return the existing or a new name for a Value.
11461184
StringRef CppEmitter::getOrCreateName(Value val) {
1147-
if (auto literal = dyn_cast_if_present<emitc::LiteralOp>(val.getDefiningOp()))
1148-
return literal.getValue();
11491185
if (!valueMapper.count(val)) {
1150-
if (auto subscript =
1151-
dyn_cast_if_present<emitc::SubscriptOp>(val.getDefiningOp())) {
1152-
valueMapper.insert(val, getSubscriptName(subscript));
1153-
} else if (auto getGlobal = dyn_cast_if_present<emitc::GetGlobalOp>(
1154-
val.getDefiningOp())) {
1155-
valueMapper.insert(val, getGlobal.getName().str());
1156-
} else {
1157-
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
1158-
}
1186+
assert(!hasDeferredEmission(val.getDefiningOp()) &&
1187+
"cacheDeferredOpResult should have been called on this value, "
1188+
"update the emitOperation function.");
1189+
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));
11591190
}
11601191
return *valueMapper.begin(val);
11611192
}
@@ -1349,9 +1380,6 @@ LogicalResult CppEmitter::emitOperand(Value value) {
13491380
if (expressionOp && shouldBeInlined(expressionOp))
13501381
return emitExpression(expressionOp);
13511382

1352-
auto literalOp = dyn_cast_if_present<LiteralOp>(value.getDefiningOp());
1353-
if (!literalOp && !hasValueInScope(value))
1354-
return failure();
13551383
os << getOrCreateName(value);
13561384
return success();
13571385
}
@@ -1407,7 +1435,7 @@ LogicalResult CppEmitter::emitVariableAssignment(OpResult result) {
14071435

14081436
LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
14091437
bool trailingSemicolon) {
1410-
if (isa<emitc::SubscriptOp, emitc::GetGlobalOp>(result.getDefiningOp()))
1438+
if (hasDeferredEmission(result.getDefiningOp()))
14111439
return success();
14121440
if (hasValueInScope(result)) {
14131441
return result.getDefiningOp()->emitError(
@@ -1506,25 +1534,25 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
15061534
emitc::CallOpaqueOp, emitc::CastOp, emitc::CmpOp,
15071535
emitc::ConditionalOp, emitc::ConstantOp, emitc::DeclareFuncOp,
15081536
emitc::DivOp, emitc::ExpressionOp, emitc::ForOp, emitc::FuncOp,
1509-
emitc::GlobalOp, emitc::GetGlobalOp, emitc::IfOp,
1510-
emitc::IncludeOp, emitc::LogicalAndOp, emitc::LogicalNotOp,
1511-
emitc::LogicalOrOp, emitc::LValueLoadOp, emitc::MulOp,
1512-
emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SubscriptOp,
1513-
emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
1514-
emitc::VerbatimOp>(
1537+
emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp,
1538+
emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp,
1539+
emitc::LValueLoadOp, emitc::MulOp, emitc::RemOp,
1540+
emitc::ReturnOp, emitc::SubOp, emitc::UnaryMinusOp,
1541+
emitc::UnaryPlusOp, emitc::VariableOp, emitc::VerbatimOp>(
15151542
[&](auto op) { return printOperation(*this, op); })
15161543
// Func ops.
15171544
.Case<func::CallOp, func::FuncOp, func::ReturnOp>(
15181545
[&](auto op) { return printOperation(*this, op); })
1519-
.Case<emitc::LiteralOp>([&](auto op) { return success(); })
1546+
.Case<emitc::GetGlobalOp, emitc::LiteralOp, emitc::SubscriptOp>(
1547+
[&](Operation *op) { return cacheDeferredOpResult(op); })
15201548
.Default([&](Operation *) {
15211549
return op.emitOpError("unable to find printer for op");
15221550
});
15231551

15241552
if (failed(status))
15251553
return failure();
15261554

1527-
if (isa<emitc::LiteralOp, emitc::SubscriptOp, emitc::GetGlobalOp>(op))
1555+
if (hasDeferredEmission(&op))
15281556
return success();
15291557

15301558
if (getEmittedExpression() ||

mlir/test/Target/Cpp/common-cpp.mlir

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,8 @@ func.func @apply(%arg0: !emitc.lvalue<i32>) -> !emitc.ptr<i32> {
9191
%2 = "emitc.variable"() {value = #emitc.opaque<"">} : () -> !emitc.lvalue<i32>
9292
// CHECK-NEXT: int32_t [[V4:[^ ]*]] = *[[V2]];
9393
%1 = emitc.apply "*"(%0) : (!emitc.ptr<i32>) -> !emitc.lvalue<i32>
94-
// CHECK-NEXT: int32_t [[V5:[^ ]*]] = [[V4]];
9594
%3 = emitc.lvalue_load %1 : !emitc.lvalue<i32>
96-
// CHECK-NEXT: [[V3]] = [[V5]];
95+
// CHECK-NEXT: [[V3]] = [[V4]];
9796
emitc.assign %3 : i32 to %2 : !emitc.lvalue<i32>
9897
// CHECK-NEXT: return [[V2]];
9998
return %0 : !emitc.ptr<i32>

0 commit comments

Comments
 (0)