@@ -4,9 +4,6 @@ import SwiftSyntax
4
4
import SwiftSyntaxBuilder
5
5
import SwiftSyntaxMacros
6
6
7
- // Disable emitting 'MutableSpan' until it has landed
8
- let enableMutableSpan = false
9
-
10
7
// avoids depending on SwiftifyImport.swift
11
8
// all instances are reparsed and reinstantiated by the macro anyways,
12
9
// so linking is irrelevant
@@ -279,36 +276,49 @@ func getUnqualifiedStdName(_ type: String) -> String? {
279
276
func getSafePointerName( mut: Mutability , generateSpan: Bool , isRaw: Bool ) -> TokenSyntax {
280
277
switch ( mut, generateSpan, isRaw) {
281
278
case ( . Immutable, true , true ) : return " RawSpan "
282
- case ( . Mutable, true , true ) : return if enableMutableSpan {
283
- " MutableRawSpan "
284
- } else {
285
- " RawSpan "
286
- }
279
+ case ( . Mutable, true , true ) : return " MutableRawSpan "
287
280
case ( . Immutable, false , true ) : return " UnsafeRawBufferPointer "
288
281
case ( . Mutable, false , true ) : return " UnsafeMutableRawBufferPointer "
289
282
290
283
case ( . Immutable, true , false ) : return " Span "
291
- case ( . Mutable, true , false ) : return if enableMutableSpan {
292
- " MutableSpan "
293
- } else {
294
- " Span "
295
- }
284
+ case ( . Mutable, true , false ) : return " MutableSpan "
296
285
case ( . Immutable, false , false ) : return " UnsafeBufferPointer "
297
286
case ( . Mutable, false , false ) : return " UnsafeMutableBufferPointer "
298
287
}
299
288
}
300
289
301
- func transformType( _ prev: TypeSyntax , _ generateSpan: Bool , _ isSizedBy: Bool ) throws -> TypeSyntax {
290
+ func hasOwnershipSpecifier( _ attrType: AttributedTypeSyntax ) -> Bool {
291
+ return attrType. specifiers. contains ( where: { e in
292
+ guard let simpleSpec = e. as ( SimpleTypeSpecifierSyntax . self) else {
293
+ return false
294
+ }
295
+ let specifierText = simpleSpec. specifier. text
296
+ switch specifierText {
297
+ case " borrowing " :
298
+ return true
299
+ case " inout " :
300
+ return true
301
+ case " consuming " :
302
+ return true
303
+ default :
304
+ return false
305
+ }
306
+ } )
307
+ }
308
+
309
+ func transformType( _ prev: TypeSyntax , _ generateSpan: Bool , _ isSizedBy: Bool , _ setMutableSpanInout: Bool ) throws -> TypeSyntax {
302
310
if let optType = prev. as ( OptionalTypeSyntax . self) {
303
311
return TypeSyntax (
304
- optType. with ( \. wrappedType, try transformType ( optType. wrappedType, generateSpan, isSizedBy) ) )
312
+ optType. with ( \. wrappedType, try transformType ( optType. wrappedType, generateSpan, isSizedBy, setMutableSpanInout ) ) )
305
313
}
306
314
if let impOptType = prev. as ( ImplicitlyUnwrappedOptionalTypeSyntax . self) {
307
- return try transformType ( impOptType. wrappedType, generateSpan, isSizedBy)
315
+ return try transformType ( impOptType. wrappedType, generateSpan, isSizedBy, setMutableSpanInout )
308
316
}
309
317
if let attrType = prev. as ( AttributedTypeSyntax . self) {
318
+ // We insert 'inout' by default for MutableSpan, but it shouldn't override existing ownership
319
+ let setMutableSpanInoutNext = setMutableSpanInout && !hasOwnershipSpecifier( attrType)
310
320
return TypeSyntax (
311
- attrType. with ( \. baseType, try transformType ( attrType. baseType, generateSpan, isSizedBy) ) )
321
+ attrType. with ( \. baseType, try transformType ( attrType. baseType, generateSpan, isSizedBy, setMutableSpanInoutNext ) ) )
312
322
}
313
323
let name = try getTypeName ( prev)
314
324
let text = name. text
@@ -326,10 +336,15 @@ func transformType(_ prev: TypeSyntax, _ generateSpan: Bool, _ isSizedBy: Bool)
326
336
+ " - first type token is ' \( text) ' " , node: name)
327
337
}
328
338
let token = getSafePointerName ( mut: kind, generateSpan: generateSpan, isRaw: isSizedBy)
329
- if isSizedBy {
330
- return TypeSyntax ( IdentifierTypeSyntax ( name: token) )
339
+ let mainType = if isSizedBy {
340
+ TypeSyntax ( IdentifierTypeSyntax ( name: token) )
341
+ } else {
342
+ try replaceTypeName ( prev, token)
331
343
}
332
- return try replaceTypeName ( prev, token)
344
+ if setMutableSpanInout && generateSpan && kind == . Mutable {
345
+ return TypeSyntax ( " inout \( mainType) " )
346
+ }
347
+ return mainType
333
348
}
334
349
335
350
func isMutablePointerType( _ type: TypeSyntax ) -> Bool {
@@ -431,10 +446,11 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
431
446
let colon : TokenSyntax ? = label != nil ? . colonToken( ) : nil
432
447
return LabeledExprSyntax ( label: label, colon: colon, expression: arg, trailingComma: comma)
433
448
}
434
- return ExprSyntax (
449
+ let call = ExprSyntax (
435
450
FunctionCallExprSyntax (
436
451
calledExpression: functionRef, leftParen: . leftParenToken( ) ,
437
452
arguments: LabeledExprListSyntax ( labeledArgs) , rightParen: . rightParenToken( ) ) )
453
+ return " unsafe \( call) "
438
454
}
439
455
}
440
456
@@ -446,6 +462,7 @@ struct CxxSpanThunkBuilder: SpanBoundsThunkBuilder, ParamBoundsThunkBuilder {
446
462
public let node : SyntaxProtocol
447
463
public let nonescaping : Bool
448
464
let isSizedBy : Bool = false
465
+ let isParameter : Bool = true
449
466
450
467
func buildBoundsChecks( ) throws -> [ CodeBlockItemSyntax . Item ] {
451
468
return try base. buildBoundsChecks ( )
@@ -462,8 +479,26 @@ struct CxxSpanThunkBuilder: SpanBoundsThunkBuilder, ParamBoundsThunkBuilder {
462
479
var args = pointerArgs
463
480
let typeName = getUnattributedType ( oldType) . description
464
481
assert ( args [ index] == nil )
465
- args [ index] = ExprSyntax ( " \( raw: typeName) ( \( raw: name) ) " )
466
- return try base. buildFunctionCall ( args)
482
+
483
+ let ( _, isConst) = dropCxxQualifiers ( try genericArg)
484
+ if isConst {
485
+ args [ index] = ExprSyntax ( " \( raw: typeName) ( \( raw: name) ) " )
486
+ return try base. buildFunctionCall ( args)
487
+ } else {
488
+ let unwrappedName = TokenSyntax ( " _ \( name) Ptr " )
489
+ args [ index] = ExprSyntax ( " \( raw: typeName) ( \( unwrappedName) ) " )
490
+ let call = try base. buildFunctionCall ( args)
491
+
492
+ // MutableSpan - unlike Span - cannot be bitcast to std::span due to being ~Copyable,
493
+ // so unwrap it to an UnsafeMutableBufferPointer that we can cast
494
+ let unwrappedCall = ExprSyntax (
495
+ """
496
+ unsafe \( name) .withUnsafeMutableBufferPointer { \( unwrappedName) in
497
+ return \( call)
498
+ }
499
+ """ )
500
+ return unwrappedCall
501
+ }
467
502
}
468
503
}
469
504
@@ -472,6 +507,7 @@ struct CxxSpanReturnThunkBuilder: SpanBoundsThunkBuilder {
472
507
public let signature : FunctionSignatureSyntax
473
508
public let typeMappings : [ String : String ]
474
509
public let node : SyntaxProtocol
510
+ let isParameter : Bool = false
475
511
476
512
var oldType : TypeSyntax {
477
513
return signature. returnClause!. type
@@ -490,12 +526,12 @@ struct CxxSpanReturnThunkBuilder: SpanBoundsThunkBuilder {
490
526
func buildFunctionCall( _ pointerArgs: [ Int : ExprSyntax ] ) throws -> ExprSyntax {
491
527
let call = try base. buildFunctionCall ( pointerArgs)
492
528
let ( _, isConst) = dropCxxQualifiers ( try genericArg)
493
- let cast = if isConst || !enableMutableSpan {
529
+ let cast = if isConst {
494
530
" Span "
495
531
} else {
496
532
" MutableSpan "
497
533
}
498
- return " _cxxOverrideLifetime( \( raw: cast) (_unsafeCxxSpan: \( call) ), copying: ()) "
534
+ return " unsafe _cxxOverrideLifetime(\( raw: cast) (_unsafeCxxSpan: \( call) ), copying: ()) "
499
535
}
500
536
}
501
537
@@ -508,11 +544,12 @@ protocol BoundsThunkBuilder: BoundsCheckedThunkBuilder {
508
544
protocol SpanBoundsThunkBuilder : BoundsThunkBuilder {
509
545
var typeMappings : [ String : String ] { get }
510
546
var node : SyntaxProtocol { get }
547
+ var isParameter : Bool { get }
511
548
}
512
549
extension SpanBoundsThunkBuilder {
513
550
var desugaredType : TypeSyntax {
514
551
get throws {
515
- let typeName = try getUnattributedType ( oldType) . description
552
+ let typeName = getUnattributedType ( oldType) . description
516
553
guard let desugaredTypeName = typeMappings [ typeName] else {
517
554
throw DiagnosticError (
518
555
" unable to desugar type with name ' \( typeName) ' " , node: node)
@@ -547,14 +584,18 @@ extension SpanBoundsThunkBuilder {
547
584
var newType : TypeSyntax {
548
585
get throws {
549
586
let ( strippedArg, isConst) = dropCxxQualifiers ( try genericArg)
550
- let mutablePrefix = if isConst || !enableMutableSpan {
587
+ let mutablePrefix = if isConst {
551
588
" "
552
589
} else {
553
590
" Mutable "
554
591
}
555
- return replaceBaseType (
592
+ let mainType = replaceBaseType (
556
593
oldType,
557
594
TypeSyntax ( " \( raw: mutablePrefix) Span< \( raw: strippedArg) > " ) )
595
+ if !isConst && isParameter {
596
+ return TypeSyntax ( " inout \( mainType) " )
597
+ }
598
+ return mainType
558
599
}
559
600
}
560
601
}
@@ -563,13 +604,14 @@ protocol PointerBoundsThunkBuilder: BoundsThunkBuilder {
563
604
var nullable : Bool { get }
564
605
var isSizedBy : Bool { get }
565
606
var generateSpan : Bool { get }
607
+ var isParameter : Bool { get }
566
608
}
567
609
568
610
extension PointerBoundsThunkBuilder {
569
611
var nullable : Bool { return oldType. is ( OptionalTypeSyntax . self) }
570
612
571
613
var newType : TypeSyntax { get throws {
572
- return try transformType ( oldType, generateSpan, isSizedBy) }
614
+ return try transformType ( oldType, generateSpan, isSizedBy, isParameter ) }
573
615
}
574
616
}
575
617
@@ -599,8 +641,9 @@ struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder {
599
641
public let nonescaping : Bool
600
642
public let isSizedBy : Bool
601
643
public let dependencies : [ LifetimeDependence ]
644
+ let isParameter : Bool = false
602
645
603
- var generateSpan : Bool { !dependencies. isEmpty && ( !isMutablePointerType ( oldType ) || enableMutableSpan ) }
646
+ var generateSpan : Bool { !dependencies. isEmpty }
604
647
605
648
var oldType : TypeSyntax {
606
649
return signature. returnClause!. type
@@ -623,9 +666,25 @@ struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder {
623
666
} else {
624
667
" start "
625
668
}
669
+ var cast = try newType
670
+ if nullable {
671
+ if let optType = cast. as ( OptionalTypeSyntax . self) {
672
+ cast = optType. wrappedType
673
+ }
674
+ return """
675
+ { () in
676
+ let _resultValue = \( call)
677
+ if unsafe _resultValue == nil {
678
+ return nil
679
+ } else {
680
+ return unsafe \( raw: try cast) ( \( raw: startLabel) : _resultValue!, count: Int( \( countExpr) ))
681
+ }
682
+ }()
683
+ """
684
+ }
626
685
return
627
686
"""
628
- \( raw: try newType ) ( \( raw: startLabel) : \( call) , count: Int( \( countExpr) ))
687
+ unsafe \( raw: try cast ) ( \( raw: startLabel) : \( call) , count: Int( \( countExpr) ))
629
688
"""
630
689
}
631
690
}
@@ -639,8 +698,9 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
639
698
public let nonescaping : Bool
640
699
public let isSizedBy : Bool
641
700
public let skipTrivialCount : Bool
701
+ let isParameter : Bool = true
642
702
643
- var generateSpan : Bool { nonescaping && ( !isMutablePointerType ( oldType ) || enableMutableSpan ) }
703
+ var generateSpan : Bool { nonescaping }
644
704
645
705
func buildFunctionSignature( _ argTypes: [ Int : TypeSyntax ? ] , _ returnType: TypeSyntax ? ) throws
646
706
-> ( FunctionSignatureSyntax , Bool ) {
@@ -702,11 +762,16 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
702
762
let call = try base. buildFunctionCall ( args)
703
763
let ptrRef = unwrapIfNullable ( ExprSyntax ( DeclReferenceExprSyntax ( baseName: name) ) )
704
764
705
- let funcName = isSizedBy ? " withUnsafeBytes " : " withUnsafeBufferPointer "
765
+ let funcName = switch ( isSizedBy, isMutablePointerType ( oldType) ) {
766
+ case ( true , true ) : " withUnsafeMutableBytes "
767
+ case ( true , false ) : " withUnsafeBytes "
768
+ case ( false , true ) : " withUnsafeMutableBufferPointer "
769
+ case ( false , false ) : " withUnsafeBufferPointer "
770
+ }
706
771
let unwrappedCall = ExprSyntax (
707
772
"""
708
- \( ptrRef) . \( raw: funcName) { \( unwrappedName) in
709
- return unsafe \( call)
773
+ unsafe \( ptrRef) . \( raw: funcName) { \( unwrappedName) in
774
+ return \( call)
710
775
}
711
776
""" )
712
777
return unwrappedCall
@@ -766,11 +831,11 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
766
831
nullArgs [ index] = ExprSyntax ( NilLiteralExprSyntax ( nilKeyword: . keyword( . nil ) ) )
767
832
return ExprSyntax (
768
833
"""
769
- if \( name) == nil {
770
- unsafe \( try base. buildFunctionCall ( nullArgs) )
771
- } else {
772
- \( unwrappedCall)
773
- }
834
+ { () in return if \( name) == nil {
835
+ \( try base. buildFunctionCall ( nullArgs) )
836
+ } else {
837
+ \( unwrappedCall)
838
+ } }()
774
839
""" )
775
840
}
776
841
return unwrappedCall
@@ -1161,7 +1226,7 @@ public struct SwiftifyImportMacro: PeerMacro {
1161
1226
}
1162
1227
}
1163
1228
1164
- static func lifetimeAttributes ( _ funcDecl: FunctionDeclSyntax ,
1229
+ static func getReturnLifetimeAttribute ( _ funcDecl: FunctionDeclSyntax ,
1165
1230
_ dependencies: [ SwiftifyExpr : [ LifetimeDependence ] ] ) -> [ AttributeListSyntax . Element ] {
1166
1231
let returnDependencies = dependencies [ . `return`, default: [ ] ]
1167
1232
if returnDependencies. isEmpty {
@@ -1190,6 +1255,66 @@ public struct SwiftifyImportMacro: PeerMacro {
1190
1255
rightParen: . rightParenToken( ) ) ) ]
1191
1256
}
1192
1257
1258
+ static func isMutableSpan( _ type: TypeSyntax ) -> Bool {
1259
+ if let optType = type. as ( OptionalTypeSyntax . self) {
1260
+ return isMutableSpan ( optType. wrappedType)
1261
+ }
1262
+ if let impOptType = type. as ( ImplicitlyUnwrappedOptionalTypeSyntax . self) {
1263
+ return isMutableSpan ( impOptType. wrappedType)
1264
+ }
1265
+ if let attrType = type. as ( AttributedTypeSyntax . self) {
1266
+ return isMutableSpan ( attrType. baseType)
1267
+ }
1268
+ guard let identifierType = type. as ( IdentifierTypeSyntax . self) else {
1269
+ return false
1270
+ }
1271
+ let name = identifierType. name. text
1272
+ return name == " MutableSpan " || name == " MutableRawSpan "
1273
+ }
1274
+
1275
+ static func containsLifetimeAttr( _ attrs: AttributeListSyntax , for paramName: TokenSyntax ) -> Bool {
1276
+ for elem in attrs {
1277
+ guard let attr = elem. as ( AttributeSyntax . self) else {
1278
+ continue
1279
+ }
1280
+ if attr. attributeName != " lifetime " {
1281
+ continue
1282
+ }
1283
+ guard let args = attr. arguments? . as ( LabeledExprListSyntax . self) else {
1284
+ continue
1285
+ }
1286
+ for arg in args {
1287
+ if arg. label == paramName {
1288
+ return true
1289
+ }
1290
+ }
1291
+ }
1292
+ return false
1293
+ }
1294
+
1295
+ // Mutable[Raw]Span parameters need explicit @lifetime annotations since they are inout
1296
+ static func paramLifetimeAttributes( _ newSignature: FunctionSignatureSyntax , _ oldAttrs: AttributeListSyntax ) -> [ AttributeListSyntax . Element ] {
1297
+ var defaultLifetimes : [ AttributeListSyntax . Element ] = [ ]
1298
+ for param in newSignature. parameterClause. parameters {
1299
+ if !isMutableSpan( param. type) {
1300
+ continue
1301
+ }
1302
+ let paramName = param. secondName ?? param. firstName
1303
+ if containsLifetimeAttr ( oldAttrs, for: paramName) {
1304
+ continue
1305
+ }
1306
+ let expr = ExprSyntax ( " \( paramName) : copy \( paramName) " )
1307
+
1308
+ defaultLifetimes. append ( . attribute( AttributeSyntax (
1309
+ atSign: . atSignToken( ) ,
1310
+ attributeName: IdentifierTypeSyntax ( name: " lifetime " ) ,
1311
+ leftParen: . leftParenToken( ) ,
1312
+ arguments: . argumentList( LabeledExprListSyntax ( [ LabeledExprSyntax ( expression: expr) ] ) ) ,
1313
+ rightParen: . rightParenToken( ) ) ) )
1314
+ }
1315
+ return defaultLifetimes
1316
+ }
1317
+
1193
1318
public static func expansion(
1194
1319
of node: AttributeSyntax ,
1195
1320
providingPeersOf declaration: some DeclSyntaxProtocol ,
@@ -1255,9 +1380,10 @@ public struct SwiftifyImportMacro: PeerMacro {
1255
1380
item: CodeBlockItemSyntax . Item (
1256
1381
ReturnStmtSyntax (
1257
1382
returnKeyword: . keyword( . return, trailingTrivia: " " ) ,
1258
- expression: ExprSyntax ( " unsafe \( try builder. buildFunctionCall ( [ : ] ) ) " ) ) ) )
1383
+ expression: try builder. buildFunctionCall ( [ : ] ) ) ) )
1259
1384
let body = CodeBlockSyntax ( statements: CodeBlockItemListSyntax ( checks + [ call] ) )
1260
- let lifetimeAttrs = lifetimeAttributes ( funcDecl, lifetimeDependencies)
1385
+ let returnLifetimeAttribute = getReturnLifetimeAttribute ( funcDecl, lifetimeDependencies)
1386
+ let lifetimeAttrs = returnLifetimeAttribute + paramLifetimeAttributes( newSignature, funcDecl. attributes)
1261
1387
let disfavoredOverload : [ AttributeListSyntax . Element ] = ( onlyReturnTypeChanged ? [
1262
1388
. attribute(
1263
1389
AttributeSyntax (
0 commit comments