Skip to content

Expand property accesses in #expect() and #require(). #161

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions Sources/Testing/Expectations/ExpectationChecking+Macro.swift
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,67 @@ public func __checkInoutFunctionCall<T, /*each*/ U, R>(
)
}

// MARK: - Property access

/// Check that an expectation has passed after a condition has been evaluated
/// and throw an error if it failed.
///
/// This overload is used by property accesses:
///
/// ```swift
/// #expect(x.isFoodTruck)
/// ```
///
/// - Warning: This function is used to implement the `#expect()` and
/// `#require()` macros. Do not call it directly.
public func __checkPropertyAccess<T>(
_ lhs: T, getting memberAccess: (T) -> Bool,
sourceCode: SourceCode,
comments: @autoclosure () -> [Comment],
isRequired: Bool,
sourceLocation: SourceLocation
) -> Result<Void, any Error> {
let condition = memberAccess(lhs)
return __checkValue(
condition,
sourceCode: sourceCode,
expandedExpressionDescription: sourceCode.expandWithOperands(lhs, condition),
comments: comments(),
isRequired: isRequired,
sourceLocation: sourceLocation
)
}

/// Check that an expectation has passed after a condition has been evaluated
/// and throw an error if it failed.
///
/// This overload is used to conditionally unwrap optional values produced from
/// expanded property accesses:
///
/// ```swift
/// let z = try #require(x.nearestFoodTruck)
/// ```
///
/// - Warning: This function is used to implement the `#expect()` and
/// `#require()` macros. Do not call it directly.
public func __checkPropertyAccess<T, U>(
_ lhs: T, getting memberAccess: (T) -> U?,
sourceCode: SourceCode,
comments: @autoclosure () -> [Comment],
isRequired: Bool,
sourceLocation: SourceLocation
) -> Result<U, any Error> {
let optionalValue = memberAccess(lhs)
return __checkValue(
optionalValue,
sourceCode: sourceCode,
expandedExpressionDescription: sourceCode.expandWithOperands(lhs, optionalValue),
comments: comments(),
isRequired: isRequired,
sourceLocation: sourceLocation
)
}

// MARK: - Collection diffing

/// Check that an expectation has passed after a condition has been evaluated
Expand Down
12 changes: 12 additions & 0 deletions Sources/Testing/SourceAttribution/SourceCode+Macro.swift
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,16 @@ extension SourceCode {
public static func __functionCall(_ value: String?, _ functionName: String, _ arguments: (label: String?, value: String)...) -> Self {
Self(kind: .functionCall(value: value, functionName: functionName, arguments: arguments))
}

/// Create an instance of ``SourceCode`` representing a property access.
///
/// - Parameters:
/// - value: The value whose property was accessed.
/// - keyPath: The key path, relative to `value`, that was accessed, not
/// including a leading backslash or period.
///
/// - Returns: A new instance of ``SourceCode``.
public static func __fromPropertyAccess(_ value: String, _ keyPath: String) -> Self {
Self(kind: .propertyAccess(value: value, keyPath: keyPath))
}
}
13 changes: 13 additions & 0 deletions Sources/Testing/SourceAttribution/SourceCode.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ public struct SourceCode: Sendable {
/// - functionName: The name of the function that was called.
/// - arguments: The arguments passed to the function.
case functionCall(value: String?, functionName: String, arguments: [(label: String?, value: String)])

/// The source code represets a property access.
///
/// - Parameters:
/// - value: The value whose property was accessed.
/// - keyPath: The key path, relative to `value`, that was accessed, not
/// including a leading backslash or period.
case propertyAccess(value: String, keyPath: String)
}

/// The kind of syntax node represented by this instance.
Expand Down Expand Up @@ -105,6 +113,9 @@ public struct SourceCode: Sendable {
return "\(sourceCodeAndValue(value, lhs)).\(functionName)(\(argumentList))"
}
return "\(functionName)(\(argumentList))"
case let .propertyAccess(value, keyPath):
let rhs = additionalValuesArray.first
return "\(sourceCodeAndValue(value, lhs)).\(sourceCodeAndValue(keyPath, rhs ?? nil, includeParenthesesIfNeeded: false))"
}
}
}
Expand Down Expand Up @@ -143,6 +154,8 @@ extension SourceCode: CustomStringConvertible, CustomDebugStringConvertible {
return "\(value).\(functionName)(\(argumentList))"
}
return "\(functionName)(\(argumentList))"
case let .propertyAccess(value, keyPath):
return "\(value).\(keyPath)"
}
}

