@@ -4,6 +4,9 @@ import SwiftSyntax
4
4
import SwiftSyntaxBuilder
5
5
import SwiftSyntaxMacros
6
6
7
+ // Disable emitting 'MutableSpan' until it has landed
8
+ let enableMutableSpan = false
9
+
7
10
// avoids depending on SwiftifyImport.swift
8
11
// all instances are reparsed and reinstantiated by the macro anyways,
9
12
// so linking is irrelevant
@@ -213,22 +216,26 @@ func replaceBaseType(_ type: TypeSyntax, _ base: TypeSyntax) -> TypeSyntax {
213
216
214
217
// C++ type qualifiers, `const T` and `volatile T`, are encoded as fake generic
215
218
// types, `__cxxConst<T>` and `__cxxVolatile<T>` respectively. Remove those.
216
- func dropQualifierGenerics( _ type: TypeSyntax ) -> TypeSyntax {
217
- guard let identifier = type. as ( IdentifierTypeSyntax . self) else { return type }
218
- guard let generic = identifier. genericArgumentClause else { return type }
219
- guard let genericArg = generic. arguments. first else { return type }
220
- guard case . type( let argType) = genericArg. argument else { return type }
219
+ // Second return value is true if __cxxConst was stripped.
220
+ func dropQualifierGenerics( _ type: TypeSyntax ) -> ( TypeSyntax , Bool ) {
221
+ guard let identifier = type. as ( IdentifierTypeSyntax . self) else { return ( type, false ) }
222
+ guard let generic = identifier. genericArgumentClause else { return ( type, false ) }
223
+ guard let genericArg = generic. arguments. first else { return ( type, false ) }
224
+ guard case . type( let argType) = genericArg. argument else { return ( type, false ) }
221
225
switch identifier. name. text {
222
- case " __cxxConst " , " __cxxVolatile " :
226
+ case " __cxxConst " :
227
+ let ( retType, _) = dropQualifierGenerics ( argType)
228
+ return ( retType, true )
229
+ case " __cxxVolatile " :
223
230
return dropQualifierGenerics ( argType)
224
231
default :
225
- return type
232
+ return ( type, false )
226
233
}
227
234
}
228
235
229
236
// The generated type names for template instantiations sometimes contain
230
237
// encoded qualifiers for disambiguation purposes. We need to remove those.
231
- func dropCxxQualifiers( _ type: TypeSyntax ) -> TypeSyntax {
238
+ func dropCxxQualifiers( _ type: TypeSyntax ) -> ( TypeSyntax , Bool ) {
232
239
if let attributed = type. as ( AttributedTypeSyntax . self) {
233
240
return dropCxxQualifiers ( attributed. baseType)
234
241
}
@@ -272,12 +279,20 @@ func getUnqualifiedStdName(_ type: String) -> String? {
272
279
func getSafePointerName( mut: Mutability , generateSpan: Bool , isRaw: Bool ) -> TokenSyntax {
273
280
switch ( mut, generateSpan, isRaw) {
274
281
case ( . Immutable, true , true ) : return " RawSpan "
275
- case ( . Mutable, true , true ) : return " MutableRawSpan "
282
+ case ( . Mutable, true , true ) : return if enableMutableSpan {
283
+ " MutableRawSpan "
284
+ } else {
285
+ " RawSpan "
286
+ }
276
287
case ( . Immutable, false , true ) : return " UnsafeRawBufferPointer "
277
288
case ( . Mutable, false , true ) : return " UnsafeMutableRawBufferPointer "
278
289
279
290
case ( . Immutable, true , false ) : return " Span "
280
- case ( . Mutable, true , false ) : return " MutableSpan "
291
+ case ( . Mutable, true , false ) : return if enableMutableSpan {
292
+ " MutableSpan "
293
+ } else {
294
+ " Span "
295
+ }
281
296
case ( . Immutable, false , false ) : return " UnsafeBufferPointer "
282
297
case ( . Mutable, false , false ) : return " UnsafeMutableBufferPointer "
283
298
}
@@ -317,6 +332,28 @@ func transformType(_ prev: TypeSyntax, _ generateSpan: Bool, _ isSizedBy: Bool)
317
332
return try replaceTypeName ( prev, token)
318
333
}
319
334
335
+ func isMutablePointerType( _ type: TypeSyntax ) -> Bool {
336
+ if let optType = type. as ( OptionalTypeSyntax . self) {
337
+ return isMutablePointerType ( optType. wrappedType)
338
+ }
339
+ if let impOptType = type. as ( ImplicitlyUnwrappedOptionalTypeSyntax . self) {
340
+ return isMutablePointerType ( impOptType. wrappedType)
341
+ }
342
+ if let attrType = type. as ( AttributedTypeSyntax . self) {
343
+ return isMutablePointerType ( attrType. baseType)
344
+ }
345
+ do {
346
+ let name = try getTypeName ( type)
347
+ let text = name. text
348
+ guard let kind: Mutability = getPointerMutability ( text: text) else {
349
+ return false
350
+ }
351
+ return kind == . Mutable
352
+ } catch _ {
353
+ return false
354
+ }
355
+ }
356
+
320
357
protocol BoundsCheckedThunkBuilder {
321
358
func buildFunctionCall( _ pointerArgs: [ Int : ExprSyntax ] ) throws -> ExprSyntax
322
359
func buildBoundsChecks( ) throws -> [ CodeBlockItemSyntax . Item ]
@@ -401,7 +438,7 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
401
438
}
402
439
}
403
440
404
- struct CxxSpanThunkBuilder : ParamPointerBoundsThunkBuilder {
441
+ struct CxxSpanThunkBuilder : SpanBoundsThunkBuilder , ParamBoundsThunkBuilder {
405
442
public let base : BoundsCheckedThunkBuilder
406
443
public let index : Int
407
444
public let signature : FunctionSignatureSyntax
@@ -417,17 +454,7 @@ struct CxxSpanThunkBuilder: ParamPointerBoundsThunkBuilder {
417
454
func buildFunctionSignature( _ argTypes: [ Int : TypeSyntax ? ] , _ returnType: TypeSyntax ? ) throws
418
455
-> ( FunctionSignatureSyntax , Bool ) {
419
456
var types = argTypes
420
- let typeName = getUnattributedType ( oldType) . description
421
- guard let desugaredType = typeMappings [ typeName] else {
422
- throw DiagnosticError (
423
- " unable to desugar type with name ' \( typeName) ' " , node: node)
424
- }
425
-
426
- let parsedDesugaredType = TypeSyntax ( " \( raw: getUnqualifiedStdName ( desugaredType) !) " )
427
- let genericArg = TypeSyntax ( parsedDesugaredType. as ( IdentifierTypeSyntax . self) !
428
- . genericArgumentClause!. arguments. first!. argument) !
429
- types [ index] = replaceBaseType ( param. type,
430
- TypeSyntax ( " Span< \( raw: dropCxxQualifiers ( genericArg) ) > " ) )
457
+ types [ index] = try newType
431
458
return try base. buildFunctionSignature ( types, returnType)
432
459
}
433
460
@@ -440,44 +467,100 @@ struct CxxSpanThunkBuilder: ParamPointerBoundsThunkBuilder {
440
467
}
441
468
}
442
469
443
- struct CxxSpanReturnThunkBuilder : BoundsCheckedThunkBuilder {
470
+ struct CxxSpanReturnThunkBuilder : SpanBoundsThunkBuilder {
444
471
public let base : BoundsCheckedThunkBuilder
445
472
public let signature : FunctionSignatureSyntax
446
473
public let typeMappings : [ String : String ]
447
474
public let node : SyntaxProtocol
448
475
476
+ var oldType : TypeSyntax {
477
+ return signature. returnClause!. type
478
+ }
479
+
449
480
func buildBoundsChecks( ) throws -> [ CodeBlockItemSyntax . Item ] {
450
481
return try base. buildBoundsChecks ( )
451
482
}
452
483
453
484
func buildFunctionSignature( _ argTypes: [ Int : TypeSyntax ? ] , _ returnType: TypeSyntax ? ) throws
454
485
-> ( FunctionSignatureSyntax , Bool ) {
455
486
assert ( returnType == nil )
456
- let typeName = getUnattributedType ( signature. returnClause!. type) . description
457
- guard let desugaredType = typeMappings [ typeName] else {
458
- throw DiagnosticError (
459
- " unable to desugar type with name ' \( typeName) ' " , node: node)
460
- }
461
- let parsedDesugaredType = TypeSyntax ( " \( raw: getUnqualifiedStdName ( desugaredType) !) " )
462
- let genericArg = TypeSyntax ( parsedDesugaredType. as ( IdentifierTypeSyntax . self) !
463
- . genericArgumentClause!. arguments. first!. argument) !
464
- let newType = replaceBaseType ( signature. returnClause!. type,
465
- TypeSyntax ( " Span< \( raw: dropCxxQualifiers ( genericArg) ) > " ) )
466
487
return try base. buildFunctionSignature ( argTypes, newType)
467
488
}
468
489
469
490
func buildFunctionCall( _ pointerArgs: [ Int : ExprSyntax ] ) throws -> ExprSyntax {
470
491
let call = try base. buildFunctionCall ( pointerArgs)
471
- return " _cxxOverrideLifetime(Span(_unsafeCxxSpan: \( call) ), copying: ()) "
492
+ let ( _, isConst) = dropCxxQualifiers ( try genericArg)
493
+ let cast = if isConst || !enableMutableSpan {
494
+ " Span "
495
+ } else {
496
+ " MutableSpan "
497
+ }
498
+ return " _cxxOverrideLifetime( \( raw: cast) (_unsafeCxxSpan: \( call) ), copying: ()) "
472
499
}
473
500
}
474
501
475
- protocol PointerBoundsThunkBuilder : BoundsCheckedThunkBuilder {
502
+ protocol BoundsThunkBuilder : BoundsCheckedThunkBuilder {
476
503
var oldType : TypeSyntax { get }
477
504
var newType : TypeSyntax { get throws }
478
- var nullable : Bool { get }
479
505
var signature : FunctionSignatureSyntax { get }
480
- var nonescaping : Bool { get }
506
+ }
507
+
508
+ protocol SpanBoundsThunkBuilder : BoundsThunkBuilder {
509
+ var typeMappings : [ String : String ] { get }
510
+ var node : SyntaxProtocol { get }
511
+ }
512
+ extension SpanBoundsThunkBuilder {
513
+ var desugaredType : TypeSyntax {
514
+ get throws {
515
+ let typeName = try getUnattributedType ( oldType) . description
516
+ guard let desugaredTypeName = typeMappings [ typeName] else {
517
+ throw DiagnosticError (
518
+ " unable to desugar type with name ' \( typeName) ' " , node: node)
519
+ }
520
+ return TypeSyntax ( " \( raw: getUnqualifiedStdName ( desugaredTypeName) !) " )
521
+ }
522
+ }
523
+ var genericArg : TypeSyntax {
524
+ get throws {
525
+ guard let idType = try desugaredType. as ( IdentifierTypeSyntax . self) else {
526
+ throw DiagnosticError (
527
+ " unexpected non-identifier type ' \( try desugaredType) ', expected a std::span type " ,
528
+ node: try desugaredType)
529
+ }
530
+ guard let genericArgumentClause = idType. genericArgumentClause else {
531
+ throw DiagnosticError (
532
+ " missing generic type argument clause expected after \( idType) " , node: idType)
533
+ }
534
+ guard let firstArg = genericArgumentClause. arguments. first else {
535
+ throw DiagnosticError (
536
+ " expected at least 1 generic type argument for std::span type ' \( idType) ', found ' \( genericArgumentClause) ' " ,
537
+ node: genericArgumentClause. arguments)
538
+ }
539
+ guard let arg = TypeSyntax ( firstArg. argument) else {
540
+ throw DiagnosticError (
541
+ " invalid generic type argument ' \( firstArg. argument) ' " ,
542
+ node: firstArg. argument)
543
+ }
544
+ return arg
545
+ }
546
+ }
547
+ var newType : TypeSyntax {
548
+ get throws {
549
+ let ( strippedArg, isConst) = dropCxxQualifiers ( try genericArg)
550
+ let mutablePrefix = if isConst || !enableMutableSpan {
551
+ " "
552
+ } else {
553
+ " Mutable "
554
+ }
555
+ return replaceBaseType (
556
+ oldType,
557
+ TypeSyntax ( " \( raw: mutablePrefix) Span< \( raw: strippedArg) > " ) )
558
+ }
559
+ }
560
+ }
561
+
562
+ protocol PointerBoundsThunkBuilder : BoundsThunkBuilder {
563
+ var nullable : Bool { get }
481
564
var isSizedBy : Bool { get }
482
565
var generateSpan : Bool { get }
483
566
}
@@ -490,13 +573,12 @@ extension PointerBoundsThunkBuilder {
490
573
}
491
574
}
492
575
493
- protocol ParamPointerBoundsThunkBuilder : PointerBoundsThunkBuilder {
576
+ protocol ParamBoundsThunkBuilder : BoundsThunkBuilder {
494
577
var index : Int { get }
578
+ var nonescaping : Bool { get }
495
579
}
496
580
497
- extension ParamPointerBoundsThunkBuilder {
498
- var generateSpan : Bool { nonescaping }
499
-
581
+ extension ParamBoundsThunkBuilder {
500
582
var param : FunctionParameterSyntax {
501
583
return getParam ( signature, index)
502
584
}
@@ -518,7 +600,7 @@ struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder {
518
600
public let isSizedBy : Bool
519
601
public let dependencies : [ LifetimeDependence ]
520
602
521
- var generateSpan : Bool { !dependencies. isEmpty }
603
+ var generateSpan : Bool { !dependencies. isEmpty && ( !isMutablePointerType ( oldType ) || enableMutableSpan ) }
522
604
523
605
var oldType : TypeSyntax {
524
606
return signature. returnClause!. type
@@ -531,7 +613,7 @@ struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder {
531
613
}
532
614
533
615
func buildBoundsChecks( ) throws -> [ CodeBlockItemSyntax . Item ] {
534
- return [ ]
616
+ return try base . buildBoundsChecks ( )
535
617
}
536
618
537
619
func buildFunctionCall( _ pointerArgs: [ Int : ExprSyntax ] ) throws -> ExprSyntax {
@@ -548,7 +630,8 @@ struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder {
548
630
}
549
631
}
550
632
551
- struct CountedOrSizedPointerThunkBuilder : ParamPointerBoundsThunkBuilder {
633
+
634
+ struct CountedOrSizedPointerThunkBuilder : ParamBoundsThunkBuilder , PointerBoundsThunkBuilder {
552
635
public let base : BoundsCheckedThunkBuilder
553
636
public let index : Int
554
637
public let countExpr : ExprSyntax
@@ -557,6 +640,8 @@ struct CountedOrSizedPointerThunkBuilder: ParamPointerBoundsThunkBuilder {
557
640
public let isSizedBy : Bool
558
641
public let skipTrivialCount : Bool
559
642
643
+ var generateSpan : Bool { nonescaping && ( !isMutablePointerType( oldType) || enableMutableSpan) }
644
+
560
645
func buildFunctionSignature( _ argTypes: [ Int : TypeSyntax ? ] , _ returnType: TypeSyntax ? ) throws
561
646
-> ( FunctionSignatureSyntax , Bool ) {
562
647
var types = argTypes
0 commit comments