Skip to content

[clang][Interp] Simplify and fix variable scope handling #101788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 15 additions & 48 deletions clang/lib/AST/Interp/Compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2226,7 +2226,7 @@ bool Compiler<Emitter>::VisitExprWithCleanups(const ExprWithCleanups *E) {

assert(E->getNumObjects() == 0 && "TODO: Implement cleanups");

return this->delegate(SubExpr) && ES.destroyLocals();
return this->delegate(SubExpr) && ES.destroyLocals(E);
}

template <class Emitter>
Expand Down Expand Up @@ -2537,13 +2537,8 @@ bool Compiler<Emitter>::VisitCXXConstructExpr(const CXXConstructExpr *E) {
return false;
}

// Immediately call the destructor if we have to.
if (DiscardResult) {
if (!this->emitRecordDestruction(getRecord(E->getType())))
return false;
if (!this->emitPopPtr(E))
return false;
}
if (DiscardResult)
return this->emitPopPtr(E);
return true;
}

Expand Down Expand Up @@ -4222,22 +4217,6 @@ template <class Emitter> bool Compiler<Emitter>::visitStmt(const Stmt *S) {
}
}

/// Visits the given statment without creating a variable
/// scope for it in case it is a compound statement.
template <class Emitter> bool Compiler<Emitter>::visitLoopBody(const Stmt *S) {
if (isa<NullStmt>(S))
return true;

if (const auto *CS = dyn_cast<CompoundStmt>(S)) {
for (const auto *InnerStmt : CS->body())
if (!visitStmt(InnerStmt))
return false;
return true;
}

return this->visitStmt(S);
}