Expand Down
101 changes: 96 additions & 5 deletions Sources/TestingMacros/Support/ConditionArgumentParsing.swift
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,65 @@ private func _parseCondition(from expr: ClosureExprSyntax, for macro: some Frees
return Condition(expression: expr)
}

/// Extract the underlying expression from an optional-chained expression as
/// well as the number of question marks required to reach it.
///
/// - Parameters:
/// - expr: The expression to examine, typically the `base` expression of a
/// `MemberAccessExprSyntax` instance.
///
/// - Returns: A copy of `expr` with trailing question marks from optional
/// chaining removed, as well as a string containing the number of question
/// marks needed to access a member of `expr` after it has been assigned to
/// another variable. If `expr` does not contain any optional chaining, it is
/// returned verbatim along with the empty string.
///
/// This function is used when expanding member accesses (either functions or
/// properties) that could contain optional chaining expressions such as
/// `foo?.bar()`. Since, in this case, `bar()` is ultimately going to be called
/// on a closure argument (i.e. `$0`), it is necessary to determine the number
/// of question mark characters needed to correctly construct that expression
/// and to capture the underlying expression of `foo?` without question marks so
/// that it remains syntactically correct when used without `bar()`.
private func _exprFromOptionalChainedExpr(_ expr: some ExprSyntaxProtocol) -> (ExprSyntax, questionMarks: String) {
let originalExpr = expr
var expr = ExprSyntax(expr)
var questionMarkCount = 0

while let optionalExpr = expr.as(OptionalChainingExprSyntax.self) {
// If the rightmost base expression is an optional-chained member access
// expression (e.g. "bar?" in the member access expression
// "foo.bar?.isQuux"), drop the question mark.
expr = optionalExpr.expression
questionMarkCount += 1
}

// If the rightmost expression Check if any of the member accesses in the expression use optional
// chaining and, if one does, ensure we preserve optional chaining in the
// macro expansion.
if questionMarkCount == 0 {
func isOptionalChained(_ expr: some ExprSyntaxProtocol) -> Bool {
if expr.is(OptionalChainingExprSyntax.self) {
return true
} else if let memberAccessBaseExpr = expr.as(MemberAccessExprSyntax.self)?.base {
return isOptionalChained(memberAccessBaseExpr)
}
return false
}
if isOptionalChained(originalExpr) {
questionMarkCount = 1
}
}

let questionMarks = String(repeating: "?", count: questionMarkCount)

return (expr, questionMarks)
}

