Skip to content

Commit f955de6

Browse files
authored
Expand property accesses in #expect() and #require(). (#161)
1 parent 593327a commit f955de6

File tree

6 files changed

+216
-6
lines changed

6 files changed

+216
-6
lines changed

Sources/Testing/Expectations/ExpectationChecking+Macro.swift

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,67 @@ public func __checkInoutFunctionCall<T, /*each*/ U, R>(
277277
)
278278
}
279279

280+
// MARK: - Property access
281+
282+
/// Check that an expectation has passed after a condition has been evaluated
283+
/// and throw an error if it failed.
284+
///
285+
/// This overload is used by property accesses:
286+
///
287+
/// ```swift
288+
/// #expect(x.isFoodTruck)
289+
/// ```
290+
///
291+
/// - Warning: This function is used to implement the `#expect()` and
292+
/// `#require()` macros. Do not call it directly.
293+
public func __checkPropertyAccess<T>(
294+
_ lhs: T, getting memberAccess: (T) -> Bool,
295+
sourceCode: SourceCode,
296+
comments: @autoclosure () -> [Comment],
297+
isRequired: Bool,
298+
sourceLocation: SourceLocation
299+
) -> Result<Void, any Error> {
300+
let condition = memberAccess(lhs)
301+
return __checkValue(
302+
condition,
303+
sourceCode: sourceCode,
304+
expandedExpressionDescription: sourceCode.expandWithOperands(lhs, condition),
305+
comments: comments(),
306+
isRequired: isRequired,
307+
sourceLocation: sourceLocation
308+
)
309+
}
310+
311+
/// Check that an expectation has passed after a condition has been evaluated
312+
/// and throw an error if it failed.
313+
///
314+
/// This overload is used to conditionally unwrap optional values produced from
315+
/// expanded property accesses:
316+
///
317+
/// ```swift
318+
/// let z = try #require(x.nearestFoodTruck)
319+
/// ```
320+
///
321+
/// - Warning: This function is used to implement the `#expect()` and
322+
/// `#require()` macros. Do not call it directly.
323+
public func __checkPropertyAccess<T, U>(
324+
_ lhs: T, getting memberAccess: (T) -> U?,
325+
sourceCode: SourceCode,
326+
comments: @autoclosure () -> [Comment],
327+
isRequired: Bool,
328+
sourceLocation: SourceLocation
329+
) -> Result<U, any Error> {
330+
let optionalValue = memberAccess(lhs)
331+
return __checkValue(
332+
optionalValue,
333+
sourceCode: sourceCode,
334+
expandedExpressionDescription: sourceCode.expandWithOperands(lhs, optionalValue),
335+
comments: comments(),
336+
isRequired: isRequired,
337+
sourceLocation: sourceLocation
338+
)
339+
}
340+
280341
// MARK: - Collection diffing
281342

282343
/// Check that an expectation has passed after a condition has been evaluated

Sources/Testing/SourceAttribution/SourceCode+Macro.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,16 @@ extension SourceCode {
4545
public static func __functionCall(_ value: String?, _ functionName: String, _ arguments: (label: String?, value: String)...) -> Self {
4646
Self(kind: .functionCall(value: value, functionName: functionName, arguments: arguments))
4747
}
48+
49+
/// Create an instance of ``SourceCode`` representing a property access.
50+
///
51+
/// - Parameters:
52+
/// - value: The value whose property was accessed.
53+
/// - keyPath: The key path, relative to `value`, that was accessed, not
54+
/// including a leading backslash or period.
55+
///
56+
/// - Returns: A new instance of ``SourceCode``.
57+
public static func __fromPropertyAccess(_ value: String, _ keyPath: String) -> Self {
58+
Self(kind: .propertyAccess(value: value, keyPath: keyPath))
59+
}
4860
}

Sources/Testing/SourceAttribution/SourceCode.swift

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ public struct SourceCode: Sendable {
3737
/// - functionName: The name of the function that was called.
3838
/// - arguments: The arguments passed to the function.
3939
case functionCall(value: String?, functionName: String, arguments: [(label: String?, value: String)])
40+
41+
/// The source code represets a property access.
42+
///
43+
/// - Parameters:
44+
/// - value: The value whose property was accessed.
45+
/// - keyPath: The key path, relative to `value`, that was accessed, not
46+
/// including a leading backslash or period.
47+
case propertyAccess(value: String, keyPath: String)
4048
}
4149

