Skip to content

[Swiftify] enable mutable span #80387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 169 additions & 43 deletions lib/Macros/Sources/SwiftMacros/SwiftifyImportMacro.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@ import SwiftSyntax
import SwiftSyntaxBuilder
import SwiftSyntaxMacros

// Disable emitting 'MutableSpan' until it has landed
let enableMutableSpan = false

// avoids depending on SwiftifyImport.swift
// all instances are reparsed and reinstantiated by the macro anyways,
// so linking is irrelevant
Expand Down Expand Up @@ -279,36 +276,49 @@ func getUnqualifiedStdName(_ type: String) -> String? {
func getSafePointerName(mut: Mutability, generateSpan: Bool, isRaw: Bool) -> TokenSyntax {
switch (mut, generateSpan, isRaw) {
case (.Immutable, true, true): return "RawSpan"
case (.Mutable, true, true): return if enableMutableSpan {
"MutableRawSpan"
} else {
"RawSpan"
}
case (.Mutable, true, true): return "MutableRawSpan"
case (.Immutable, false, true): return "UnsafeRawBufferPointer"
case (.Mutable, false, true): return "UnsafeMutableRawBufferPointer"

case (.Immutable, true, false): return "Span"
case (.Mutable, true, false): return if enableMutableSpan {
"MutableSpan"
} else {
"Span"
}
case (.Mutable, true, false): return "MutableSpan"
case (.Immutable, false, false): return "UnsafeBufferPointer"
case (.Mutable, false, false): return "UnsafeMutableBufferPointer"
}
}

func transformType(_ prev: TypeSyntax, _ generateSpan: Bool, _ isSizedBy: Bool) throws -> TypeSyntax {
func hasOwnershipSpecifier(_ attrType: AttributedTypeSyntax) -> Bool {
return attrType.specifiers.contains(where: { e in
guard let simpleSpec = e.as(SimpleTypeSpecifierSyntax.self) else {
return false
}
let specifierText = simpleSpec.specifier.text
switch specifierText {
case "borrowing":
return true
case "inout":
return true
case "consuming":
return true
default:
return false
}
})
}

func transformType(_ prev: TypeSyntax, _ generateSpan: Bool, _ isSizedBy: Bool, _ setMutableSpanInout: Bool) throws -> TypeSyntax {
if let optType = prev.as(OptionalTypeSyntax.self) {
return TypeSyntax(
optType.with(\.wrappedType, try transformType(optType.wrappedType, generateSpan, isSizedBy)))
optType.with(\.wrappedType, try transformType(optType.wrappedType, generateSpan, isSizedBy, setMutableSpanInout)))
}
if let impOptType = prev.as(ImplicitlyUnwrappedOptionalTypeSyntax.self) {
return try transformType(impOptType.wrappedType, generateSpan, isSizedBy)
return try transformType(impOptType.wrappedType, generateSpan, isSizedBy, setMutableSpanInout)
}
if let attrType = prev.as(AttributedTypeSyntax.self) {
// We insert 'inout' by default for MutableSpan, but it shouldn't override existing ownership
let setMutableSpanInoutNext = setMutableSpanInout && !hasOwnershipSpecifier(attrType)
return TypeSyntax(
attrType.with(\.baseType, try transformType(attrType.baseType, generateSpan, isSizedBy)))
attrType.with(\.baseType, try transformType(attrType.baseType, generateSpan, isSizedBy, setMutableSpanInoutNext)))
}
let name = try getTypeName(prev)
let text = name.text
Expand All @@ -326,10 +336,15 @@ func transformType(_ prev: TypeSyntax, _ generateSpan: Bool, _ isSizedBy: Bool)
+ " - first type token is '\(text)'", node: name)
}
let token = getSafePointerName(mut: kind, generateSpan: generateSpan, isRaw: isSizedBy)
if isSizedBy {
return TypeSyntax(IdentifierTypeSyntax(name: token))
let mainType = if isSizedBy {
TypeSyntax(IdentifierTypeSyntax(name: token))
} else {
try replaceTypeName(prev, token)
}
return try replaceTypeName(prev, token)
if setMutableSpanInout && generateSpan && kind == .Mutable {
return TypeSyntax("inout \(mainType)")
}
return mainType
}

