Skip to content

Commit 8c03a31

Browse files
author
Gabor Horvath
committed
[cxx-interop] Support transforming lifetimebound spans
This PR adds basic support for storing lifetime dependence information, transform Span return types, and generate lifetime annotations. rdar://139074571
1 parent 6b2fb2e commit 8c03a31

File tree

8 files changed

+256
-19
lines changed

8 files changed

+256
-19
lines changed

lib/Macros/Sources/SwiftMacros/SwiftifyImportMacro.swift

Lines changed: 157 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import SwiftSyntaxMacros
77
// avoids depending on SwiftifyImport.swift
88
// all instances are reparsed and reinstantiated by the macro anyways,
99
// so linking is irrelevant
10-
enum SwiftifyExpr {
10+
enum SwiftifyExpr: Hashable {
1111
case param(_ index: Int)
1212
case `return`
1313
}
@@ -21,11 +21,21 @@ extension SwiftifyExpr: CustomStringConvertible {
2121
}
2222
}
2323

24+
enum DependenceType {
25+
case borrow, copy
26+
}
27+
28+
struct LifetimeDependence {
29+
let dependsOn: Int
30+
let type: DependenceType
31+
}
32+
2433
protocol ParamInfo: CustomStringConvertible {
2534
var description: String { get }
2635
var original: SyntaxProtocol { get }
2736
var pointerIndex: SwiftifyExpr { get }
2837
var nonescaping: Bool { get set }
38+
var dependencies: [LifetimeDependence] { get set }
2939

3040
func getBoundsCheckedThunkBuilder(
3141
_ base: BoundsCheckedThunkBuilder, _ funcDecl: FunctionDeclSyntax,
@@ -55,8 +65,9 @@ func getSwiftifyExprType(_ funcDecl: FunctionDeclSyntax, _ expr: SwiftifyExpr) -
5565
struct CxxSpan: ParamInfo {
5666
var pointerIndex: SwiftifyExpr
5767
var nonescaping: Bool
58-
var original: SyntaxProtocol
68+
var dependencies: [LifetimeDependence]
5969
var typeMappings: [String: String]
70+
var original: SyntaxProtocol
6071

6172
var description: String {
6273
return "std::span(pointer: \(pointerIndex), nonescaping: \(nonescaping))"
@@ -71,9 +82,8 @@ struct CxxSpan: ParamInfo {
7182
return CxxSpanThunkBuilder(base: base, index: i - 1, signature: funcDecl.signature,
7283
typeMappings: typeMappings, node: original, nonescaping: nonescaping)
7384
case .return:
74-
// TODO: actually implement std::span in return position
75-
return CxxSpanThunkBuilder(base: base, index: -1, signature: funcDecl.signature,
76-
typeMappings: typeMappings, node: original, nonescaping: nonescaping)
85+
return CxxSpanReturnThunkBuilder(base: base, signature: funcDecl.signature,
86+
typeMappings: typeMappings, node: original)
7787
}
7888
}
7989
}
@@ -83,6 +93,7 @@ struct CountedBy: ParamInfo {
8393
var count: ExprSyntax
8494
var sizedBy: Bool
8595
var nonescaping: Bool
96+
var dependencies: [LifetimeDependence]
8697
var original: SyntaxProtocol
8798

8899
var description: String {
@@ -156,6 +167,8 @@ func getTypeName(_ type: TypeSyntax) throws -> TokenSyntax {
156167
return memberType.name
157168
case .identifierType:
158169
return type.as(IdentifierTypeSyntax.self)!.name
170+
case .attributedType:
171+
return try getTypeName(type.as(AttributedTypeSyntax.self)!.baseType)
159172
default:
160173
throw DiagnosticError("expected pointer type, got \(type) with kind \(type.kind)", node: type)
161174
}
@@ -169,6 +182,13 @@ func replaceTypeName(_ type: TypeSyntax, _ name: TokenSyntax) -> TypeSyntax {
169182
return TypeSyntax(idType.with(\.name, name))
170183
}
171184

185+
func replaceBaseType(_ type: TypeSyntax, _ base: TypeSyntax) -> TypeSyntax {
186+
if let attributedType = type.as(AttributedTypeSyntax.self) {
187+
return TypeSyntax(attributedType.with(\.baseType, base))
188+
}
189+
return base
190+
}
191+
172192
func getPointerMutability(text: String) -> Mutability? {
173193
switch text {
174194
case "UnsafePointer": return .Immutable
@@ -352,7 +372,7 @@ struct CxxSpanThunkBuilder: ParamPointerBoundsThunkBuilder {
352372
let parsedDesugaredType = TypeSyntax("\(raw: getUnqualifiedStdName(desugaredType)!)")
353373
let genericArg = TypeSyntax(parsedDesugaredType.as(IdentifierTypeSyntax.self)!
354374
.genericArgumentClause!.arguments.first!.argument)!
355-
types[index] = TypeSyntax("Span<\(raw: try getTypeName(genericArg).text)>")
375+
types[index] = replaceBaseType(param.type, TypeSyntax("Span<\(raw: try getTypeName(genericArg).text)>"))
356376
return try base.buildFunctionSignature(types, returnType)
357377
}
358378

@@ -365,6 +385,38 @@ struct CxxSpanThunkBuilder: ParamPointerBoundsThunkBuilder {
365385
}
366386
}
367387

388+
struct CxxSpanReturnThunkBuilder: BoundsCheckedThunkBuilder {
389+
public let base: BoundsCheckedThunkBuilder
390+
public let signature: FunctionSignatureSyntax
391+
public let typeMappings: [String: String]
392+
public let node: SyntaxProtocol
393+
394+
func buildBoundsChecks() throws -> [CodeBlockItemSyntax.Item] {
395+
return []
396+
}
397+
398+
func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ returnType: TypeSyntax?) throws
399+
-> FunctionSignatureSyntax {
400+
assert(returnType == nil)
401+
let typeName = try getTypeName(signature.returnClause!.type).text
402+
guard let desugaredType = typeMappings[typeName] else {
403+
throw DiagnosticError(
404+
"unable to desugar type with name '\(typeName)'", node: node)
405+
}
406+
let parsedDesugaredType = TypeSyntax("\(raw: getUnqualifiedStdName(desugaredType)!)")
407+
let genericArg = TypeSyntax(parsedDesugaredType.as(IdentifierTypeSyntax.self)!
408+
.genericArgumentClause!.arguments.first!.argument)!
409+
let newType = replaceBaseType(signature.returnClause!.type,
410+
TypeSyntax("Span<\(raw: try getTypeName(genericArg).text)>"))
411+
return try base.buildFunctionSignature(argTypes, newType)
412+
}
413+
414+
func buildFunctionCall(_ pointerArgs: [Int: ExprSyntax]) throws -> ExprSyntax {
415+
let call = try base.buildFunctionCall(pointerArgs)
416+
return "Span(_unsafeCxxSpan: \(call))"
417+
}
418+
}
419+
368420
protocol PointerBoundsThunkBuilder: BoundsCheckedThunkBuilder {
369421
var oldType: TypeSyntax { get }
370422
var newType: TypeSyntax { get throws }
@@ -723,7 +775,7 @@ public struct SwiftifyImportMacro: PeerMacro {
723775
}
724776
return CountedBy(
725777
pointerIndex: pointerExpr, count: unwrappedCountExpr, sizedBy: false,
726-
nonescaping: false, original: ExprSyntax(enumConstructorExpr))
778+
nonescaping: false, dependencies: [], original: ExprSyntax(enumConstructorExpr))
727779
}
728780

729781
static func parseSizedByEnum(_ enumConstructorExpr: FunctionCallExprSyntax) throws -> ParamInfo {
@@ -738,7 +790,7 @@ public struct SwiftifyImportMacro: PeerMacro {
738790
let unwrappedCountExpr = ExprSyntax(stringLiteral: sizeExprStringLit.representedLiteralValue!)
739791
return CountedBy(
740792
pointerIndex: pointerExpr, count: unwrappedCountExpr, sizedBy: true, nonescaping: false,
741-
original: ExprSyntax(enumConstructorExpr))
793+
dependencies: [], original: ExprSyntax(enumConstructorExpr))
742794
}
743795

744796
static func parseEndedByEnum(_ enumConstructorExpr: FunctionCallExprSyntax) throws -> ParamInfo {
@@ -758,6 +810,24 @@ public struct SwiftifyImportMacro: PeerMacro {
758810
return pointerParamIndex
759811
}
760812

813+
static func parseLifetimeDependence(_ enumConstructorExpr: FunctionCallExprSyntax) throws -> (SwiftifyExpr, LifetimeDependence) {
814+
let argumentList = enumConstructorExpr.arguments
815+
let pointer: SwiftifyExpr = try parseSwiftifyExpr(try getArgumentByName(argumentList, "pointer"))
816+
let dependsOn: Int = try getIntLiteralValue(try getArgumentByName(argumentList, "dependsOn"))
817+
let type = try getArgumentByName(argumentList, "type")
818+
let depType: DependenceType
819+
switch try parseEnumName(type) {
820+
case "borrow":
821+
depType = DependenceType.borrow
822+
case "copy":
823+
depType = DependenceType.copy
824+
default:
825+
throw DiagnosticError("expected '.copy' or '.borrow', got '\(type)'", node: type)
826+
}
827+
let dependence = LifetimeDependence(dependsOn: dependsOn, type: depType)
828+
return (pointer, dependence)
829+
}
830+
761831
static func parseTypeMappingParam(_ paramAST: LabeledExprSyntax?) throws -> [String: String]? {
762832
guard let unwrappedParamAST = paramAST else {
763833
return nil
@@ -786,31 +856,38 @@ public struct SwiftifyImportMacro: PeerMacro {
786856
return dict
787857
}
788858

789-
static func parseCxxSpanParams(
859+
static func parseCxxSpansInSignature(
790860
_ signature: FunctionSignatureSyntax,
791861
_ typeMappings: [String: String]?
792862
) throws -> [ParamInfo] {
793863
guard let typeMappings else {
794864
return []
795865
}
796866
var result : [ParamInfo] = []
797-
for (idx, param) in signature.parameterClause.parameters.enumerated() {
798-
let typeName = try getTypeName(param.type).text;
867+
let process = { type, expr, orig in
868+
let typeName = try getTypeName(type).text;
799869
if let desugaredType = typeMappings[typeName] {
800870
if let unqualifiedDesugaredType = getUnqualifiedStdName(desugaredType) {
801871
if unqualifiedDesugaredType.starts(with: "span<") {
802-
result.append(CxxSpan(pointerIndex: .param(idx + 1), nonescaping: false,
803-
original: param, typeMappings: typeMappings))
872+
result.append(CxxSpan(pointerIndex: expr, nonescaping: false,
873+
dependencies: [], typeMappings: typeMappings, original: orig))
804874
}
805875
}
806876
}
807877
}
878+
for (idx, param) in signature.parameterClause.parameters.enumerated() {
879+
try process(param.type, .param(idx + 1), param)
880+
}
881+
if let retClause = signature.returnClause {
882+
try process(retClause.type, .`return`, retClause)
883+
}
808884
return result
809885
}
810886

811887
static func parseMacroParam(
812888
_ paramAST: LabeledExprSyntax, _ signature: FunctionSignatureSyntax,
813-
nonescapingPointers: inout Set<Int>
889+
nonescapingPointers: inout Set<Int>,
890+
lifetimeDependencies: inout [SwiftifyExpr: [LifetimeDependence]]
814891
) throws -> ParamInfo? {
815892
let paramExpr = paramAST.expression
816893
guard let enumConstructorExpr = paramExpr.as(FunctionCallExprSyntax.self) else {
@@ -826,9 +903,23 @@ public struct SwiftifyImportMacro: PeerMacro {
826903
let index = try parseNonEscaping(enumConstructorExpr)
827904
nonescapingPointers.insert(index)
828905
return nil
906+
case "lifetimeDependence":
907+
let (expr, dependence) = try parseLifetimeDependence(enumConstructorExpr)
908+
lifetimeDependencies[expr, default: []].append(dependence)
909+
// We assume pointers annotated with lifetimebound do not escape.
910+
if dependence.type == DependenceType.copy {
911+
nonescapingPointers.insert(dependence.dependsOn)
912+
}
913+
// The escaping is controlled when a parameter is the target of a lifetimebound.
914+
// So we want to do the transformation to Swift's Span.
915+
let idx = paramOrReturnIndex(expr)
916+
if idx != -1 {
917+
nonescapingPointers.insert(idx)
918+
}
919+
return nil
829920
default:
830921
throw DiagnosticError(
831-
"expected 'countedBy', 'sizedBy', 'endedBy' or 'nonescaping', got '\(enumName)'",
922+
"expected 'countedBy', 'sizedBy', 'endedBy', 'nonescaping' or 'lifetimeDependence', got '\(enumName)'",
832923
node: enumConstructorExpr)
833924
}
834925
}
@@ -898,11 +989,48 @@ public struct SwiftifyImportMacro: PeerMacro {
898989
}
899990

900991
static func setNonescapingPointers(_ args: inout [ParamInfo], _ nonescapingPointers: Set<Int>) {
992+
if args.isEmpty {
993+
return
994+
}
901995
for i in 0...args.count - 1 where nonescapingPointers.contains(paramOrReturnIndex(args[i].pointerIndex)) {
902996
args[i].nonescaping = true
903997
}
904998
}
905999

1000+
static func setLifetimeDependencies(_ args: inout [ParamInfo], _ lifetimeDependencies: [SwiftifyExpr: [LifetimeDependence]]) {
1001+
if args.isEmpty {
1002+
return
1003+
}
1004+
for i in 0...args.count - 1 where lifetimeDependencies.keys.contains(args[i].pointerIndex) {
1005+
args[i].dependencies = lifetimeDependencies[args[i].pointerIndex]!
1006+
}
1007+
}
1008+
1009+
static func lifetimeAttributes(_ funcDecl: FunctionDeclSyntax,
1010+
_ dependencies: [SwiftifyExpr: [LifetimeDependence]]) -> [AttributeListSyntax.Element] {
1011+
let returnDependencies = dependencies[.`return`, default: []]
1012+
if returnDependencies.isEmpty {
1013+
return []
1014+
}
1015+
var args : [LabeledExprSyntax] = []
1016+
for dependence in returnDependencies {
1017+
if (dependence.type == .borrow) {
1018+
args.append(LabeledExprSyntax(expression:
1019+
DeclReferenceExprSyntax(baseName: TokenSyntax("borrow"))))
1020+
}
1021+
args.append(LabeledExprSyntax(expression:
1022+
DeclReferenceExprSyntax(baseName: TokenSyntax(tryGetParamName(funcDecl, .param(dependence.dependsOn)))!),
1023+
trailingComma: .commaToken()))
1024+
}
1025+
args[args.count - 1] = args[args.count - 1].with(\.trailingComma, nil)
1026+
return [.attribute(AttributeSyntax(
1027+
atSign: .atSignToken(),
1028+
attributeName: IdentifierTypeSyntax(name: "lifetime"),
1029+
leftParen: .leftParenToken(),
1030+
arguments: .argumentList(LabeledExprListSyntax(args)),
1031+
rightParen: .rightParenToken()))]
1032+
}
1033+
9061034
public static func expansion(
9071035
of node: AttributeSyntax,
9081036
providingPeersOf declaration: some DeclSyntaxProtocol,
@@ -920,13 +1048,21 @@ public struct SwiftifyImportMacro: PeerMacro {
9201048
arguments = arguments.dropLast()
9211049
}
9221050
var nonescapingPointers = Set<Int>()
1051+
var lifetimeDependencies : [SwiftifyExpr: [LifetimeDependence]] = [:]
9231052
var parsedArgs = try arguments.compactMap {
924-
try parseMacroParam($0, funcDecl.signature, nonescapingPointers: &nonescapingPointers)
1053+
try parseMacroParam($0, funcDecl.signature, nonescapingPointers: &nonescapingPointers,
1054+
lifetimeDependencies: &lifetimeDependencies)
9251055
}
926-
parsedArgs.append(contentsOf: try parseCxxSpanParams(funcDecl.signature, typeMappings))
1056+
parsedArgs.append(contentsOf: try parseCxxSpansInSignature(funcDecl.signature, typeMappings))
9271057
setNonescapingPointers(&parsedArgs, nonescapingPointers)
1058+
setLifetimeDependencies(&parsedArgs, lifetimeDependencies)
1059+
// We only transform non-escaping spans.
9281060
parsedArgs = parsedArgs.filter {
929-
!($0 is CxxSpan) || ($0 as! CxxSpan).nonescaping
1061+
if let cxxSpanArg = $0 as? CxxSpan {
1062+
return cxxSpanArg.nonescaping || cxxSpanArg.pointerIndex == .return
1063+
} else {
1064+
return true
1065+
}
9301066
}
9311067
try checkArgs(parsedArgs, funcDecl)
9321068
let baseBuilder = FunctionCallBuilder(funcDecl)
@@ -951,6 +1087,7 @@ public struct SwiftifyImportMacro: PeerMacro {
9511087
returnKeyword: .keyword(.return, trailingTrivia: " "),
9521088
expression: try builder.buildFunctionCall([:]))))
9531089
let body = CodeBlockSyntax(statements: CodeBlockItemListSyntax(checks + [call]))
1090+
let lifetimeAttrs = lifetimeAttributes(funcDecl, lifetimeDependencies)
9541091
let newFunc =
9551092
funcDecl
9561093
.with(\.signature, newSignature)
@@ -970,7 +1107,8 @@ public struct SwiftifyImportMacro: PeerMacro {
9701107
AttributeSyntax(
9711108
atSign: .atSignToken(),
9721109
attributeName: IdentifierTypeSyntax(name: "_alwaysEmitIntoClient")))
973-
])
1110+
]
1111+
+ lifetimeAttrs)
9741112
return [DeclSyntax(newFunc)]
9751113
} catch let error as DiagnosticError {
9761114
context.diagnose(

stdlib/public/core/SwiftifyImport.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@ public enum _SwiftifyExpr {
22
case param(_ index: Int)
33
case `return`
44
}
5+
6+
public enum _DependenceType {
7+
case borrow
8+
case copy
9+
}
510
/// Different ways to annotate pointer parameters using the `@_SwiftifyImport` macro.
611
/// All indices into parameter lists start at 1. Indices __must__ be integer literals, and strings
712
/// __must__ be string literals, because their contents are parsed by the `@_SwiftifyImport` macro.
@@ -34,6 +39,9 @@ public enum _SwiftifyInfo {
3439
/// object past the lifetime of the function.
3540
/// Parameter pointer: index of pointer in function parameter list.
3641
case nonescaping(pointer: _SwiftifyExpr)
42+
/// Can express lifetime dependencies between inputs and outputs of a function.
43+
/// 'dependsOn' is the input on which the output 'pointer' depends.
44+
case lifetimeDependence(dependsOn: Int, pointer: _SwiftifyExpr, type: _DependenceType)
3745
}
3846

3947
/// Generates a safe wrapper for function with Unsafe[Mutable][Raw]Pointer[?] or std::span arguments.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
module StdSpan {
2+
header "std-span.h"
3+
requires cplusplus
4+
export *
5+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#pragma once
2+
3+
#include <span>
4+
5+
using SpanOfInt = std::span<const int>;

0 commit comments

Comments
 (0)