@@ -7,7 +7,7 @@ import SwiftSyntaxMacros
7
7
// avoids depending on SwiftifyImport.swift
8
8
// all instances are reparsed and reinstantiated by the macro anyways,
9
9
// so linking is irrelevant
10
- enum SwiftifyExpr {
10
+ enum SwiftifyExpr : Hashable {
11
11
case param( _ index: Int )
12
12
case `return`
13
13
}
@@ -21,11 +21,21 @@ extension SwiftifyExpr: CustomStringConvertible {
21
21
}
22
22
}
23
23
24
+ enum DependenceType {
25
+ case borrow, copy
26
+ }
27
+
28
+ struct LifetimeDependence {
29
+ let dependsOn : Int
30
+ let type : DependenceType
31
+ }
32
+
24
33
protocol ParamInfo : CustomStringConvertible {
25
34
var description : String { get }
26
35
var original : SyntaxProtocol { get }
27
36
var pointerIndex : SwiftifyExpr { get }
28
37
var nonescaping : Bool { get set }
38
+ var dependencies : [ LifetimeDependence ] { get set }
29
39
30
40
func getBoundsCheckedThunkBuilder(
31
41
_ base: BoundsCheckedThunkBuilder , _ funcDecl: FunctionDeclSyntax ,
@@ -55,8 +65,9 @@ func getSwiftifyExprType(_ funcDecl: FunctionDeclSyntax, _ expr: SwiftifyExpr) -
55
65
struct CxxSpan : ParamInfo {
56
66
var pointerIndex : SwiftifyExpr
57
67
var nonescaping : Bool
58
- var original : SyntaxProtocol
68
+ var dependencies : [ LifetimeDependence ]
59
69
var typeMappings : [ String : String ]
70
+ var original : SyntaxProtocol
60
71
61
72
var description : String {
62
73
return " std::span(pointer: \( pointerIndex) , nonescaping: \( nonescaping) ) "
@@ -71,9 +82,8 @@ struct CxxSpan: ParamInfo {
71
82
return CxxSpanThunkBuilder ( base: base, index: i - 1 , signature: funcDecl. signature,
72
83
typeMappings: typeMappings, node: original, nonescaping: nonescaping)
73
84
case . return:
74
- // TODO: actually implement std::span in return position
75
- return CxxSpanThunkBuilder ( base: base, index: - 1 , signature: funcDecl. signature,
76
- typeMappings: typeMappings, node: original, nonescaping: nonescaping)
85
+ return CxxSpanReturnThunkBuilder ( base: base, signature: funcDecl. signature,
86
+ typeMappings: typeMappings, node: original)
77
87
}
78
88
}
79
89
}
@@ -83,6 +93,7 @@ struct CountedBy: ParamInfo {
83
93
var count : ExprSyntax
84
94
var sizedBy : Bool
85
95
var nonescaping : Bool
96
+ var dependencies : [ LifetimeDependence ]
86
97
var original : SyntaxProtocol
87
98
88
99
var description : String {
@@ -156,6 +167,8 @@ func getTypeName(_ type: TypeSyntax) throws -> TokenSyntax {
156
167
return memberType. name
157
168
case . identifierType:
158
169
return type. as ( IdentifierTypeSyntax . self) !. name
170
+ case . attributedType:
171
+ return try getTypeName ( type. as ( AttributedTypeSyntax . self) !. baseType)
159
172
default :
160
173
throw DiagnosticError ( " expected pointer type, got \( type) with kind \( type. kind) " , node: type)
161
174
}
@@ -169,6 +182,13 @@ func replaceTypeName(_ type: TypeSyntax, _ name: TokenSyntax) -> TypeSyntax {
169
182
return TypeSyntax ( idType. with ( \. name, name) )
170
183
}
171
184
185
+ func replaceBaseType( _ type: TypeSyntax , _ base: TypeSyntax ) -> TypeSyntax {
186
+ if let attributedType = type. as ( AttributedTypeSyntax . self) {
187
+ return TypeSyntax ( attributedType. with ( \. baseType, base) )
188
+ }
189
+ return base
190
+ }
191
+
172
192
func getPointerMutability( text: String ) -> Mutability ? {
173
193
switch text {
174
194
case " UnsafePointer " : return . Immutable
@@ -352,7 +372,7 @@ struct CxxSpanThunkBuilder: ParamPointerBoundsThunkBuilder {
352
372
let parsedDesugaredType = TypeSyntax ( " \( raw: getUnqualifiedStdName ( desugaredType) !) " )
353
373
let genericArg = TypeSyntax ( parsedDesugaredType. as ( IdentifierTypeSyntax . self) !
354
374
. genericArgumentClause!. arguments. first!. argument) !
355
- types [ index] = TypeSyntax ( " Span< \( raw: try getTypeName ( genericArg) . text) > " )
375
+ types [ index] = replaceBaseType ( param . type , TypeSyntax ( " Span< \( raw: try getTypeName ( genericArg) . text) > " ) )
356
376
return try base. buildFunctionSignature ( types, returnType)
357
377
}
358
378
@@ -365,6 +385,38 @@ struct CxxSpanThunkBuilder: ParamPointerBoundsThunkBuilder {
365
385
}
366
386
}
367
387
388
+ struct CxxSpanReturnThunkBuilder : BoundsCheckedThunkBuilder {
389
+ public let base : BoundsCheckedThunkBuilder
390
+ public let signature : FunctionSignatureSyntax
391
+ public let typeMappings : [ String : String ]
392
+ public let node : SyntaxProtocol
393
+
394
+ func buildBoundsChecks( ) throws -> [ CodeBlockItemSyntax . Item ] {
395
+ return [ ]
396
+ }
397
+
398
+ func buildFunctionSignature( _ argTypes: [ Int : TypeSyntax ? ] , _ returnType: TypeSyntax ? ) throws
399
+ -> FunctionSignatureSyntax {
400
+ assert ( returnType == nil )
401
+ let typeName = try getTypeName ( signature. returnClause!. type) . text
402
+ guard let desugaredType = typeMappings [ typeName] else {
403
+ throw DiagnosticError (
404
+ " unable to desugar type with name ' \( typeName) ' " , node: node)
405
+ }
406
+ let parsedDesugaredType = TypeSyntax ( " \( raw: getUnqualifiedStdName ( desugaredType) !) " )
407
+ let genericArg = TypeSyntax ( parsedDesugaredType. as ( IdentifierTypeSyntax . self) !
408
+ . genericArgumentClause!. arguments. first!. argument) !
409
+ let newType = replaceBaseType ( signature. returnClause!. type,
410
+ TypeSyntax ( " Span< \( raw: try getTypeName ( genericArg) . text) > " ) )
411
+ return try base. buildFunctionSignature ( argTypes, newType)
412
+ }
413
+
414
+ func buildFunctionCall( _ pointerArgs: [ Int : ExprSyntax ] ) throws -> ExprSyntax {
415
+ let call = try base. buildFunctionCall ( pointerArgs)
416
+ return " Span(_unsafeCxxSpan: \( call) ) "
417
+ }
418
+ }
419
+
368
420
protocol PointerBoundsThunkBuilder : BoundsCheckedThunkBuilder {
369
421
var oldType : TypeSyntax { get }
370
422
var newType : TypeSyntax { get throws }
@@ -723,7 +775,7 @@ public struct SwiftifyImportMacro: PeerMacro {
723
775
}
724
776
return CountedBy (
725
777
pointerIndex: pointerExpr, count: unwrappedCountExpr, sizedBy: false ,
726
- nonescaping: false , original: ExprSyntax ( enumConstructorExpr) )
778
+ nonescaping: false , dependencies : [ ] , original: ExprSyntax ( enumConstructorExpr) )
727
779
}
728
780
729
781
static func parseSizedByEnum( _ enumConstructorExpr: FunctionCallExprSyntax ) throws -> ParamInfo {
@@ -738,7 +790,7 @@ public struct SwiftifyImportMacro: PeerMacro {
738
790
let unwrappedCountExpr = ExprSyntax ( stringLiteral: sizeExprStringLit. representedLiteralValue!)
739
791
return CountedBy (
740
792
pointerIndex: pointerExpr, count: unwrappedCountExpr, sizedBy: true , nonescaping: false ,
741
- original: ExprSyntax ( enumConstructorExpr) )
793
+ dependencies : [ ] , original: ExprSyntax ( enumConstructorExpr) )
742
794
}
743
795
744
796
static func parseEndedByEnum( _ enumConstructorExpr: FunctionCallExprSyntax ) throws -> ParamInfo {
@@ -758,6 +810,24 @@ public struct SwiftifyImportMacro: PeerMacro {
758
810
return pointerParamIndex
759
811
}
760
812
813
+ static func parseLifetimeDependence( _ enumConstructorExpr: FunctionCallExprSyntax ) throws -> ( SwiftifyExpr , LifetimeDependence ) {
814
+ let argumentList = enumConstructorExpr. arguments
815
+ let pointer : SwiftifyExpr = try parseSwiftifyExpr ( try getArgumentByName ( argumentList, " pointer " ) )
816
+ let dependsOn : Int = try getIntLiteralValue ( try getArgumentByName ( argumentList, " dependsOn " ) )
817
+ let type = try getArgumentByName ( argumentList, " type " )
818
+ let depType : DependenceType
819
+ switch try parseEnumName ( type) {
820
+ case " borrow " :
821
+ depType = DependenceType . borrow
822
+ case " copy " :
823
+ depType = DependenceType . copy
824
+ default :
825
+ throw DiagnosticError ( " expected '.copy' or '.borrow', got ' \( type) ' " , node: type)
826
+ }
827
+ let dependence = LifetimeDependence ( dependsOn: dependsOn, type: depType)
828
+ return ( pointer, dependence)
829
+ }
830
+
761
831
static func parseTypeMappingParam( _ paramAST: LabeledExprSyntax ? ) throws -> [ String : String ] ? {
762
832
guard let unwrappedParamAST = paramAST else {
763
833
return nil
@@ -786,31 +856,38 @@ public struct SwiftifyImportMacro: PeerMacro {
786
856
return dict
787
857
}
788
858
789
- static func parseCxxSpanParams (
859
+ static func parseCxxSpansInSignature (
790
860
_ signature: FunctionSignatureSyntax ,
791
861
_ typeMappings: [ String : String ] ?
792
862
) throws -> [ ParamInfo ] {
793
863
guard let typeMappings else {
794
864
return [ ]
795
865
}
796
866
var result : [ ParamInfo ] = [ ]
797
- for (idx , param ) in signature . parameterClause . parameters . enumerated ( ) {
798
- let typeName = try getTypeName ( param . type) . text;
867
+ let process = { type , expr , orig in
868
+ let typeName = try getTypeName ( type) . text;
799
869
if let desugaredType = typeMappings [ typeName] {
800
870
if let unqualifiedDesugaredType = getUnqualifiedStdName ( desugaredType) {
801
871
if unqualifiedDesugaredType. starts ( with: " span< " ) {
802
- result. append ( CxxSpan ( pointerIndex: . param ( idx + 1 ) , nonescaping: false ,
803
- original : param , typeMappings: typeMappings) )
872
+ result. append ( CxxSpan ( pointerIndex: expr , nonescaping: false ,
873
+ dependencies : [ ] , typeMappings: typeMappings, original : orig ) )
804
874
}
805
875
}
806
876
}
807
877
}
878
+ for (idx, param) in signature. parameterClause. parameters. enumerated ( ) {
879
+ try process ( param. type, . param( idx + 1 ) , param)
880
+ }
881
+ if let retClause = signature. returnClause {
882
+ try process ( retClause. type, . `return`, retClause)
883
+ }
808
884
return result
809
885
}
810
886
811
887
static func parseMacroParam(
812
888
_ paramAST: LabeledExprSyntax , _ signature: FunctionSignatureSyntax ,
813
- nonescapingPointers: inout Set < Int >
889
+ nonescapingPointers: inout Set < Int > ,
890
+ lifetimeDependencies: inout [ SwiftifyExpr : [ LifetimeDependence ] ]
814
891
) throws -> ParamInfo ? {
815
892
let paramExpr = paramAST. expression
816
893
guard let enumConstructorExpr = paramExpr. as ( FunctionCallExprSyntax . self) else {
@@ -826,9 +903,23 @@ public struct SwiftifyImportMacro: PeerMacro {
826
903
let index = try parseNonEscaping ( enumConstructorExpr)
827
904
nonescapingPointers. insert ( index)
828
905
return nil
906
+ case " lifetimeDependence " :
907
+ let ( expr, dependence) = try parseLifetimeDependence ( enumConstructorExpr)
908
+ lifetimeDependencies [ expr, default: [ ] ] . append ( dependence)
909
+ // We assume pointers annotated with lifetimebound do not escape.
910
+ if dependence. type == DependenceType . copy {
911
+ nonescapingPointers. insert ( dependence. dependsOn)
912
+ }
913
+ // The escaping is controlled when a parameter is the target of a lifetimebound.
914
+ // So we want to do the transformation to Swift's Span.
915
+ let idx = paramOrReturnIndex ( expr)
916
+ if idx != - 1 {
917
+ nonescapingPointers. insert ( idx)
918
+ }
919
+ return nil
829
920
default :
830
921
throw DiagnosticError (
831
- " expected 'countedBy', 'sizedBy', 'endedBy' or 'nonescaping ', got ' \( enumName) ' " ,
922
+ " expected 'countedBy', 'sizedBy', 'endedBy', 'nonescaping' or 'lifetimeDependence ', got ' \( enumName) ' " ,
832
923
node: enumConstructorExpr)
833
924
}
834
925
}
@@ -898,11 +989,48 @@ public struct SwiftifyImportMacro: PeerMacro {
898
989
}
899
990
900
991
static func setNonescapingPointers( _ args: inout [ ParamInfo ] , _ nonescapingPointers: Set < Int > ) {
992
+ if args. isEmpty {
993
+ return
994
+ }
901
995
for i in 0 ... args. count - 1 where nonescapingPointers. contains ( paramOrReturnIndex ( args [ i] . pointerIndex) ) {
902
996
args [ i] . nonescaping = true
903
997
}
904
998
}
905
999
1000
+ static func setLifetimeDependencies( _ args: inout [ ParamInfo ] , _ lifetimeDependencies: [ SwiftifyExpr : [ LifetimeDependence ] ] ) {
1001
+ if args. isEmpty {
1002
+ return
1003
+ }
1004
+ for i in 0 ... args. count - 1 where lifetimeDependencies. keys. contains ( args [ i] . pointerIndex) {
1005
+ args [ i] . dependencies = lifetimeDependencies [ args [ i] . pointerIndex] !
1006
+ }
1007
+ }
1008
+
1009
+ static func lifetimeAttributes( _ funcDecl: FunctionDeclSyntax ,
1010
+ _ dependencies: [ SwiftifyExpr : [ LifetimeDependence ] ] ) -> [ AttributeListSyntax . Element ] {
1011
+ let returnDependencies = dependencies [ . `return`, default: [ ] ]
1012
+ if returnDependencies. isEmpty {
1013
+ return [ ]
1014
+ }
1015
+ var args : [ LabeledExprSyntax ] = [ ]
1016
+ for dependence in returnDependencies {
1017
+ if ( dependence. type == . borrow) {
1018
+ args. append ( LabeledExprSyntax ( expression:
1019
+ DeclReferenceExprSyntax ( baseName: TokenSyntax ( " borrow " ) ) ) )
1020
+ }
1021
+ args. append ( LabeledExprSyntax ( expression:
1022
+ DeclReferenceExprSyntax ( baseName: TokenSyntax ( tryGetParamName ( funcDecl, . param( dependence. dependsOn) ) ) !) ,
1023
+ trailingComma: . commaToken( ) ) )
1024
+ }
1025
+ args [ args. count - 1 ] = args [ args. count - 1 ] . with ( \. trailingComma, nil )
1026
+ return [ . attribute( AttributeSyntax (
1027
+ atSign: . atSignToken( ) ,
1028
+ attributeName: IdentifierTypeSyntax ( name: " lifetime " ) ,
1029
+ leftParen: . leftParenToken( ) ,
1030
+ arguments: . argumentList( LabeledExprListSyntax ( args) ) ,
1031
+ rightParen: . rightParenToken( ) ) ) ]
1032
+ }
1033
+
906
1034
public static func expansion(
907
1035
of node: AttributeSyntax ,
908
1036
providingPeersOf declaration: some DeclSyntaxProtocol ,
@@ -920,13 +1048,21 @@ public struct SwiftifyImportMacro: PeerMacro {
920
1048
arguments = arguments. dropLast ( )
921
1049
}
922
1050
var nonescapingPointers = Set < Int > ( )
1051
+ var lifetimeDependencies : [ SwiftifyExpr : [ LifetimeDependence ] ] = [ : ]
923
1052
var parsedArgs = try arguments. compactMap {
924
- try parseMacroParam ( $0, funcDecl. signature, nonescapingPointers: & nonescapingPointers)
1053
+ try parseMacroParam ( $0, funcDecl. signature, nonescapingPointers: & nonescapingPointers,
1054
+ lifetimeDependencies: & lifetimeDependencies)
925
1055
}
926
- parsedArgs. append ( contentsOf: try parseCxxSpanParams ( funcDecl. signature, typeMappings) )
1056
+ parsedArgs. append ( contentsOf: try parseCxxSpansInSignature ( funcDecl. signature, typeMappings) )
927
1057
setNonescapingPointers ( & parsedArgs, nonescapingPointers)
1058
+ setLifetimeDependencies ( & parsedArgs, lifetimeDependencies)
1059
+ // We only transform non-escaping spans.
928
1060
parsedArgs = parsedArgs. filter {
929
- !( $0 is CxxSpan ) || ( $0 as! CxxSpan ) . nonescaping
1061
+ if let cxxSpanArg = $0 as? CxxSpan {
1062
+ return cxxSpanArg. nonescaping || cxxSpanArg. pointerIndex == . return
1063
+ } else {
1064
+ return true
1065
+ }
930
1066
}
931
1067
try checkArgs ( parsedArgs, funcDecl)
932
1068
let baseBuilder = FunctionCallBuilder ( funcDecl)
@@ -951,6 +1087,7 @@ public struct SwiftifyImportMacro: PeerMacro {
951
1087
returnKeyword: . keyword( . return, trailingTrivia: " " ) ,
952
1088
expression: try builder. buildFunctionCall ( [ : ] ) ) ) )
953
1089
let body = CodeBlockSyntax ( statements: CodeBlockItemListSyntax ( checks + [ call] ) )
1090
+ let lifetimeAttrs = lifetimeAttributes ( funcDecl, lifetimeDependencies)
954
1091
let newFunc =
955
1092
funcDecl
956
1093
. with ( \. signature, newSignature)
@@ -970,7 +1107,8 @@ public struct SwiftifyImportMacro: PeerMacro {
970
1107
AttributeSyntax (
971
1108
atSign: . atSignToken( ) ,
972
1109
attributeName: IdentifierTypeSyntax ( name: " _alwaysEmitIntoClient " ) ) )
973
- ] )
1110
+ ]
1111
+ + lifetimeAttrs)
974
1112
return [ DeclSyntax ( newFunc) ]
975
1113
} catch let error as DiagnosticError {
976
1114
context. diagnose (
0 commit comments