Skip to content

Commit 87f5309

Browse files
authored
[Swiftify] enable mutable span (#80387)
* [Swiftify] Emit Mutable[Raw]Span when possible Previously wrappers would use UnsafeMutable[Raw]Pointer for mutable pointers, and Span for non-const std::span, to prevent the compiler from complaining that MutableSpan didn't exist. Now that MutableSpan has landed we can finally emit MutableSpan without causing compilation errors. While we had (disabled) support for MutableSpan syntax already, some unexpected semantic errors required additional changes: - Mutable[Raw]Span parameters need to be inout (for mutation) - inout ~Escapable paramters need explicit lifetime annotations - MutableSpan cannot be directly bitcast to std::span, because it is ~Copyable, so they need unwrapping to UnsafeMutableBufferPointer rdar://147883022 * [Swiftify] Wrap if-expressions in Immediately Called Closures When parameters in swiftified wrapper functions are nullable, we use separate branches for the nil and nonnil cases, because `withUnsafeBufferPointer` (and similar) cannot be called on nil. If-expressions have some limitations on where they are allowed in the grammar, and cannot be passed as arguments to a function. As such, when the return value is also swiftified, we get an error when trying to pass the if-expression to the UnsafeBufferPointer/Span constructor. While it isn't pretty, the best way forward seems to be by wrapping the if-expressions in Immediately Called Closures. The closures have the side-effect of acting as a barrier for 'unsafe': unsafe keywords outside the closure do not "reach" unsafe expressions inside the closure. We therefore have to emit "unsafe" where unsafe expressions are used, rather than just when returning. rdar://148153063
1 parent 813f1dc commit 87f5309

File tree

13 files changed

+436
-94
lines changed

13 files changed

+436
-94
lines changed

lib/Macros/Sources/SwiftMacros/SwiftifyImportMacro.swift

Lines changed: 169 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@ import SwiftSyntax
44
import SwiftSyntaxBuilder
55
import SwiftSyntaxMacros
66

7-
// Disable emitting 'MutableSpan' until it has landed
8-
let enableMutableSpan = false
9-
107
// avoids depending on SwiftifyImport.swift
118
// all instances are reparsed and reinstantiated by the macro anyways,
129
// so linking is irrelevant
@@ -279,36 +276,49 @@ func getUnqualifiedStdName(_ type: String) -> String? {
279276
func getSafePointerName(mut: Mutability, generateSpan: Bool, isRaw: Bool) -> TokenSyntax {
280277
switch (mut, generateSpan, isRaw) {
281278
case (.Immutable, true, true): return "RawSpan"
282-
case (.Mutable, true, true): return if enableMutableSpan {
283-
"MutableRawSpan"
284-
} else {
285-
"RawSpan"
286-
}
279+
case (.Mutable, true, true): return "MutableRawSpan"
287280
case (.Immutable, false, true): return "UnsafeRawBufferPointer"
288281
case (.Mutable, false, true): return "UnsafeMutableRawBufferPointer"
289282

290283
case (.Immutable, true, false): return "Span"
291-
case (.Mutable, true, false): return if enableMutableSpan {
292-
"MutableSpan"
293-
} else {
294-
"Span"
295-
}
284+
case (.Mutable, true, false): return "MutableSpan"
296285
case (.Immutable, false, false): return "UnsafeBufferPointer"
297286
case (.Mutable, false, false): return "UnsafeMutableBufferPointer"
298287
}
299288
}
300289

301-
func transformType(_ prev: TypeSyntax, _ generateSpan: Bool, _ isSizedBy: Bool) throws -> TypeSyntax {
290+
func hasOwnershipSpecifier(_ attrType: AttributedTypeSyntax) -> Bool {
291+
return attrType.specifiers.contains(where: { e in
292+
guard let simpleSpec = e.as(SimpleTypeSpecifierSyntax.self) else {
293+
return false
294+
}
295+
let specifierText = simpleSpec.specifier.text
296+
switch specifierText {
297+
case "borrowing":
298+
return true
299+
case "inout":
300+
return true
301+
case "consuming":
302+
return true
303+
default:
304+
return false
305+
}
306+
})
307+
}
308+
309+
func transformType(_ prev: TypeSyntax, _ generateSpan: Bool, _ isSizedBy: Bool, _ setMutableSpanInout: Bool) throws -> TypeSyntax {
302310
if let optType = prev.as(OptionalTypeSyntax.self) {
303311
return TypeSyntax(
304-
optType.with(\.wrappedType, try transformType(optType.wrappedType, generateSpan, isSizedBy)))
312+
optType.with(\.wrappedType, try transformType(optType.wrappedType, generateSpan, isSizedBy, setMutableSpanInout)))
305313
}
306314
if let impOptType = prev.as(ImplicitlyUnwrappedOptionalTypeSyntax.self) {
307-
return try transformType(impOptType.wrappedType, generateSpan, isSizedBy)
315+
return try transformType(impOptType.wrappedType, generateSpan, isSizedBy, setMutableSpanInout)
308316
}
309317
if let attrType = prev.as(AttributedTypeSyntax.self) {
318+
// We insert 'inout' by default for MutableSpan, but it shouldn't override existing ownership
319+
let setMutableSpanInoutNext = setMutableSpanInout && !hasOwnershipSpecifier(attrType)
310320
return TypeSyntax(
311-
attrType.with(\.baseType, try transformType(attrType.baseType, generateSpan, isSizedBy)))
321+
attrType.with(\.baseType, try transformType(attrType.baseType, generateSpan, isSizedBy, setMutableSpanInoutNext)))
312322
}
313323
let name = try getTypeName(prev)
314324
let text = name.text
@@ -326,10 +336,15 @@ func transformType(_ prev: TypeSyntax, _ generateSpan: Bool, _ isSizedBy: Bool)
326336
+ " - first type token is '\(text)'", node: name)
327337
}
328338
let token = getSafePointerName(mut: kind, generateSpan: generateSpan, isRaw: isSizedBy)
329-
if isSizedBy {
330-
return TypeSyntax(IdentifierTypeSyntax(name: token))
339+
let mainType = if isSizedBy {
340+
TypeSyntax(IdentifierTypeSyntax(name: token))
341+
} else {
342+
try replaceTypeName(prev, token)
331343
}
332-
return try replaceTypeName(prev, token)
344+
if setMutableSpanInout && generateSpan && kind == .Mutable {
345+
return TypeSyntax("inout \(mainType)")
346+
}
347+
return mainType
333348
}
334349