template <class Emitter>
bool Compiler<Emitter>::visitCompoundStmt(const CompoundStmt *S) {
BlockScope<Emitter> Scope(this);
Expand Down Expand Up @@ -4300,8 +4279,6 @@ bool Compiler<Emitter>::visitReturnStmt(const ReturnStmt *RS) {
}

template <class Emitter> bool Compiler<Emitter>::visitIfStmt(const IfStmt *IS) {
BlockScope<Emitter> IfScope(this);

if (IS->isNonNegatedConsteval())
return visitStmt(IS->getThen());
if (IS->isNegatedConsteval())
Expand Down Expand Up @@ -4340,7 +4317,7 @@ template <class Emitter> bool Compiler<Emitter>::visitIfStmt(const IfStmt *IS) {
this->emitLabel(LabelEnd);
}

return IfScope.destroyLocals();
return true;
}

template <class Emitter>
Expand All @@ -4364,12 +4341,8 @@ bool Compiler<Emitter>::visitWhileStmt(const WhileStmt *S) {
if (!this->jumpFalse(EndLabel))
return false;

LocalScope<Emitter> Scope(this);
{
DestructorScope<Emitter> DS(Scope);
if (!this->visitLoopBody(Body))
return false;
}
if (!this->visitStmt(Body))
return false;

if (!this->jump(CondLabel))
return false;
Expand All @@ -4387,14 +4360,11 @@ template <class Emitter> bool Compiler<Emitter>::visitDoStmt(const DoStmt *S) {
LabelTy EndLabel = this->getLabel();
LabelTy CondLabel = this->getLabel();
LoopScope<Emitter> LS(this, EndLabel, CondLabel);
LocalScope<Emitter> Scope(this);

this->fallthrough(StartLabel);
this->emitLabel(StartLabel);
{
DestructorScope<Emitter> DS(Scope);

if (!this->visitLoopBody(Body))
if (!this->visitStmt(Body))
return false;
this->fallthrough(CondLabel);
this->emitLabel(CondLabel);
Expand All @@ -4421,10 +4391,10 @@ bool Compiler<Emitter>::visitForStmt(const ForStmt *S) {
LabelTy CondLabel = this->getLabel();
LabelTy IncLabel = this->getLabel();
LoopScope<Emitter> LS(this, EndLabel, IncLabel);
LocalScope<Emitter> Scope(this);

if (Init && !this->visitStmt(Init))
return false;

this->fallthrough(CondLabel);
this->emitLabel(CondLabel);

Expand All @@ -4440,10 +4410,9 @@ bool Compiler<Emitter>::visitForStmt(const ForStmt *S) {
}

{
DestructorScope<Emitter> DS(Scope);

if (Body && !this->visitLoopBody(Body))
if (Body && !this->visitStmt(Body))
return false;

this->fallthrough(IncLabel);
this->emitLabel(IncLabel);
if (Inc && !this->discard(Inc))
Expand Down Expand Up @@ -4495,13 +4464,11 @@ bool Compiler<Emitter>::visitCXXForRangeStmt(const CXXForRangeStmt *S) {
return false;

// Body.
LocalScope<Emitter> Scope(this);
{
DestructorScope<Emitter> DS(Scope);

if (!this->visitLoopBody(Body))
if (!this->visitStmt(Body))
return false;
this->fallthrough(IncLabel);

this->fallthrough(IncLabel);
this->emitLabel(IncLabel);
if (!this->discard(Inc))
return false;
Expand All @@ -4520,7 +4487,7 @@ bool Compiler<Emitter>::visitBreakStmt(const BreakStmt *S) {
if (!BreakLabel)
return false;

this->VarScope->emitDestructors();
this->emitCleanup();
return this->jump(*BreakLabel);
}

Expand All @@ -4529,7 +4496,7 @@ bool Compiler<Emitter>::visitContinueStmt(const ContinueStmt *S) {
if (!ContinueLabel)
return false;

this->VarScope->emitDestructors();
this->emitCleanup();
return this->jump(*ContinueLabel);
}

Expand Down
44 changes: 20 additions & 24 deletions clang/lib/AST/Interp/Compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ class Compiler : public ConstStmtVisitor<Compiler<Emitter>, bool>,

// Statements.
bool visitCompoundStmt(const CompoundStmt *S);
bool visitLoopBody(const Stmt *S);
bool visitDeclStmt(const DeclStmt *DS);
bool visitReturnStmt(const ReturnStmt *RS);
bool visitIfStmt(const IfStmt *IS);
Expand Down Expand Up @@ -452,11 +451,15 @@ template <class Emitter> class VariableScope {
}

// Use the parent scope.
addExtended(Local);
if (this->Parent)
this->Parent->addLocal(Local);
else
this->addLocal(Local);
}

virtual void emitDestruction() {}
virtual bool emitDestructors() { return true; }
virtual bool emitDestructors(const Expr *E = nullptr) { return true; }
virtual bool destroyLocals(const Expr *E = nullptr) { return true; }
VariableScope *getParent() const { return Parent; }

protected:
Expand All @@ -483,16 +486,21 @@ template <class Emitter> class LocalScope : public VariableScope<Emitter> {
}

/// Overriden to support explicit destruction.
void emitDestruction() override { destroyLocals(); }
void emitDestruction() override {
if (!Idx)
return;

this->emitDestructors();
this->Ctx->emitDestroy(*Idx, SourceInfo{});
}

/// Explicit destruction of local variables.
bool destroyLocals() {
bool destroyLocals(const Expr *E = nullptr) override {
if (!Idx)
return true;

bool Success = this->emitDestructors();
this->Ctx->emitDestroy(*Idx, SourceInfo{});
removeStoredOpaqueValues();
bool Success = this->emitDestructors(E);
this->Ctx->emitDestroy(*Idx, E);
this->Idx = std::nullopt;
return Success;
}
Expand All @@ -501,25 +509,26 @@ template <class Emitter> class LocalScope : public VariableScope<Emitter> {
if (!Idx) {
Idx = this->Ctx->Descriptors.size();
this->Ctx->Descriptors.emplace_back();
this->Ctx->emitInitScope(*Idx, {});
}

this->Ctx->Descriptors[*Idx].emplace_back(Local);
}

bool emitDestructors() override {
bool emitDestructors(const Expr *E = nullptr) override {
if (!Idx)
return true;
// Emit destructor calls for local variables of record
// type with a destructor.
for (Scope::Local &Local : this->Ctx->Descriptors[*Idx]) {
if (!Local.Desc->isPrimitive() && !Local.Desc->isPrimitiveArray()) {
if (!this->Ctx->emitGetPtrLocal(Local.Offset, SourceInfo{}))
if (!this->Ctx->emitGetPtrLocal(Local.Offset, E))
return false;

if (!this->Ctx->emitDestruction(Local.Desc))
return false;

if (!this->Ctx->emitPopPtr(SourceInfo{}))
if (!this->Ctx->emitPopPtr(E))
return false;
removeIfStoredOpaqueValue(Local);
}
Expand Down Expand Up @@ -549,19 +558,6 @@ template <class Emitter> class LocalScope : public VariableScope<Emitter> {
std::optional<unsigned> Idx;
};

/// Emits the destructors of the variables of \param OtherScope
/// when this scope is destroyed. Does not create a Scope in the bytecode at
/// all, this is just a RAII object to emit destructors.
template <class Emitter> class DestructorScope final {
public:
DestructorScope(LocalScope<Emitter> &OtherScope) : OtherScope(OtherScope) {}

~DestructorScope() { OtherScope.emitDestructors(); }

private:
LocalScope<Emitter> &OtherScope;
};

/// Scope for storage declared in a compound statement.
template <class Emitter> class BlockScope final : public LocalScope<Emitter> {
public:
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/AST/Interp/Interp.h
Original file line number Diff line number Diff line change
Expand Up @@ -2012,6 +2012,11 @@ inline bool Destroy(InterpState &S, CodePtr OpPC, uint32_t I) {
return true;
}

inline bool InitScope(InterpState &S, CodePtr OpPC, uint32_t I) {
S.Current->initScope(I);
return true;
}

//===----------------------------------------------------------------------===//
// Cast, CastFP
//===----------------------------------------------------------------------===//
Expand Down
14 changes: 11 additions & 3 deletions clang/lib/AST/Interp/InterpFrame.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ InterpFrame::InterpFrame(InterpState &S, const Function *Func,
Locals = std::make_unique<char[]>(FrameSize);
for (auto &Scope : Func->scopes()) {
for (auto &Local : Scope.locals()) {
Block *B =
new (localBlock(Local.Offset)) Block(S.Ctx.getEvalID(), Local.Desc);
B->invokeCtor();
new (localBlock(Local.Offset)) Block(S.Ctx.getEvalID(), Local.Desc);
// Note that we are NOT calling invokeCtor() here, since that is done
// via the InitScope op.
new (localInlineDesc(Local.Offset)) InlineDescriptor(Local.Desc);
}
}
Expand Down Expand Up @@ -83,6 +83,14 @@ InterpFrame::~InterpFrame() {
}
}

void InterpFrame::initScope(unsigned Idx) {
if (!Func)
return;
for (auto &Local : Func->getScope(Idx).locals()) {
localBlock(Local.Offset)->invokeCtor();
}
}

void InterpFrame::destroy(unsigned Idx) {
for (auto &Local : Func->getScope(Idx).locals()) {
S.deallocate(localBlock(Local.Offset));
Expand Down
1 change: 1 addition & 0 deletions clang/lib/AST/Interp/InterpFrame.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class InterpFrame final : public Frame {

/// Invokes the destructors for a scope.
void destroy(unsigned Idx);
void initScope(unsigned Idx);

/// Pops the arguments off the stack.
void popArgs();
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/AST/Interp/Opcodes.td
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,9 @@ def Destroy : Opcode {
let Args = [ArgUint32];
let HasCustomEval = 1;
}
def InitScope : Opcode {
let Args = [ArgUint32];
}

//===----------------------------------------------------------------------===//
// Constants
Expand Down
18 changes: 18 additions & 0 deletions clang/test/AST/Interp/if.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,21 @@ constexpr char g(char const (&x)[2]) {
;
}
static_assert(g("x") == 'x');

namespace IfScope {
struct Inc {
int &a;
constexpr Inc(int &a) : a(a) {}
constexpr ~Inc() { ++a; }
};

constexpr int foo() {
int a= 0;
int b = 12;
if (Inc{a}; true) {
b += a;
}
return b;
}
static_assert(foo() == 13, "");
}
28 changes: 28 additions & 0 deletions clang/test/AST/Interp/loops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,31 @@ namespace RangeForLoop {
// ref-note {{semicolon on a separate line}}
}
}

namespace Scopes {
constexpr int foo() {
int n = 0;
{
int m = 12;
for (int i = 0;i < 10;++i) {

{
int a = 10;
{
int b = 20;
{
int c = 30;
continue;
}
}
}
}
++m;
n = m;
}

++n;
return n;
}
static_assert(foo() == 14, "");
}
Loading