4250
/// The kind of syntax node represented by this instance.
@@ -105,6 +113,9 @@ public struct SourceCode: Sendable {
105113
return "\(sourceCodeAndValue(value, lhs)).\(functionName)(\(argumentList))"
106114
}
107115
return "\(functionName)(\(argumentList))"
116+
case let .propertyAccess(value, keyPath):
117+
let rhs = additionalValuesArray.first
118+
return "\(sourceCodeAndValue(value, lhs)).\(sourceCodeAndValue(keyPath, rhs ?? nil, includeParenthesesIfNeeded: false))"
108119
}
109120
}
110121
}
@@ -143,6 +154,8 @@ extension SourceCode: CustomStringConvertible, CustomDebugStringConvertible {
143154
return "\(value).\(functionName)(\(argumentList))"
144155
}
145156
return "\(functionName)(\(argumentList))"
157+
case let .propertyAccess(value, keyPath):
158+
return "\(value).\(keyPath)"
146159
}
147160
}
148161

Sources/TestingMacros/Support/ConditionArgumentParsing.swift

Lines changed: 96 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -243,11 +243,65 @@ private func _parseCondition(from expr: ClosureExprSyntax, for macro: some Frees
243243
return Condition(expression: expr)
244244
}
245245

246+
/// Extract the underlying expression from an optional-chained expression as
247+
/// well as the number of question marks required to reach it.
248+
///
249+
/// - Parameters:
250+
/// - expr: The expression to examine, typically the `base` expression of a
251+
/// `MemberAccessExprSyntax` instance.
252+
///
253+
/// - Returns: A copy of `expr` with trailing question marks from optional
254+
/// chaining removed, as well as a string containing the number of question
255+
/// marks needed to access a member of `expr` after it has been assigned to
256+
/// another variable. If `expr` does not contain any optional chaining, it is
257+
/// returned verbatim along with the empty string.
258+
///
259+
/// This function is used when expanding member accesses (either functions or
260+
/// properties) that could contain optional chaining expressions such as
261+
/// `foo?.bar()`. Since, in this case, `bar()` is ultimately going to be called
262+
/// on a closure argument (i.e. `$0`), it is necessary to determine the number
263+
/// of question mark characters needed to correctly construct that expression
264+
/// and to capture the underlying expression of `foo?` without question marks so
265+
/// that it remains syntactically correct when used without `bar()`.
266+
private func _exprFromOptionalChainedExpr(_ expr: some ExprSyntaxProtocol) -> (ExprSyntax, questionMarks: String) {
267+
let originalExpr = expr
268+
var expr = ExprSyntax(expr)
269+
var questionMarkCount = 0
270+
271+
while let optionalExpr = expr.as(OptionalChainingExprSyntax.self) {
272+
// If the rightmost base expression is an optional-chained member access
273+
// expression (e.g. "bar?" in the member access expression
274+
// "foo.bar?.isQuux"), drop the question mark.
275+
expr = optionalExpr.expression
276+
questionMarkCount += 1
277+
}
278+
279+
// If the rightmost expression Check if any of the member accesses in the expression use optional
280+
// chaining and, if one does, ensure we preserve optional chaining in the
281+
// macro expansion.
282+
if questionMarkCount == 0 {
283+
func isOptionalChained(_ expr: some ExprSyntaxProtocol) -> Bool {
284+
if expr.is(OptionalChainingExprSyntax.self) {
285+
return true
286+
} else if let memberAccessBaseExpr = expr.as(MemberAccessExprSyntax.self)?.base {
287+
return isOptionalChained(memberAccessBaseExpr)
288+
}
289+
return false
290+
}
291+
if isOptionalChained(originalExpr) {
292+
questionMarkCount = 1
293+
}
294+
}
295+
296+
let questionMarks = String(repeating: "?", count: questionMarkCount)
297+
298+
return (expr, questionMarks)
299+
}
300+
246301
/// Parse a condition argument from a member function call.
247302
///
248303
/// - Parameters:
249304
/// - expr: The function call expression.
250-
/// - memberAccessExpr: The called expression of `expr`.
251305
/// - macro: The macro expression being expanded.
252306
/// - context: The macro context in which the expression is being parsed.
253307
///
@@ -310,15 +364,20 @@ private func _parseCondition(from expr: FunctionCallExprSyntax, for macro: some
310364
.map(\.expression)
311365
.map { Argument(expression: $0) }
312366

367+
var baseExprForSourceCode: ExprSyntax?
313368
var conditionArguments = [Argument]()
314-
if let memberAccessExpr, let baseExpr = memberAccessExpr.base {
369+
if let memberAccessExpr, var baseExpr = memberAccessExpr.base {
370+
let questionMarks: String
371+
(baseExpr, questionMarks) = _exprFromOptionalChainedExpr(baseExpr)
372+
baseExprForSourceCode = baseExpr
373+
315374
conditionArguments.append(Argument(expression: "\(baseExpr.trimmed).self")) // BUG: rdar://113152370
316375
conditionArguments.append(
317376
Argument(
318377
label: "calling",
319378
expression: """
320379
{
321-
$0.\(functionName.trimmed)(\(LabeledExprListSyntax(indexedArguments)))
380+
$0\(raw: questionMarks).\(functionName.trimmed)(\(LabeledExprListSyntax(indexedArguments)))
322381
}
323382
"""
324383
)
@@ -345,7 +404,37 @@ private func _parseCondition(from expr: FunctionCallExprSyntax, for macro: some
345404
return Condition(
346405
expandedFunctionName,
347406
arguments: conditionArguments,
348-
sourceCode: createSourceCodeExprForFunctionCall(memberAccessExpr?.base, functionName, argumentList)
407+
sourceCode: createSourceCodeExprForFunctionCall(baseExprForSourceCode, functionName, argumentList)
408+
)
409+
}
410+
411+
/// Parse a condition argument from a property access.
412+
///
413+
/// - Parameters:
414+
/// - expr: The member access expression.
415+
/// - macro: The macro expression being expanded.
416+
/// - context: The macro context in which the expression is being parsed.
417+
///
418+
/// - Returns: An instance of ``Condition`` describing `expr`.
419+
private func _parseCondition(from expr: MemberAccessExprSyntax, for macro: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext) -> Condition {
420+
// Only handle member access expressions where the base expression is known
421+
// and where there are no argument names (which would otherwise indicate a
422+
// reference to a member function which wouldn't resolve to anything useful at
423+
// runtime.)
424+
guard var baseExpr = expr.base, expr.declName.argumentNames == nil else {
425+
return Condition(expression: expr)
426+
}
427+
428+
let questionMarks: String
429+
(baseExpr, questionMarks) = _exprFromOptionalChainedExpr(baseExpr)
430+
431+
return Condition(
432+
"__checkPropertyAccess",
433+
arguments: [
434+
Argument(expression: "\(baseExpr.trimmed).self"),
435+
Argument(label: "getting", expression: "{ $0\(raw: questionMarks).\(expr.declName.baseName) }")
436+
],
437+
sourceCode: createSourceCodeExprForPropertyAccess(baseExpr, expr.declName)
349438
)
350439
}
351440

@@ -375,9 +464,11 @@ private func _parseCondition(from expr: ExprSyntax, for macro: some Freestanding
375464
return _parseCondition(from: closureExpr, for: macro, in: context)
376465
}
377466

378-
// Handle function calls.
467+
// Handle function calls and member accesses.
379468
if let functionCallExpr = expr.as(FunctionCallExprSyntax.self) {
380469
return _parseCondition(from: functionCallExpr, for: macro, in: context)
470+
} else if let memberAccessExpr = expr.as(MemberAccessExprSyntax.self) {
471+
return _parseCondition(from: memberAccessExpr, for: macro, in: context)
381472
}
382473

383474
// Parentheses are parsed as if they were tuples, so (true && false) appears

Sources/TestingMacros/Support/SourceCodeCapturing.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,12 @@ func createSourceCodeExprForFunctionCall(_ value: (some SyntaxProtocol)?, _ func
7575

7676
return ".__functionCall(\(arguments))"
7777
}
78+
79+
func createSourceCodeExprForPropertyAccess(_ value: ExprSyntax, _ keyPath: DeclReferenceExprSyntax) -> ExprSyntax {
80+
let arguments = LabeledExprListSyntax {
81+
LabeledExprSyntax(expression: StringLiteralExprSyntax(content: value.trimmedDescription))
82+
LabeledExprSyntax(expression: StringLiteralExprSyntax(content: keyPath.baseName.trimmedDescription))
83+
}
84+
85+
return ".__fromPropertyAccess(\(arguments))"
86+
}

Tests/TestingMacrosTests/ConditionMacroTests.swift

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ struct ConditionMacroTests {
6666
##"Testing.__checkValue(a.b(&c, d), sourceCode: .__fromSyntaxNode("a.b(&c, d)"), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation()).__expected()"##,
6767
##"#expect(a.b(try c()))"##:
6868
##"Testing.__checkValue(a.b(try c()), sourceCode: .__fromSyntaxNode("a.b(try c())"), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation()).__expected()"##,
69+
##"#expect(a?.b(c))"##:
70+
##"Testing.__checkFunctionCall(a.self, calling: { $0?.b($1) }, c, sourceCode: .__functionCall("a", "b", (nil, "c")), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation()).__expected()"##,
71+
##"#expect(a???.b(c))"##:
72+
##"Testing.__checkFunctionCall(a.self, calling: { $0???.b($1) }, c, sourceCode: .__functionCall("a", "b", (nil, "c")), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation()).__expected()"##,
73+
##"#expect(a?.b.c(d))"##:
74+
##"Testing.__checkFunctionCall(a?.b.self, calling: { $0?.c($1) }, d, sourceCode: .__functionCall("a?.b", "c", (nil, "d")), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation()).__expected()"##,
6975
##"#expect({}())"##:
7076
##"Testing.__checkValue({}(), sourceCode: .__fromSyntaxNode("{}()"), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation()).__expected()"##,
7177
##"#expect(a.b(c: d))"##:
@@ -74,6 +80,12 @@ struct ConditionMacroTests {
7480
##"Testing.__checkValue(a.b { c }, sourceCode: .__fromSyntaxNode("a.b { c }"), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation()).__expected()"##,
7581
##"#expect(a, sourceLocation: someValue)"##:
7682
##"Testing.__checkValue(a, sourceCode: .__fromSyntaxNode("a"), comments: [], isRequired: false, sourceLocation: someValue).__expected()"##,
83+
##"#expect(a.isB)"##:
84+
##"Testing.__checkPropertyAccess(a.self, getting: { $0.isB }, sourceCode: .__fromPropertyAccess("a", "isB"), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation()).__expected()"##,
85+
##"#expect(a???.isB)"##:
86+
##"Testing.__checkPropertyAccess(a.self, getting: { $0???.isB }, sourceCode: .__fromPropertyAccess("a", "isB"), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation()).__expected()"##,
87+
##"#expect(a?.b.isB)"##:
88+
##"Testing.__checkPropertyAccess(a?.b.self, getting: { $0?.isB }, sourceCode: .__fromPropertyAccess("a?.b", "isB"), comments: [], isRequired: false, sourceLocation: Testing.SourceLocation()).__expected()"##,
7789
]
7890
)
7991
func expectMacro(input: String, expectedOutput: String) throws {
@@ -128,6 +140,12 @@ struct ConditionMacroTests {
128140
##"Testing.__checkValue(a.b(&c, d), sourceCode: .__fromSyntaxNode("a.b(&c, d)"), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
129141
##"#require(a.b(try c()))"##:
130142
##"Testing.__checkValue(a.b(try c()), sourceCode: .__fromSyntaxNode("a.b(try c())"), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
143+
##"#require(a?.b(c))"##:
144+
##"Testing.__checkFunctionCall(a.self, calling: { $0?.b($1) }, c, sourceCode: .__functionCall("a", "b", (nil, "c")), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
145+
##"#require(a???.b(c))"##:
146+
##"Testing.__checkFunctionCall(a.self, calling: { $0???.b($1) }, c, sourceCode: .__functionCall("a", "b", (nil, "c")), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
147+
##"#require(a?.b.c(d))"##:
148+
##"Testing.__checkFunctionCall(a?.b.self, calling: { $0?.c($1) }, d, sourceCode: .__functionCall("a?.b", "c", (nil, "d")), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
131149
##"#require({}())"##:
132150
##"Testing.__checkValue({}(), sourceCode: .__fromSyntaxNode("{}()"), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
133151
##"#require(a.b(c: d))"##:
@@ -136,6 +154,12 @@ struct ConditionMacroTests {
136154
##"Testing.__checkValue(a.b { c }, sourceCode: .__fromSyntaxNode("a.b { c }"), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
137155
##"#require(a, sourceLocation: someValue)"##:
138156
##"Testing.__checkValue(a, sourceCode: .__fromSyntaxNode("a"), comments: [], isRequired: true, sourceLocation: someValue).__required()"##,
157+
##"#require(a.isB)"##:
158+
##"Testing.__checkPropertyAccess(a.self, getting: { $0.isB }, sourceCode: .__fromPropertyAccess("a", "isB"), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
159+
##"#require(a???.isB)"##:
160+
##"Testing.__checkPropertyAccess(a.self, getting: { $0???.isB }, sourceCode: .__fromPropertyAccess("a", "isB"), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
161+
##"#require(a?.b.isB)"##:
162+
##"Testing.__checkPropertyAccess(a?.b.self, getting: { $0?.isB }, sourceCode: .__fromPropertyAccess("a?.b", "isB"), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
139163
]
140164
)
141165
func requireMacro(input: String, expectedOutput: String) throws {
@@ -147,7 +171,7 @@ struct ConditionMacroTests {
147171
@Test("Unwrapping #require() macro",
148172
arguments: [
149173
##"#require(Optional<Int>.none)"##:
150-
##"Testing.__checkValue(Optional<Int>.none, sourceCode: .__fromSyntaxNode("Optional<Int>.none"), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
174+
##"Testing.__checkPropertyAccess(Optional<Int>.self, getting: { $0.none }, sourceCode: .__fromPropertyAccess("Optional<Int>", "none"), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
151175
##"#require(nil ?? 123)"##:
152176
##"Testing.__checkBinaryOperation(nil, { $0 ?? $1() }, 123, sourceCode: .__fromBinaryOperation("nil", "??", "123"), comments: [], isRequired: true, sourceLocation: Testing.SourceLocation()).__required()"##,
153177
##"#require(123 ?? nil)"##:

0 commit comments

Comments
 (0)