335350
func isMutablePointerType(_ type: TypeSyntax) -> Bool {
@@ -431,10 +446,11 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
431446
let colon: TokenSyntax? = label != nil ? .colonToken() : nil
432447
return LabeledExprSyntax(label: label, colon: colon, expression: arg, trailingComma: comma)
433448
}
434-
return ExprSyntax(
449+
let call = ExprSyntax(
435450
FunctionCallExprSyntax(
436451
calledExpression: functionRef, leftParen: .leftParenToken(),
437452
arguments: LabeledExprListSyntax(labeledArgs), rightParen: .rightParenToken()))
453+
return "unsafe \(call)"
438454
}
439455
}
440456

@@ -446,6 +462,7 @@ struct CxxSpanThunkBuilder: SpanBoundsThunkBuilder, ParamBoundsThunkBuilder {
446462
public let node: SyntaxProtocol
447463
public let nonescaping: Bool
448464
let isSizedBy: Bool = false
465+
let isParameter: Bool = true
449466

450467
func buildBoundsChecks() throws -> [CodeBlockItemSyntax.Item] {
451468
return try base.buildBoundsChecks()
@@ -462,8 +479,26 @@ struct CxxSpanThunkBuilder: SpanBoundsThunkBuilder, ParamBoundsThunkBuilder {
462479
var args = pointerArgs
463480
let typeName = getUnattributedType(oldType).description
464481
assert(args[index] == nil)
465-
args[index] = ExprSyntax("\(raw: typeName)(\(raw: name))")
466-
return try base.buildFunctionCall(args)
482+
483+
let (_, isConst) = dropCxxQualifiers(try genericArg)
484+
if isConst {
485+
args[index] = ExprSyntax("\(raw: typeName)(\(raw: name))")
486+
return try base.buildFunctionCall(args)
487+
} else {
488+
let unwrappedName = TokenSyntax("_\(name)Ptr")
489+
args[index] = ExprSyntax("\(raw: typeName)(\(unwrappedName))")
490+
let call = try base.buildFunctionCall(args)
491+
492+
// MutableSpan - unlike Span - cannot be bitcast to std::span due to being ~Copyable,
493+
// so unwrap it to an UnsafeMutableBufferPointer that we can cast
494+
let unwrappedCall = ExprSyntax(
495+
"""
496+
unsafe \(name).withUnsafeMutableBufferPointer { \(unwrappedName) in
497+
return \(call)
498+
}
499+
""")
500+
return unwrappedCall
501+
}
467502
}
468503
}
469504

@@ -472,6 +507,7 @@ struct CxxSpanReturnThunkBuilder: SpanBoundsThunkBuilder {
472507
public let signature: FunctionSignatureSyntax
473508
public let typeMappings: [String: String]
474509
public let node: SyntaxProtocol
510+
let isParameter: Bool = false
475511

476512
var oldType: TypeSyntax {
477513
return signature.returnClause!.type
@@ -490,12 +526,12 @@ struct CxxSpanReturnThunkBuilder: SpanBoundsThunkBuilder {
490526
func buildFunctionCall(_ pointerArgs: [Int: ExprSyntax]) throws -> ExprSyntax {
491527
let call = try base.buildFunctionCall(pointerArgs)
492528
let (_, isConst) = dropCxxQualifiers(try genericArg)
493-
let cast = if isConst || !enableMutableSpan {
529+
let cast = if isConst {
494530
"Span"
495531
} else {
496532
"MutableSpan"
497533
}
498-
return "_cxxOverrideLifetime(\(raw: cast)(_unsafeCxxSpan: \(call)), copying: ())"
534+
return "unsafe _cxxOverrideLifetime(\(raw: cast)(_unsafeCxxSpan: \(call)), copying: ())"
499535
}
500536
}
501537

@@ -508,11 +544,12 @@ protocol BoundsThunkBuilder: BoundsCheckedThunkBuilder {
508544
protocol SpanBoundsThunkBuilder: BoundsThunkBuilder {
509545
var typeMappings: [String: String] { get }
510546
var node: SyntaxProtocol { get }
547+
var isParameter: Bool { get }
511548
}
512549
extension SpanBoundsThunkBuilder {
513550
var desugaredType: TypeSyntax {
514551
get throws {
515-
let typeName = try getUnattributedType(oldType).description
552+
let typeName = getUnattributedType(oldType).description
516553
guard let desugaredTypeName = typeMappings[typeName] else {
517554
throw DiagnosticError(
518555
"unable to desugar type with name '\(typeName)'", node: node)
@@ -547,14 +584,18 @@ extension SpanBoundsThunkBuilder {
547584
var newType: TypeSyntax {
548585
get throws {
549586
let (strippedArg, isConst) = dropCxxQualifiers(try genericArg)
550-
let mutablePrefix = if isConst || !enableMutableSpan {
587+
let mutablePrefix = if isConst {
551588
""
552589
} else {
553590
"Mutable"
554591
}
555-
return replaceBaseType(
592+
let mainType = replaceBaseType(
556593
oldType,
557594
TypeSyntax("\(raw: mutablePrefix)Span<\(raw: strippedArg)>"))
595+
if !isConst && isParameter {
596+
return TypeSyntax("inout \(mainType)")
597+
}
598+
return mainType
558599
}
559600
}
560601
}
@@ -563,13 +604,14 @@ protocol PointerBoundsThunkBuilder: BoundsThunkBuilder {
563604
var nullable: Bool { get }
564605
var isSizedBy: Bool { get }
565606
var generateSpan: Bool { get }
607+
var isParameter: Bool { get }
566608
}
567609

568610
extension PointerBoundsThunkBuilder {
569611
var nullable: Bool { return oldType.is(OptionalTypeSyntax.self) }
570612

571613
var newType: TypeSyntax { get throws {
572-
return try transformType(oldType, generateSpan, isSizedBy) }
614+
return try transformType(oldType, generateSpan, isSizedBy, isParameter) }
573615
}
574616
}
575617

@@ -599,8 +641,9 @@ struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder {
599641
public let nonescaping: Bool
600642
public let isSizedBy: Bool
601643
public let dependencies: [LifetimeDependence]
644+
let isParameter: Bool = false
602645

603-
var generateSpan: Bool { !dependencies.isEmpty && (!isMutablePointerType(oldType) || enableMutableSpan)}
646+
var generateSpan: Bool { !dependencies.isEmpty }
604647

605648
var oldType: TypeSyntax {
606649
return signature.returnClause!.type
@@ -623,9 +666,25 @@ struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder {
623666
} else {
624667
"start"
625668
}
669+
var cast = try newType
670+
if nullable {
671+
if let optType = cast.as(OptionalTypeSyntax.self) {
672+
cast = optType.wrappedType
673+
}
674+
return """
675+
{ () in
676+
let _resultValue = \(call)
677+
if unsafe _resultValue == nil {
678+
return nil
679+
} else {
680+
return unsafe \(raw: try cast)(\(raw: startLabel): _resultValue!, count: Int(\(countExpr)))
681+
}
682+
}()
683+
"""
684+
}
626685
return
627686
"""
628-
\(raw: try newType)(\(raw: startLabel): \(call), count: Int(\(countExpr)))
687+
unsafe \(raw: try cast)(\(raw: startLabel): \(call), count: Int(\(countExpr)))
629688
"""
630689
}
631690
}
@@ -639,8 +698,9 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
639698
public let nonescaping: Bool
640699
public let isSizedBy: Bool
641700
public let skipTrivialCount: Bool
701+
let isParameter: Bool = true
642702

643-
var generateSpan: Bool { nonescaping && (!isMutablePointerType(oldType) || enableMutableSpan) }
703+
var generateSpan: Bool { nonescaping }
644704

645705
func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ returnType: TypeSyntax?) throws
646706
-> (FunctionSignatureSyntax, Bool) {
@@ -702,11 +762,16 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
702762
let call = try base.buildFunctionCall(args)
703763
let ptrRef = unwrapIfNullable(ExprSyntax(DeclReferenceExprSyntax(baseName: name)))
704764

705-
let funcName = isSizedBy ? "withUnsafeBytes" : "withUnsafeBufferPointer"
765+
let funcName = switch (isSizedBy, isMutablePointerType(oldType)) {
766+
case (true, true): "withUnsafeMutableBytes"
767+
case (true, false): "withUnsafeBytes"
768+
case (false, true): "withUnsafeMutableBufferPointer"
769+
case (false, false): "withUnsafeBufferPointer"
770+
}
706771
let unwrappedCall = ExprSyntax(
707772
"""
708-
\(ptrRef).\(raw: funcName) { \(unwrappedName) in
709-
return unsafe \(call)
773+
unsafe \(ptrRef).\(raw: funcName) { \(unwrappedName) in
774+
return \(call)
710775
}
711776
""")
712777
return unwrappedCall
@@ -766,11 +831,11 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
766831
nullArgs[index] = ExprSyntax(NilLiteralExprSyntax(nilKeyword: .keyword(.nil)))
767832
return ExprSyntax(
768833
"""
769-
if \(name) == nil {
770-
unsafe \(try base.buildFunctionCall(nullArgs))
771-
} else {
772-
\(unwrappedCall)
773-
}
834+
{ () in return if \(name) == nil {
835+
\(try base.buildFunctionCall(nullArgs))
836+
} else {
837+
\(unwrappedCall)
838+
} }()
774839
""")
775840
}
776841
return unwrappedCall
@@ -1161,7 +1226,7 @@ public struct SwiftifyImportMacro: PeerMacro {
11611226
}
11621227
}
11631228

1164-
static func lifetimeAttributes(_ funcDecl: FunctionDeclSyntax,
1229+
static func getReturnLifetimeAttribute(_ funcDecl: FunctionDeclSyntax,
11651230
_ dependencies: [SwiftifyExpr: [LifetimeDependence]]) -> [AttributeListSyntax.Element] {
11661231
let returnDependencies = dependencies[.`return`, default: []]
11671232
if returnDependencies.isEmpty {
@@ -1190,6 +1255,66 @@ public struct SwiftifyImportMacro: PeerMacro {
11901255
rightParen: .rightParenToken()))]
11911256
}
11921257

1258+
static func isMutableSpan(_ type: TypeSyntax) -> Bool {
1259+
if let optType = type.as(OptionalTypeSyntax.self) {
1260+
return isMutableSpan(optType.wrappedType)
1261+
}
1262+
if let impOptType = type.as(ImplicitlyUnwrappedOptionalTypeSyntax.self) {
1263+
return isMutableSpan(impOptType.wrappedType)
1264+
}
1265+
if let attrType = type.as(AttributedTypeSyntax.self) {
1266+
return isMutableSpan(attrType.baseType)
1267+
}
1268+
guard let identifierType = type.as(IdentifierTypeSyntax.self) else {
1269+
return false
1270+
}
1271+
let name = identifierType.name.text
1272+
return name == "MutableSpan" || name == "MutableRawSpan"
1273+
}
1274+
1275+
static func containsLifetimeAttr(_ attrs: AttributeListSyntax, for paramName: TokenSyntax) -> Bool {
1276+
for elem in attrs {
1277+
guard let attr = elem.as(AttributeSyntax.self) else {
1278+
continue
1279+
}
1280+
if attr.attributeName != "lifetime" {
1281+
continue
1282+
}
1283+
guard let args = attr.arguments?.as(LabeledExprListSyntax.self) else {
1284+
continue
1285+
}
1286+
for arg in args {
1287+
if arg.label == paramName {
1288+
return true
1289+
}
1290+
}
1291+
}
1292+
return false
1293+
}
1294+
1295+
// Mutable[Raw]Span parameters need explicit @lifetime annotations since they are inout
1296+
static func paramLifetimeAttributes(_ newSignature: FunctionSignatureSyntax, _ oldAttrs: AttributeListSyntax) -> [AttributeListSyntax.Element] {
1297+
var defaultLifetimes: [AttributeListSyntax.Element] = []
1298+
for param in newSignature.parameterClause.parameters {
1299+
if !isMutableSpan(param.type) {
1300+
continue
1301+
}
1302+
let paramName = param.secondName ?? param.firstName
1303+
if containsLifetimeAttr(oldAttrs, for: paramName) {
1304+
continue
1305+
}
1306+
let expr = ExprSyntax("\(paramName): copy \(paramName)")
1307+
1308+
defaultLifetimes.append(.attribute(AttributeSyntax(
1309+
atSign: .atSignToken(),
1310+
attributeName: IdentifierTypeSyntax(name: "lifetime"),
1311+
leftParen: .leftParenToken(),
1312+
arguments: .argumentList(LabeledExprListSyntax([LabeledExprSyntax(expression: expr)])),
1313+
rightParen: .rightParenToken())))
1314+
}
1315+
return defaultLifetimes
1316+
}
1317+
11931318
public static func expansion(
11941319
of node: AttributeSyntax,
11951320
providingPeersOf declaration: some DeclSyntaxProtocol,
@@ -1255,9 +1380,10 @@ public struct SwiftifyImportMacro: PeerMacro {
12551380
item: CodeBlockItemSyntax.Item(
12561381
ReturnStmtSyntax(
12571382
returnKeyword: .keyword(.return, trailingTrivia: " "),
1258-
expression: ExprSyntax("unsafe \(try builder.buildFunctionCall([:]))"))))
1383+
expression: try builder.buildFunctionCall([:]))))
12591384
let body = CodeBlockSyntax(statements: CodeBlockItemListSyntax(checks + [call]))
1260-
let lifetimeAttrs = lifetimeAttributes(funcDecl, lifetimeDependencies)
1385+
let returnLifetimeAttribute = getReturnLifetimeAttribute(funcDecl, lifetimeDependencies)
1386+
let lifetimeAttrs = returnLifetimeAttribute + paramLifetimeAttributes(newSignature, funcDecl.attributes)
12611387
let disfavoredOverload : [AttributeListSyntax.Element] = (onlyReturnTypeChanged ? [
12621388
.attribute(
12631389
AttributeSyntax(

0 commit comments

Comments
 (0)