Skip to content

[astgen] Add ReturnStmt visitor. #61809

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions include/swift/AST/CASTBridging.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ void *SwiftTopLevelCodeDecl_createStmt(void *ctx, void *DC, void *startLoc,
void *SwiftTopLevelCodeDecl_createExpr(void *ctx, void *DC, void *startLoc,
void *element, void *endLoc);

void *ReturnStmt_create(void *ctx, void *loc, void *_Nullable expr);


void *SwiftSequenceExpr_create(void *ctx, BridgedArrayRef exprs);

void *SwiftTupleExpr_create(void *ctx, void *lparen, BridgedArrayRef subs,
Expand All @@ -128,8 +131,13 @@ void *SwiftVarDecl_create(void *ctx, BridgedIdentifier _Nullable name,
void *IfStmt_create(void *ctx, void *ifLoc, void *cond, void *_Nullable then, void *_Nullable elseLoc,
void *_Nullable elseStmt);

void *BraceStmt_createExpr(void *ctx, void *lbloc, BridgedArrayRef elements, void *rbloc);
void *BraceStmt_createStmt(void *ctx, void *lbloc, BridgedArrayRef elements, void *rbloc);
struct ASTNodeBridged {
void *ptr;
_Bool isExpr; // Must be expr or stmt.
};

void *BraceStmt_create(void *ctx, void *lbloc, BridgedArrayRef elements, void *rbloc);
void *BraceStmt_create(void *ctx, void *lbloc, BridgedArrayRef elements, void *rbloc);

void *BridgedSourceLoc_advanced(void *loc, long len);

Expand Down
18 changes: 11 additions & 7 deletions lib/AST/CASTBridging.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,17 +160,21 @@ void *IfStmt_create(void *ctx, void *ifLoc, void *cond, void *_Nullable then, vo
getSourceLocFromPointer(elseLoc), (Stmt *)elseStmt, None, Context);
}

void *BraceStmt_createExpr(void *ctx, void *lbloc, BridgedArrayRef elements, void *rbloc) {
void *ReturnStmt_create(void *ctx, void *loc, void *_Nullable expr) {
ASTContext &Context = *static_cast<ASTContext *>(ctx);
return BraceStmt::create(Context, getSourceLocFromPointer(lbloc),
getArrayRef<ASTNode>(elements),
getSourceLocFromPointer(rbloc));
return new (Context) ReturnStmt(getSourceLocFromPointer(loc), (Expr *)expr);
}

void *BraceStmt_createStmt(void *ctx, void *lbloc, BridgedArrayRef elements, void *rbloc) {
void *BraceStmt_create(void *ctx, void *lbloc, BridgedArrayRef elements, void *rbloc) {
llvm::SmallVector<ASTNode, 6> nodes;
for (auto stmt : getArrayRef<Stmt *>(elements)) {
nodes.push_back(stmt);
for (auto node : getArrayRef<ASTNodeBridged>(elements)) {
if (node.isExpr) {
auto expr = (Expr *)node.ptr;
nodes.push_back(expr);
} else {
auto stmt = (Stmt *)node.ptr;
nodes.push_back(stmt);
}
}

ASTContext &Context = *static_cast<ASTContext *>(ctx);
Expand Down
53 changes: 44 additions & 9 deletions lib/ASTGen/Sources/ASTGen/ASTGen.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,37 @@ extension UnsafePointer {
}
}

enum ASTNode {
case decl(UnsafeMutableRawPointer)
case stmt(UnsafeMutableRawPointer)
case expr(UnsafeMutableRawPointer)
case type(UnsafeMutableRawPointer)

var rawValue: UnsafeMutableRawPointer {
switch self {
case .decl(let ptr):
return ptr
case .stmt(let ptr):
return ptr
case .expr(let ptr):
return ptr
case .type(let ptr):
return ptr
}
}

func bridged() -> ASTNodeBridged {
switch self {
case .expr(let e):
return ASTNodeBridged(ptr: e, isExpr: true)
case .stmt(let s):
return ASTNodeBridged(ptr: s, isExpr: false)
default:
fatalError("Must be expr or stmt.")
}
}
}

/// Little utility wrapper that lets us have some mutable state within
/// immutable structs, and is therefore pretty evil.
@propertyWrapper
Expand All @@ -29,6 +60,8 @@ class Boxed<Value> {
}

struct ASTGenVisitor: SyntaxTransformVisitor {
typealias ResultType = ASTNode

let ctx: UnsafeMutableRawPointer
let base: UnsafePointer<UInt8>

Expand All @@ -41,11 +74,11 @@ struct ASTGenVisitor: SyntaxTransformVisitor {
// }

@_disfavoredOverload
public func visit(_ node: SourceFileSyntax) -> UnsafeMutableRawPointer {
public func visit(_ node: SourceFileSyntax) -> ASTNode {
fatalError("Use other overload.")
}

public func visitAny(_ node: Syntax) -> UnsafeMutableRawPointer {
public func visitAny(_ node: Syntax) -> ASTNode {
fatalError("Not implemented.")
}

Expand All @@ -55,13 +88,15 @@ struct ASTGenVisitor: SyntaxTransformVisitor {

for element in node.statements {
let swiftASTNodes = visit(element)
if element.item.is(StmtSyntax.self) {
out.append(SwiftTopLevelCodeDecl_createStmt(ctx, declContext, loc, swiftASTNodes, loc))
} else if element.item.is(ExprSyntax.self) {
out.append(SwiftTopLevelCodeDecl_createExpr(ctx, declContext, loc, swiftASTNodes, loc))
} else {
assert(element.item.is(DeclSyntax.self))
out.append(swiftASTNodes)
switch swiftASTNodes {
case .decl(let d):
out.append(d)
case .stmt(let s):
out.append(SwiftTopLevelCodeDecl_createStmt(ctx, declContext, loc, s, loc))
case .expr(let e):
out.append(SwiftTopLevelCodeDecl_createExpr(ctx, declContext, loc, e, loc))
case .type(_):
fatalError("Type should not exist at top level.")
}
}

Expand Down
53 changes: 23 additions & 30 deletions lib/ASTGen/Sources/ASTGen/Decls.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,37 @@ import SwiftSyntax
import CASTBridging

extension ASTGenVisitor {
public func visit(_ node: TypealiasDeclSyntax) -> UnsafeMutableRawPointer {
public func visit(_ node: TypealiasDeclSyntax) -> ASTNode {
let aliasLoc = self.base.advanced(by: node.typealiasKeyword.position.utf8Offset).raw
let equalLoc = self.base.advanced(by: node.initializer.equal.position.utf8Offset).raw
var nameText = node.identifier.text
let name = nameText.withUTF8 { buf in
return SwiftASTContext_getIdentifier(ctx, buf.baseAddress, buf.count)
}
let nameLoc = self.base.advanced(by: node.identifier.position.utf8Offset).raw
let genericParams = node.genericParameterClause.map(self.visit)
let genericParams = node.genericParameterClause.map(self.visit).map { $0.rawValue }
let out = TypeAliasDecl_create(self.ctx, self.declContext, aliasLoc, equalLoc, name, nameLoc, genericParams)

let oldDeclContext = declContext
declContext = out.declContext
defer { declContext = oldDeclContext }

let underlying = self.visit(node.initializer.value)
let underlying = self.visit(node.initializer.value).rawValue
TypeAliasDecl_setUnderlyingTypeRepr(out.nominalDecl, underlying)

return out.decl
return .decl(out.decl)
}

public func visit(_ node: StructDeclSyntax) -> UnsafeMutableRawPointer {
public func visit(_ node: StructDeclSyntax) -> ASTNode {
let loc = self.base.advanced(by: node.position.utf8Offset).raw
var nameText = node.identifier.text
let name = nameText.withUTF8 { buf in
return SwiftASTContext_getIdentifier(ctx, buf.baseAddress, buf.count)
}

let genericParams = node.genericParameterClause.map(self.visit)
let genericParams = node.genericParameterClause
.map(self.visit)
.map { $0.rawValue }
let out = StructDecl_create(ctx, loc, name, loc, genericParams, declContext)
let oldDeclContext = declContext
declContext = out.declContext
Expand All @@ -42,10 +44,10 @@ extension ASTGenVisitor {
NominalTypeDecl_setMembers(out.nominalDecl, ref)
}

return out.decl
return .decl(out.decl)
}

public func visit(_ node: ClassDeclSyntax) -> UnsafeMutableRawPointer {
public func visit(_ node: ClassDeclSyntax) -> ASTNode {
let loc = self.base.advanced(by: node.position.utf8Offset).raw
var nameText = node.identifier.text
let name = nameText.withUTF8 { buf in
Expand All @@ -61,31 +63,22 @@ extension ASTGenVisitor {
NominalTypeDecl_setMembers(out.nominalDecl, ref)
}

return out.decl
return .decl(out.decl)
}

public func visit(_ node: VariableDeclSyntax) -> UnsafeMutableRawPointer {
let pattern = visit(node.bindings.first!.pattern)
let initializer = visit(node.bindings.first!.initializer!)
public func visit(_ node: VariableDeclSyntax) -> ASTNode {
let pattern = visit(node.bindings.first!.pattern).rawValue
let initializer = visit(node.bindings.first!.initializer!).rawValue

let loc = self.base.advanced(by: node.position.utf8Offset).raw
let isStateic = false // TODO: compute this
let isLet = node.letOrVarKeyword.tokenKind == .letKeyword

// TODO: don't drop "initializer" on the floor.
return SwiftVarDecl_create(ctx, nil, loc, isStateic, isLet, declContext)
return .decl(SwiftVarDecl_create(ctx, nil, loc, isStateic, isLet, declContext))
}

public func visit(_ node: CodeBlockSyntax) -> UnsafeMutableRawPointer {
let statements = node.statements.map(self.visit)
let loc = self.base.advanced(by: node.position.utf8Offset).raw

return statements.withBridgedArrayRef { ref in
BraceStmt_createStmt(ctx, loc, ref, loc)
}
}

public func visit(_ node: FunctionParameterSyntax) -> UnsafeMutableRawPointer {
public func visit(_ node: FunctionParameterSyntax) -> ASTNode {
let loc = self.base.advanced(by: node.position.utf8Offset).raw

let firstName: UnsafeMutableRawPointer?
Expand All @@ -109,34 +102,34 @@ extension ASTGenVisitor {
secondName = nil
}

return ParamDecl_create(ctx, loc, loc, firstName, loc, secondName, declContext)
return .decl(ParamDecl_create(ctx, loc, loc, firstName, loc, secondName, declContext))
}

public func visit(_ node: FunctionDeclSyntax) -> UnsafeMutableRawPointer {
public func visit(_ node: FunctionDeclSyntax) -> ASTNode {
let loc = self.base.advanced(by: node.position.utf8Offset).raw

var nameText = node.identifier.text
let name = nameText.withUTF8 { buf in
return SwiftASTContext_getIdentifier(ctx, buf.baseAddress, buf.count)
}

let body: UnsafeMutableRawPointer?
let body: ASTNode?
if let nodeBody = node.body {
body = visit(nodeBody)
} else {
body = nil
}

let returnType: UnsafeMutableRawPointer?
let returnType: ASTNode?
if let output = node.signature.output {
returnType = visit(output.returnType)
} else {
returnType = nil
}

let params = node.signature.input.parameterList.map { visit($0) }
return params.withBridgedArrayRef { ref in
FuncDecl_create(ctx, loc, false, loc, name, loc, false, nil, false, nil, loc, ref, loc, body, returnType, declContext)
}
return .decl(params.withBridgedArrayRef { ref in
FuncDecl_create(ctx, loc, false, loc, name, loc, false, nil, false, nil, loc, ref, loc, body?.rawValue, returnType?.rawValue, declContext)
})
}
}
38 changes: 19 additions & 19 deletions lib/ASTGen/Sources/ASTGen/Exprs.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,74 +4,74 @@ import SwiftSyntax
import CASTBridging

extension ASTGenVisitor {
public func visit(_ node: ClosureExprSyntax) -> UnsafeMutableRawPointer {
let statements = node.statements.map(self.visit)
public func visit(_ node: ClosureExprSyntax) -> ASTNode {
let statements = node.statements.map(self.visit).map { $0.bridged() }
let loc = self.base.advanced(by: node.position.utf8Offset).raw

let body = statements.withBridgedArrayRef { ref in
BraceStmt_createExpr(ctx, loc, ref, loc)
BraceStmt_create(ctx, loc, ref, loc)
}

return ClosureExpr_create(ctx, body, declContext)
return .expr(ClosureExpr_create(ctx, body, declContext))
}

public func visit(_ node: FunctionCallExprSyntax) -> UnsafeMutableRawPointer {
public func visit(_ node: FunctionCallExprSyntax) -> ASTNode {
// Transform the trailing closure into an argument.
if let trailingClosure = node.trailingClosure {
let tupleElement = TupleExprElementSyntax(label: nil, colon: nil, expression: ExprSyntax(trailingClosure), trailingComma: nil)

return visit(node.addArgument(tupleElement).withTrailingClosure(nil))
}

let args = visit(node.argumentList)
let args = visit(node.argumentList).rawValue
// TODO: hack
let callee = visit(node.calledExpression)
let callee = visit(node.calledExpression).rawValue

return SwiftFunctionCallExpr_create(self.ctx, callee, args)
return .expr(SwiftFunctionCallExpr_create(self.ctx, callee, args))
}

public func visit(_ node: IdentifierExprSyntax) -> UnsafeMutableRawPointer {
public func visit(_ node: IdentifierExprSyntax) -> ASTNode {
let loc = self.base.advanced(by: node.position.utf8Offset).raw

var text = node.identifier.text
let id = text.withUTF8 { buf in
return SwiftASTContext_getIdentifier(ctx, buf.baseAddress, buf.count)
}

return SwiftIdentifierExpr_create(ctx, id, loc)
return .expr(SwiftIdentifierExpr_create(ctx, id, loc))
}

public func visit(_ node: IdentifierPatternSyntax) -> UnsafeMutableRawPointer {
public func visit(_ node: IdentifierPatternSyntax) -> ASTNode {
let loc = self.base.advanced(by: node.position.utf8Offset).raw

var text = node.identifier.text
let id = text.withUTF8 { buf in
return SwiftASTContext_getIdentifier(ctx, buf.baseAddress, buf.count)
}

return SwiftIdentifierExpr_create(ctx, id, loc)
return .expr(SwiftIdentifierExpr_create(ctx, id, loc))
}

public func visit(_ node: MemberAccessExprSyntax) -> UnsafeMutableRawPointer {
public func visit(_ node: MemberAccessExprSyntax) -> ASTNode {
let loc = self.base.advanced(by: node.position.utf8Offset).raw
let base = visit(node.base!)
let base = visit(node.base!).rawValue
var nameText = node.name.text
let name = nameText.withUTF8 { buf in
return SwiftASTContext_getIdentifier(ctx, buf.baseAddress, buf.count)
}

return UnresolvedDotExpr_create(ctx, base, loc, name, loc)
return .expr(UnresolvedDotExpr_create(ctx, base, loc, name, loc))
}

public func visit(_ node: TupleExprElementListSyntax) -> UnsafeMutableRawPointer {
let elements = node.map(self.visit)
public func visit(_ node: TupleExprElementListSyntax) -> ASTNode {
let elements = node.map(self.visit).map { $0.rawValue }

// TODO: find correct paren locs.
let lParenLoc = self.base.advanced(by: node.position.utf8Offset).raw
let rParenLoc = self.base.advanced(by: node.position.utf8Offset).raw

return elements.withBridgedArrayRef { elementsRef in
return .expr(elements.withBridgedArrayRef { elementsRef in
SwiftTupleExpr_create(self.ctx, lParenLoc, elementsRef, rParenLoc)
}
})
}
}
Loading