/// Parse a condition argument from a member function call.
///
/// - Parameters:
/// - expr: The function call expression.
/// - memberAccessExpr: The called expression of `expr`.
/// - macro: The macro expression being expanded.
/// - context: The macro context in which the expression is being parsed.
///
Expand Down Expand Up @@ -310,15 +364,20 @@ private func _parseCondition(from expr: FunctionCallExprSyntax, for macro: some
.map(\.expression)
.map { Argument(expression: $0) }

var baseExprForSourceCode: ExprSyntax?
var conditionArguments = [Argument]()
if let memberAccessExpr, let baseExpr = memberAccessExpr.base {
if let memberAccessExpr, var baseExpr = memberAccessExpr.base {
let questionMarks: String
(baseExpr, questionMarks) = _exprFromOptionalChainedExpr(baseExpr)
baseExprForSourceCode = baseExpr

conditionArguments.append(Argument(expression: "\(baseExpr.trimmed).self")) // BUG: rdar://113152370
conditionArguments.append(
Argument(
label: "calling",
expression: """
{
$0.\(functionName.trimmed)(\(LabeledExprListSyntax(indexedArguments)))
$0\(raw: questionMarks).\(functionName.trimmed)(\(LabeledExprListSyntax(indexedArguments)))
}
"""
)
Expand All @@ -345,7 +404,37 @@ private func _parseCondition(from expr: FunctionCallExprSyntax, for macro: some
return Condition(
expandedFunctionName,
arguments: conditionArguments,
sourceCode: createSourceCodeExprForFunctionCall(memberAccessExpr?.base, functionName, argumentList)
sourceCode: createSourceCodeExprForFunctionCall(baseExprForSourceCode, functionName, argumentList)
)
}

/// Parse a condition argument from a property access.
///
/// - Parameters:
/// - expr: The member access expression.
/// - macro: The macro expression being expanded.
/// - context: The macro context in which the expression is being parsed.
///
/// - Returns: An instance of ``Condition`` describing `expr`.
private func _parseCondition(from expr: MemberAccessExprSyntax, for macro: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext) -> Condition {
// Only handle member access expressions where the base expression is known
// and where there are no argument names (which would otherwise indicate a
// reference to a member function which wouldn't resolve to anything useful at
// runtime.)
guard var baseExpr = expr.base, expr.declName.argumentNames == nil else {
return Condition(expression: expr)
}

let questionMarks: String
(baseExpr, questionMarks) = _exprFromOptionalChainedExpr(baseExpr)

return Condition(
"__checkPropertyAccess",
arguments: [
Argument(expression: "\(baseExpr.trimmed).self"),
Argument(label: "getting", expression: "{ $0\(raw: questionMarks).\(expr.declName.baseName) }")
],
sourceCode: createSourceCodeExprForPropertyAccess(baseExpr, expr.declName)
)
}

Expand Down Expand Up @@ -375,9 +464,11 @@ private func _parseCondition(from expr: ExprSyntax, for macro: some Freestanding
return _parseCondition(from: closureExpr, for: macro, in: context)
}

// Handle function calls.
// Handle function calls and member accesses.
if let functionCallExpr = expr.as(FunctionCallExprSyntax.self) {
return _parseCondition(from: functionCallExpr, for: macro, in: context)
} else if let memberAccessExpr = expr.as(MemberAccessExprSyntax.self) {
return _parseCondition(from: memberAccessExpr, for: macro, in: context)
}

// Parentheses are parsed as if they were tuples, so (true && false) appears
Expand Down
9 changes: 9 additions & 0 deletions Sources/TestingMacros/Support/SourceCodeCapturing.swift
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,12 @@ func createSourceCodeExprForFunctionCall(_ value: (some SyntaxProtocol)?, _ func

return ".__functionCall(\(arguments))"
}

func createSourceCodeExprForPropertyAccess(_ value: ExprSyntax, _ keyPath: DeclReferenceExprSyntax) -> ExprSyntax {
let arguments = LabeledExprListSyntax {
LabeledExprSyntax(expression: StringLiteralExprSyntax(content: value.trimmedDescription))
LabeledExprSyntax(expression: StringLiteralExprSyntax(content: keyPath.baseName.trimmedDescription))
}

return ".__fromPropertyAccess(\(arguments))"
}
26 changes: 25 additions & 1 deletion Tests/TestingMacrosTests/ConditionMacroTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ struct ConditionMacroTests {
##"Testing.__checkValue(a.b(&c, d), sourceCode: .__fromSyntaxNode("a.b(&c, d)"), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation()).__expected()"##,
##"#expect(a.b(try c()))"##:
##"Testing.__checkValue(a.b(try c()), sourceCode: .__fromSyntaxNode("a.b(try c())"), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation()).__expected()"##,
##"#expect(a?.b(c))"##:
##"Testing.__checkFunctionCall(a.self, calling: { $0?.b($1) }, c, sourceCode: .__functionCall("a", "b", (nil, "c")), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation()).__expected()"##,
##"#expect(a???.b(c))"##:
##"Testing.__checkFunctionCall(a.self, calling: { $0???.b($1) }, c, sourceCode: .__functionCall("a", "b", (nil, "c")), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation()).__expected()"##,
##"#expect(a?.b.c(d))"##:
##"Testing.__checkFunctionCall(a?.b.self, calling: { $0?.c($1) }, d, sourceCode: .__functionCall("a?.b", "c", (nil, "d")), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation()).__expected()"##,
##"#expect({}())"##:
##"Testing.__checkValue({}(), sourceCode: .__fromSyntaxNode("{}()"), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation()).__expected()"##,
##"#expect(a.b(c: d))"##:
Expand All @@ -74,6 +80,12 @@ struct ConditionMacroTests {
##"Testing.__checkValue(a.b { c }, sourceCode: .__fromSyntaxNode("a.b { c }"), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation()).__expected()"##,
##"#expect(a, sourceLocation: someValue)"##:
##"Testing.__checkValue(a, sourceCode: .__fromSyntaxNode("a"), comments: [], isRequired: false, sourceLocation: someValue).__expected()"##,
##"#expect(a.isB)"##:
##"Testing.__checkPropertyAccess(a.self, getting: { $0.isB }, sourceCode: .__fromPropertyAccess("a", "isB"), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation()).__expected()"##,
##"#expect(a???.isB)"##:
##"Testing.__checkPropertyAccess(a.self, getting: { $0???.isB }, sourceCode: .__fromPropertyAccess("a", "isB"), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation()).__expected()"##,
##"#expect(a?.b.isB)"##:
##"Testing.__checkPropertyAccess(a?.b.self, getting: { $0?.isB }, sourceCode: .__fromPropertyAccess("a?.b", "isB"), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation()).__expected()"##,
]
)
func expectMacro(input: String, expectedOutput: String) throws {
Expand Down Expand Up @@ -128,6 +140,12 @@ struct ConditionMacroTests {
##"Testing.__checkValue(a.b(&c, d), sourceCode: .__fromSyntaxNode("a.b(&c, d)"), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
##"#require(a.b(try c()))"##:
##"Testing.__checkValue(a.b(try c()), sourceCode: .__fromSyntaxNode("a.b(try c())"), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
##"#require(a?.b(c))"##:
##"Testing.__checkFunctionCall(a.self, calling: { $0?.b($1) }, c, sourceCode: .__functionCall("a", "b", (nil, "c")), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
##"#require(a???.b(c))"##:
##"Testing.__checkFunctionCall(a.self, calling: { $0???.b($1) }, c, sourceCode: .__functionCall("a", "b", (nil, "c")), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
##"#require(a?.b.c(d))"##:
##"Testing.__checkFunctionCall(a?.b.self, calling: { $0?.c($1) }, d, sourceCode: .__functionCall("a?.b", "c", (nil, "d")), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
##"#require({}())"##:
##"Testing.__checkValue({}(), sourceCode: .__fromSyntaxNode("{}()"), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
##"#require(a.b(c: d))"##:
Expand All @@ -136,6 +154,12 @@ struct ConditionMacroTests {
##"Testing.__checkValue(a.b { c }, sourceCode: .__fromSyntaxNode("a.b { c }"), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
##"#require(a, sourceLocation: someValue)"##:
##"Testing.__checkValue(a, sourceCode: .__fromSyntaxNode("a"), comments: [], isRequired: true, sourceLocation: someValue).__required()"##,
##"#require(a.isB)"##:
##"Testing.__checkPropertyAccess(a.self, getting: { $0.isB }, sourceCode: .__fromPropertyAccess("a", "isB"), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
##"#require(a???.isB)"##:
##"Testing.__checkPropertyAccess(a.self, getting: { $0???.isB }, sourceCode: .__fromPropertyAccess("a", "isB"), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
##"#require(a?.b.isB)"##:
##"Testing.__checkPropertyAccess(a?.b.self, getting: { $0?.isB }, sourceCode: .__fromPropertyAccess("a?.b", "isB"), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
]
)
func requireMacro(input: String, expectedOutput: String) throws {
Expand All @@ -147,7 +171,7 @@ struct ConditionMacroTests {
@Test("Unwrapping #require() macro",
arguments: [
##"#require(Optional<Int>.none)"##:
##"Testing.__checkValue(Optional<Int>.none, sourceCode: .__fromSyntaxNode("Optional<Int>.none"), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
##"Testing.__checkPropertyAccess(Optional<Int>.self, getting: { $0.none }, sourceCode: .__fromPropertyAccess("Optional<Int>", "none"), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
##"#require(nil ?? 123)"##:
##"Testing.__checkBinaryOperation(nil, { $0 ?? $1() }, 123, sourceCode: .__fromBinaryOperation("nil", "??", "123"), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
##"#require(123 ?? nil)"##:
Expand Down