@@ -45,19 +45,11 @@ protocol ParamInfo: CustomStringConvertible {
45
45
) -> BoundsCheckedThunkBuilder
46
46
}
47
47
48
- func getParamName( _ param: FunctionParameterSyntax , _ paramIndex: Int ) -> TokenSyntax {
49
- let name = param. secondName ?? param. firstName
50
- if name. trimmed. text == " _ " {
51
- return " _param \( raw: paramIndex) "
52
- }
53
- return name
54
- }
55
-
56
48
func tryGetParamName( _ funcDecl: FunctionDeclSyntax , _ expr: SwiftifyExpr ) -> TokenSyntax ? {
57
49
switch expr {
58
50
case . param( let i) :
59
51
let funcParam = getParam ( funcDecl, i - 1 )
60
- return getParamName ( funcParam, i - 1 )
52
+ return funcParam. name
61
53
case . `self`:
62
54
return . keyword( . self )
63
55
default : return nil
@@ -427,12 +419,7 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
427
419
// filter out deleted parameters, i.e. ones where argTypes[i] _contains_ nil
428
420
return type == nil || type! != nil
429
421
} . map { ( i: Int , e: FunctionParameterSyntax ) in
430
- let param = e. with ( \. type, ( argTypes [ i] ?? e. type) !)
431
- let name = param. secondName ?? param. firstName
432
- if name. trimmed. text == " _ " {
433
- return param. with ( \. secondName, getParamName ( param, i) )
434
- }
435
- return param
422
+ e. with ( \. type, ( argTypes [ i] ?? e. type) !)
436
423
}
437
424
if let last = newParams. popLast ( ) {
438
425
newParams. append ( last. with ( \. trailingComma, nil ) )
@@ -450,9 +437,7 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
450
437
let functionRef = DeclReferenceExprSyntax ( baseName: base. name)
451
438
let args : [ ExprSyntax ] = base. signature. parameterClause. parameters. enumerated ( )
452
439
. map { ( i: Int , param: FunctionParameterSyntax ) in
453
- let name = getParamName ( param, i)
454
- let declref = DeclReferenceExprSyntax ( baseName: name)
455
- return pointerArgs [ i] ?? ExprSyntax ( declref)
440
+ return pointerArgs [ i] ?? ExprSyntax ( " \( param. name) " )
456
441
}
457
442
let labels : [ TokenSyntax ? ] = base. signature. parameterClause. parameters. map { param in
458
443
let firstName = param. firstName. trimmed
@@ -468,7 +453,8 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
468
453
comma = . commaToken( )
469
454
}
470
455
let colon : TokenSyntax ? = label != nil ? . colonToken( ) : nil
471
- return LabeledExprSyntax ( label: label, colon: colon, expression: arg, trailingComma: comma)
456
+ // The compiler emits warnings if you unnecessarily escape labels in function calls
457
+ return LabeledExprSyntax ( label: label? . withoutBackticks, colon: colon, expression: arg, trailingComma: comma)
472
458
}
473
459
let call = ExprSyntax (
474
460
FunctionCallExprSyntax (
@@ -510,7 +496,7 @@ struct CxxSpanThunkBuilder: SpanBoundsThunkBuilder, ParamBoundsThunkBuilder {
510
496
args [ index] = ExprSyntax ( " \( raw: typeName) ( \( raw: name) ) " )
511
497
return try base. buildFunctionCall ( args)
512
498
} else {
513
- let unwrappedName = TokenSyntax ( " _ \( name) Ptr " )
499
+ let unwrappedName = TokenSyntax ( " _ \( name. withoutBackticks ) Ptr " )
514
500
args [ index] = ExprSyntax ( " \( raw: typeName) ( \( unwrappedName) ) " )
515
501
let call = try base. buildFunctionCall ( args)
516
502
@@ -663,7 +649,7 @@ extension ParamBoundsThunkBuilder {
663
649
}
664
650
665
651
var name : TokenSyntax {
666
- getParamName ( param, index )
652
+ param. name
667
653
}
668
654
}
669
655
@@ -796,7 +782,7 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
796
782
}
797
783
798
784
func buildUnwrapCall( _ argOverrides: [ Int : ExprSyntax ] ) throws -> ExprSyntax {
799
- let unwrappedName = TokenSyntax ( " _ \( name) Ptr " )
785
+ let unwrappedName = TokenSyntax ( " _ \( name. withoutBackticks ) Ptr " ) . escapeIfNeeded
800
786
var args = argOverrides
801
787
let argExpr = ExprSyntax ( " \( unwrappedName) .baseAddress " )
802
788
assert ( args [ index] == nil )
@@ -809,7 +795,7 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
809
795
}
810
796
}
811
797
let call = try base. buildFunctionCall ( args)
812
- let ptrRef = unwrapIfNullable ( ExprSyntax ( DeclReferenceExprSyntax ( baseName : name) ) )
798
+ let ptrRef = unwrapIfNullable ( " \( name) " )
813
799
814
800
let funcName =
815
801
switch ( isSizedBy, isMutablePointerType ( oldType) ) {
@@ -1004,7 +990,7 @@ func parseSwiftifyExpr(_ expr: ExprSyntax) throws -> SwiftifyExpr {
1004
990
}
1005
991
1006
992
func parseCountedByEnum(
1007
- _ enumConstructorExpr: FunctionCallExprSyntax , _ signature: FunctionSignatureSyntax
993
+ _ enumConstructorExpr: FunctionCallExprSyntax , _ signature: FunctionSignatureSyntax , _ rewriter : CountExprRewriter
1008
994
) throws -> ParamInfo {
1009
995
let argumentList = enumConstructorExpr. arguments
1010
996
let pointerExprArg = try getArgumentByName ( argumentList, " pointer " )
@@ -1015,7 +1001,8 @@ func parseCountedByEnum(
1015
1001
" expected string literal for 'count' parameter, got \( countExprArg) " , node: countExprArg)
1016
1002
}
1017
1003
let unwrappedCountExpr = ExprSyntax ( stringLiteral: countExprStringLit. representedLiteralValue!)
1018
- if let countVar = unwrappedCountExpr. as ( DeclReferenceExprSyntax . self) {
1004
+ let rewrittenCountExpr = rewriter. visit ( unwrappedCountExpr)
1005
+ if let countVar = rewrittenCountExpr. as ( DeclReferenceExprSyntax . self) {
1019
1006
// Perform this lookup here so we can override the position to point to the string literal
1020
1007
// instead of line 1, column 1
1021
1008
do {
@@ -1025,11 +1012,11 @@ func parseCountedByEnum(
1025
1012
}
1026
1013
}
1027
1014
return CountedBy (
1028
- pointerIndex: pointerExpr, count: unwrappedCountExpr , sizedBy: false ,
1015
+ pointerIndex: pointerExpr, count: rewrittenCountExpr , sizedBy: false ,
1029
1016
nonescaping: false , dependencies: [ ] , original: ExprSyntax ( enumConstructorExpr) )
1030
1017
}
1031
1018
1032
- func parseSizedByEnum( _ enumConstructorExpr: FunctionCallExprSyntax ) throws -> ParamInfo {
1019
+ func parseSizedByEnum( _ enumConstructorExpr: FunctionCallExprSyntax , _ rewriter : CountExprRewriter ) throws -> ParamInfo {
1033
1020
let argumentList = enumConstructorExpr. arguments
1034
1021
let pointerExprArg = try getArgumentByName ( argumentList, " pointer " )
1035
1022
let pointerExpr : SwiftifyExpr = try parseSwiftifyExpr ( pointerExprArg)
@@ -1039,8 +1026,9 @@ func parseSizedByEnum(_ enumConstructorExpr: FunctionCallExprSyntax) throws -> P
1039
1026
" expected string literal for 'size' parameter, got \( sizeExprArg) " , node: sizeExprArg)
1040
1027
}
1041
1028
let unwrappedCountExpr = ExprSyntax ( stringLiteral: sizeExprStringLit. representedLiteralValue!)
1029
+ let rewrittenCountExpr = rewriter. visit ( unwrappedCountExpr)
1042
1030
return CountedBy (
1043
- pointerIndex: pointerExpr, count: unwrappedCountExpr , sizedBy: true , nonescaping: false ,
1031
+ pointerIndex: pointerExpr, count: rewrittenCountExpr , sizedBy: true , nonescaping: false ,
1044
1032
dependencies: [ ] , original: ExprSyntax ( enumConstructorExpr) )
1045
1033
}
1046
1034
@@ -1177,7 +1165,7 @@ func parseCxxSpansInSignature(
1177
1165
}
1178
1166
1179
1167
func parseMacroParam(
1180
- _ paramAST: LabeledExprSyntax , _ signature: FunctionSignatureSyntax ,
1168
+ _ paramAST: LabeledExprSyntax , _ signature: FunctionSignatureSyntax , _ rewriter : CountExprRewriter ,
1181
1169
nonescapingPointers: inout Set < Int > ,
1182
1170
lifetimeDependencies: inout [ SwiftifyExpr : [ LifetimeDependence ] ]
1183
1171
) throws -> ParamInfo ? {
@@ -1188,8 +1176,8 @@ func parseMacroParam(
1188
1176
}
1189
1177
let enumName = try parseEnumName ( paramExpr)
1190
1178
switch enumName {
1191
- case " countedBy " : return try parseCountedByEnum ( enumConstructorExpr, signature)
1192
- case " sizedBy " : return try parseSizedByEnum ( enumConstructorExpr)
1179
+ case " countedBy " : return try parseCountedByEnum ( enumConstructorExpr, signature, rewriter )
1180
+ case " sizedBy " : return try parseSizedByEnum ( enumConstructorExpr, rewriter )
1193
1181
case " endedBy " : return try parseEndedByEnum ( enumConstructorExpr)
1194
1182
case " nonescaping " :
1195
1183
let index = try parseNonEscaping ( enumConstructorExpr)
@@ -1438,7 +1426,7 @@ func paramLifetimeAttributes(
1438
1426
if !isMutableSpan( param. type) {
1439
1427
continue
1440
1428
}
1441
- let paramName = param. secondName ?? param . firstName
1429
+ let paramName = param. name
1442
1430
if containsLifetimeAttr ( oldAttrs, for: paramName) {
1443
1431
continue
1444
1432
}
@@ -1456,6 +1444,61 @@ func paramLifetimeAttributes(
1456
1444
return defaultLifetimes
1457
1445
}
1458
1446
1447
+ class CountExprRewriter : SyntaxRewriter {
1448
+ public let nameMap : [ String : String ]
1449
+
1450
+ init ( _ renamedParams: [ String : String ] ) {
1451
+ nameMap = renamedParams
1452
+ }
1453
+
1454
+ override func visit( _ node: DeclReferenceExprSyntax ) -> ExprSyntax {
1455
+ if let newName = nameMap [ node. baseName. trimmed. text] {
1456
+ return ExprSyntax (
1457
+ node. with (
1458
+ \. baseName,
1459
+ . identifier(
1460
+ newName, leadingTrivia: node. baseName. leadingTrivia,
1461
+ trailingTrivia: node. baseName. trailingTrivia) ) )
1462
+ }
1463
+ return escapeIfNeeded ( node)
1464
+ }
1465
+ }
1466
+
1467
+ func renameParameterNamesIfNeeded( _ funcDecl: FunctionDeclSyntax ) -> ( FunctionDeclSyntax , CountExprRewriter ) {
1468
+ let params = funcDecl. signature. parameterClause. parameters
1469
+ let funcName = funcDecl. name. withoutBackticks. trimmed. text
1470
+ let shouldRename = params. contains ( where: { param in
1471
+ let paramName = param. name. trimmed. text
1472
+ return paramName == " _ " || paramName == funcName || " ` \( paramName) ` " == funcName
1473
+ } )
1474
+ var renamedParams : [ String : String ] = [ : ]
1475
+ let newParams = params. enumerated ( ) . map { ( i, param) in
1476
+ let secondName = if shouldRename {
1477
+ // Including funcName in name prevents clash with function name.
1478
+ // Renaming all parameters if one requires renaming guarantees that other parameters don't clash with the renamed one.
1479
+ TokenSyntax ( " _ \( raw: funcName) _param \( raw: i) " )
1480
+ } else {
1481
+ param. secondName? . escapeIfNeeded
1482
+ }
1483
+ let firstName = param. firstName. escapeIfNeeded
1484
+ let newParam = param. with ( \. secondName, secondName)
1485
+ . with ( \. firstName, firstName)
1486
+ let newName = newParam. name. trimmed. text
1487
+ let oldName = param. name. trimmed. text
1488
+ if newName != oldName {
1489
+ renamedParams [ oldName] = newName
1490
+ }
1491
+ return newParam
1492
+ }
1493
+ let newDecl = if renamedParams. count > 0 {
1494
+ funcDecl. with ( \. signature. parameterClause. parameters, FunctionParameterListSyntax ( newParams) )
1495
+ } else {
1496
+ // Keeps source locations for diagnostics, in the common case where nothing was renamed
1497
+ funcDecl
1498
+ }
1499
+ return ( newDecl, CountExprRewriter ( renamedParams) )
1500
+ }
1501
+
1459
1502
/// A macro that adds safe(r) wrappers for functions with unsafe pointer types.
1460
1503
/// Depends on bounds, escapability and lifetime information for each pointer.
1461
1504
/// Intended to map to C attributes like __counted_by, __ended_by and __no_escape,
@@ -1469,9 +1512,10 @@ public struct SwiftifyImportMacro: PeerMacro {
1469
1512
in context: some MacroExpansionContext
1470
1513
) throws -> [ DeclSyntax ] {
1471
1514
do {
1472
- guard let funcDecl = declaration. as ( FunctionDeclSyntax . self) else {
1515
+ guard let origFuncDecl = declaration. as ( FunctionDeclSyntax . self) else {
1473
1516
throw DiagnosticError ( " @_SwiftifyImport only works on functions " , node: declaration)
1474
1517
}
1518
+ let ( funcDecl, rewriter) = renameParameterNamesIfNeeded ( origFuncDecl)
1475
1519
1476
1520
let argumentList = node. arguments!. as ( LabeledExprListSyntax . self) !
1477
1521
var arguments = [ LabeledExprSyntax] ( argumentList)
@@ -1487,7 +1531,7 @@ public struct SwiftifyImportMacro: PeerMacro {
1487
1531
var lifetimeDependencies : [ SwiftifyExpr : [ LifetimeDependence ] ] = [ : ]
1488
1532
var parsedArgs = try arguments. compactMap {
1489
1533
try parseMacroParam (
1490
- $0, funcDecl. signature, nonescapingPointers: & nonescapingPointers,
1534
+ $0, funcDecl. signature, rewriter , nonescapingPointers: & nonescapingPointers,
1491
1535
lifetimeDependencies: & lifetimeDependencies)
1492
1536
}
1493
1537
parsedArgs. append ( contentsOf: try parseCxxSpansInSignature ( funcDecl. signature, typeMappings) )
@@ -1627,3 +1671,33 @@ extension TypeSyntaxProtocol {
1627
1671
return false
1628
1672
}
1629
1673
}
1674
+
1675
+ extension FunctionParameterSyntax {
1676
+ var name : TokenSyntax {
1677
+ self . secondName ?? self . firstName
1678
+ }
1679
+ }
1680
+
1681
+ extension TokenSyntax {
1682
+ public var withoutBackticks : TokenSyntax {
1683
+ return . identifier( self . identifier!. name)
1684
+ }
1685
+
1686
+ public var escapeIfNeeded : TokenSyntax {
1687
+ var parser = Parser ( " let \( self ) " )
1688
+ let decl = DeclSyntax . parse ( from: & parser)
1689
+ if !decl. hasError {
1690
+ return self
1691
+ } else {
1692
+ return self . copyTrivia ( to: " ` \( raw: self . trimmed. text) ` " )
1693
+ }
1694
+ }
1695
+
1696
+ public func copyTrivia( to other: TokenSyntax ) -> TokenSyntax {
1697
+ return . identifier( other. text, leadingTrivia: self . leadingTrivia, trailingTrivia: self . trailingTrivia)
1698
+ }
1699
+ }
1700
+
1701
+ func escapeIfNeeded( _ identifier: DeclReferenceExprSyntax ) -> ExprSyntax {
1702
+ return " \( identifier. baseName. escapeIfNeeded) "
1703
+ }
0 commit comments