func isMutablePointerType(_ type: TypeSyntax) -> Bool {
Expand Down Expand Up @@ -431,10 +446,11 @@ struct FunctionCallBuilder: BoundsCheckedThunkBuilder {
let colon: TokenSyntax? = label != nil ? .colonToken() : nil
return LabeledExprSyntax(label: label, colon: colon, expression: arg, trailingComma: comma)
}
return ExprSyntax(
let call = ExprSyntax(
FunctionCallExprSyntax(
calledExpression: functionRef, leftParen: .leftParenToken(),
arguments: LabeledExprListSyntax(labeledArgs), rightParen: .rightParenToken()))
return "unsafe \(call)"
}
}

Expand All @@ -446,6 +462,7 @@ struct CxxSpanThunkBuilder: SpanBoundsThunkBuilder, ParamBoundsThunkBuilder {
public let node: SyntaxProtocol
public let nonescaping: Bool
let isSizedBy: Bool = false
let isParameter: Bool = true

func buildBoundsChecks() throws -> [CodeBlockItemSyntax.Item] {
return try base.buildBoundsChecks()
Expand All @@ -462,8 +479,26 @@ struct CxxSpanThunkBuilder: SpanBoundsThunkBuilder, ParamBoundsThunkBuilder {
var args = pointerArgs
let typeName = getUnattributedType(oldType).description
assert(args[index] == nil)
args[index] = ExprSyntax("\(raw: typeName)(\(raw: name))")
return try base.buildFunctionCall(args)

let (_, isConst) = dropCxxQualifiers(try genericArg)
if isConst {
args[index] = ExprSyntax("\(raw: typeName)(\(raw: name))")
return try base.buildFunctionCall(args)
} else {
let unwrappedName = TokenSyntax("_\(name)Ptr")
args[index] = ExprSyntax("\(raw: typeName)(\(unwrappedName))")
let call = try base.buildFunctionCall(args)

// MutableSpan - unlike Span - cannot be bitcast to std::span due to being ~Copyable,
// so unwrap it to an UnsafeMutableBufferPointer that we can cast
let unwrappedCall = ExprSyntax(
"""
unsafe \(name).withUnsafeMutableBufferPointer { \(unwrappedName) in
return \(call)
}
""")
return unwrappedCall
}
}
}

Expand All @@ -472,6 +507,7 @@ struct CxxSpanReturnThunkBuilder: SpanBoundsThunkBuilder {
public let signature: FunctionSignatureSyntax
public let typeMappings: [String: String]
public let node: SyntaxProtocol
let isParameter: Bool = false

var oldType: TypeSyntax {
return signature.returnClause!.type
Expand All @@ -490,12 +526,12 @@ struct CxxSpanReturnThunkBuilder: SpanBoundsThunkBuilder {
func buildFunctionCall(_ pointerArgs: [Int: ExprSyntax]) throws -> ExprSyntax {
let call = try base.buildFunctionCall(pointerArgs)
let (_, isConst) = dropCxxQualifiers(try genericArg)
let cast = if isConst || !enableMutableSpan {
let cast = if isConst {
"Span"
} else {
"MutableSpan"
}
return "_cxxOverrideLifetime(\(raw: cast)(_unsafeCxxSpan: \(call)), copying: ())"
return "unsafe _cxxOverrideLifetime(\(raw: cast)(_unsafeCxxSpan: \(call)), copying: ())"
}
}

Expand All @@ -508,11 +544,12 @@ protocol BoundsThunkBuilder: BoundsCheckedThunkBuilder {
protocol SpanBoundsThunkBuilder: BoundsThunkBuilder {
var typeMappings: [String: String] { get }
var node: SyntaxProtocol { get }
var isParameter: Bool { get }
}
extension SpanBoundsThunkBuilder {
var desugaredType: TypeSyntax {
get throws {
let typeName = try getUnattributedType(oldType).description
let typeName = getUnattributedType(oldType).description
guard let desugaredTypeName = typeMappings[typeName] else {
throw DiagnosticError(
"unable to desugar type with name '\(typeName)'", node: node)
Expand Down Expand Up @@ -547,14 +584,18 @@ extension SpanBoundsThunkBuilder {
var newType: TypeSyntax {
get throws {
let (strippedArg, isConst) = dropCxxQualifiers(try genericArg)
let mutablePrefix = if isConst || !enableMutableSpan {
let mutablePrefix = if isConst {
""
} else {
"Mutable"
}
return replaceBaseType(
let mainType = replaceBaseType(
oldType,
TypeSyntax("\(raw: mutablePrefix)Span<\(raw: strippedArg)>"))
if !isConst && isParameter {
return TypeSyntax("inout \(mainType)")
}
return mainType
}
}
}
Expand All @@ -563,13 +604,14 @@ protocol PointerBoundsThunkBuilder: BoundsThunkBuilder {
var nullable: Bool { get }
var isSizedBy: Bool { get }
var generateSpan: Bool { get }
var isParameter: Bool { get }
}

extension PointerBoundsThunkBuilder {
var nullable: Bool { return oldType.is(OptionalTypeSyntax.self) }

var newType: TypeSyntax { get throws {
return try transformType(oldType, generateSpan, isSizedBy) }
return try transformType(oldType, generateSpan, isSizedBy, isParameter) }
}
}

Expand Down Expand Up @@ -599,8 +641,9 @@ struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder {
public let nonescaping: Bool
public let isSizedBy: Bool
public let dependencies: [LifetimeDependence]
let isParameter: Bool = false

var generateSpan: Bool { !dependencies.isEmpty && (!isMutablePointerType(oldType) || enableMutableSpan)}
var generateSpan: Bool { !dependencies.isEmpty }

var oldType: TypeSyntax {
return signature.returnClause!.type
Expand All @@ -623,9 +666,25 @@ struct CountedOrSizedReturnPointerThunkBuilder: PointerBoundsThunkBuilder {
} else {
"start"
}
var cast = try newType
if nullable {
if let optType = cast.as(OptionalTypeSyntax.self) {
cast = optType.wrappedType
}
return """
{ () in
let _resultValue = \(call)
if unsafe _resultValue == nil {
return nil
} else {
return unsafe \(raw: try cast)(\(raw: startLabel): _resultValue!, count: Int(\(countExpr)))
}
}()
"""
}
return
"""
\(raw: try newType)(\(raw: startLabel): \(call), count: Int(\(countExpr)))
unsafe \(raw: try cast)(\(raw: startLabel): \(call), count: Int(\(countExpr)))
"""
}
}
Expand All @@ -639,8 +698,9 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
public let nonescaping: Bool
public let isSizedBy: Bool
public let skipTrivialCount: Bool
let isParameter: Bool = true

var generateSpan: Bool { nonescaping && (!isMutablePointerType(oldType) || enableMutableSpan) }
var generateSpan: Bool { nonescaping }

func buildFunctionSignature(_ argTypes: [Int: TypeSyntax?], _ returnType: TypeSyntax?) throws
-> (FunctionSignatureSyntax, Bool) {
Expand Down Expand Up @@ -702,11 +762,16 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
let call = try base.buildFunctionCall(args)
let ptrRef = unwrapIfNullable(ExprSyntax(DeclReferenceExprSyntax(baseName: name)))

let funcName = isSizedBy ? "withUnsafeBytes" : "withUnsafeBufferPointer"
let funcName = switch (isSizedBy, isMutablePointerType(oldType)) {
case (true, true): "withUnsafeMutableBytes"
case (true, false): "withUnsafeBytes"
case (false, true): "withUnsafeMutableBufferPointer"
case (false, false): "withUnsafeBufferPointer"
}
let unwrappedCall = ExprSyntax(
"""
\(ptrRef).\(raw: funcName) { \(unwrappedName) in
return unsafe \(call)
unsafe \(ptrRef).\(raw: funcName) { \(unwrappedName) in
return \(call)
}
""")
return unwrappedCall
Expand Down Expand Up @@ -766,11 +831,11 @@ struct CountedOrSizedPointerThunkBuilder: ParamBoundsThunkBuilder, PointerBounds
nullArgs[index] = ExprSyntax(NilLiteralExprSyntax(nilKeyword: .keyword(.nil)))
return ExprSyntax(
"""
if \(name) == nil {
unsafe \(try base.buildFunctionCall(nullArgs))
} else {
\(unwrappedCall)
}
{ () in return if \(name) == nil {
\(try base.buildFunctionCall(nullArgs))
} else {
\(unwrappedCall)
} }()
""")
}
return unwrappedCall
Expand Down Expand Up @@ -1161,7 +1226,7 @@ public struct SwiftifyImportMacro: PeerMacro {
}
}

static func lifetimeAttributes(_ funcDecl: FunctionDeclSyntax,
static func getReturnLifetimeAttribute(_ funcDecl: FunctionDeclSyntax,
_ dependencies: [SwiftifyExpr: [LifetimeDependence]]) -> [AttributeListSyntax.Element] {
let returnDependencies = dependencies[.`return`, default: []]
if returnDependencies.isEmpty {
Expand Down Expand Up @@ -1190,6 +1255,66 @@ public struct SwiftifyImportMacro: PeerMacro {
rightParen: .rightParenToken()))]
}

static func isMutableSpan(_ type: TypeSyntax) -> Bool {
if let optType = type.as(OptionalTypeSyntax.self) {
return isMutableSpan(optType.wrappedType)
}
if let impOptType = type.as(ImplicitlyUnwrappedOptionalTypeSyntax.self) {
return isMutableSpan(impOptType.wrappedType)
}
if let attrType = type.as(AttributedTypeSyntax.self) {
return isMutableSpan(attrType.baseType)
}
guard let identifierType = type.as(IdentifierTypeSyntax.self) else {
return false
}
let name = identifierType.name.text
return name == "MutableSpan" || name == "MutableRawSpan"
}

static func containsLifetimeAttr(_ attrs: AttributeListSyntax, for paramName: TokenSyntax) -> Bool {
for elem in attrs {
guard let attr = elem.as(AttributeSyntax.self) else {
continue
}
if attr.attributeName != "lifetime" {
continue
}
guard let args = attr.arguments?.as(LabeledExprListSyntax.self) else {
continue
}
for arg in args {
if arg.label == paramName {
return true
}
}
}
return false
}

// Mutable[Raw]Span parameters need explicit @lifetime annotations since they are inout
static func paramLifetimeAttributes(_ newSignature: FunctionSignatureSyntax, _ oldAttrs: AttributeListSyntax) -> [AttributeListSyntax.Element] {
var defaultLifetimes: [AttributeListSyntax.Element] = []
for param in newSignature.parameterClause.parameters {
if !isMutableSpan(param.type) {
continue
}
let paramName = param.secondName ?? param.firstName
if containsLifetimeAttr(oldAttrs, for: paramName) {
continue
}
let expr = ExprSyntax("\(paramName): copy \(paramName)")

defaultLifetimes.append(.attribute(AttributeSyntax(
atSign: .atSignToken(),
attributeName: IdentifierTypeSyntax(name: "lifetime"),
leftParen: .leftParenToken(),
arguments: .argumentList(LabeledExprListSyntax([LabeledExprSyntax(expression: expr)])),
rightParen: .rightParenToken())))
}
return defaultLifetimes
}

public static func expansion(
of node: AttributeSyntax,
providingPeersOf declaration: some DeclSyntaxProtocol,
Expand Down Expand Up @@ -1255,9 +1380,10 @@ public struct SwiftifyImportMacro: PeerMacro {
item: CodeBlockItemSyntax.Item(
ReturnStmtSyntax(
returnKeyword: .keyword(.return, trailingTrivia: " "),
expression: ExprSyntax("unsafe \(try builder.buildFunctionCall([:]))"))))
expression: try builder.buildFunctionCall([:]))))
let body = CodeBlockSyntax(statements: CodeBlockItemListSyntax(checks + [call]))
let lifetimeAttrs = lifetimeAttributes(funcDecl, lifetimeDependencies)
let returnLifetimeAttribute = getReturnLifetimeAttribute(funcDecl, lifetimeDependencies)
let lifetimeAttrs = returnLifetimeAttribute + paramLifetimeAttributes(newSignature, funcDecl.attributes)
let disfavoredOverload : [AttributeListSyntax.Element] = (onlyReturnTypeChanged ? [
.attribute(
AttributeSyntax(
Expand Down
Loading