Skip to content

Commit e29c19c

Browse files
authored
[AutoDiff] [Sema] Fix '@differentiable' witness matching regression. (#34533)
When checking protocol conformances with `@differentiable` requirements, the type checker is supposed to accept omissions of `@differentiable` attributes when there exsits an attribute that covers a superset of the differentiation configuration. This was accidentally regressed in #33776 which made the following test case fail to compile. This is fixed by adjusting the witness matching conditions. ```swift // rdar://70348904 reproducer: public protocol P: Differentiable { @differentiable(wrt: self) @differentiable(wrt: (self, x)) func foo(_ x: Float) -> Float } public struct S: P {} extension S { // This had worked until #33776. @differentiable(wrt: (self, x)) public func foo(_ x: Float) -> Float { x } } ``` Also fix some suboptimal diagnostics where more information could be shown. Resolves rdar://70348904.
1 parent 16dec4c commit e29c19c

File tree

5 files changed

+51
-46
lines changed

5 files changed

+51
-46
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3121,8 +3121,6 @@ ERROR(differentiable_attr_layout_req_unsupported,none,
31213121
())
31223122
ERROR(overriding_decl_missing_differentiable_attr,none,
31233123
"overriding declaration is missing attribute '%0'", (StringRef))
3124-
NOTE(protocol_witness_missing_differentiable_attr,none,
3125-
"candidate is missing attribute '%0'", (StringRef))
31263124
NOTE(protocol_witness_missing_differentiable_attr_invalid_context,none,
31273125
"candidate is missing explicit '%0' attribute to satisfy requirement %1 "
31283126
"(in protocol %3); explicit attribute is necessary because candidate is "

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -383,20 +383,27 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
383383
reqDiffAttr->getParameterIndices()))
384384
supersetConfig = witnessConfig;
385385
}
386+
387+
// If no exact witness derivative configuration was found, check conditions
388+
// for creating an implicit witness `@differentiable` attribute with the
389+
// exact derivative configuration.
386390
if (!foundExactConfig) {
387-
// If no exact witness derivative configuration was found, check
388-
// conditions for creating an implicit witness `@differentiable` attribute
389-
// with the exact derivative configuration.
390-
391-
// If witness is declared in a different file or type context than the
392-
// conformance, we should not create an implicit `@differentiable`
393-
// attribute on the witness. Produce an error.
394-
auto sameTypeContext =
395-
dc->getInnermostTypeContext() ==
391+
auto witnessInDifferentFile =
392+
dc->getParentSourceFile() !=
393+
witness->getDeclContext()->getParentSourceFile();
394+
auto witnessInDifferentTypeContext =
395+
dc->getInnermostTypeContext() !=
396396
witness->getDeclContext()->getInnermostTypeContext();
397-
auto sameModule = dc->getModuleScopeContext() ==
398-
witness->getDeclContext()->getModuleScopeContext();
399-
if (!sameTypeContext || !sameModule) {
397+
// Produce an error instead of creating an implicit `@differentiable`
398+
// attribute if any of the following conditions are met:
399+
// - The witness is in a different file than the conformance
400+
// declaration.
401+
// - The witness is in a different type context (i.e. extension) than
402+
// the conformance declaration, and there is no existing
403+
// `@differentiable` attribute that covers the required differentiation
404+
// parameters.
405+
if (witnessInDifferentFile ||
406+
(witnessInDifferentTypeContext && !supersetConfig)) {
400407
// FIXME(TF-1014): `@differentiable` attribute diagnostic does not
401408
// appear if associated type inference is involved.
402409
if (auto *vdWitness = dyn_cast<VarDecl>(witness)) {
@@ -2492,30 +2499,14 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance,
24922499
llvm::raw_string_ostream os(reqDiffAttrString);
24932500
reqAttr->print(os, req, omitWrtClause);
24942501
os.flush();
2495-
// If the witness is declared in a different file or type context than the
2496-
// conformance, emit a specialized diagnostic.
2497-
auto sameModule = conformance->getDeclContext()->getModuleScopeContext() !=
2498-
witness->getDeclContext()->getModuleScopeContext();
2499-
auto sameTypeContext =
2500-
conformance->getDeclContext()->getInnermostTypeContext() !=
2501-
witness->getDeclContext()->getInnermostTypeContext();
2502-
if (sameModule || sameTypeContext) {
2503-
diags
2504-
.diagnose(
2505-
witness,
2506-
diag::
2507-
protocol_witness_missing_differentiable_attr_invalid_context,
2508-
reqDiffAttrString, req->getName(), conformance->getType(),
2509-
conformance->getProtocol()->getDeclaredInterfaceType())
2510-
.fixItInsert(match.Witness->getStartLoc(), reqDiffAttrString + ' ');
2511-
}
2512-
// Otherwise, emit a general "missing attribute" diagnostic.
2513-
else {
2514-
diags
2515-
.diagnose(witness, diag::protocol_witness_missing_differentiable_attr,
2516-
reqDiffAttrString)
2517-
.fixItInsert(witness->getStartLoc(), reqDiffAttrString + ' ');
2518-
}
2502+
diags
2503+
.diagnose(
2504+
witness,
2505+
diag::
2506+
protocol_witness_missing_differentiable_attr_invalid_context,
2507+
reqDiffAttrString, req->getName(), conformance->getType(),
2508+
conformance->getProtocol()->getDeclaredInterfaceType())
2509+
.fixItInsert(match.Witness->getStartLoc(), reqDiffAttrString + ' ');
25192510
break;
25202511
}
25212512
case MatchKind::EnumCaseWithAssociatedValues:

test/AutoDiff/Sema/ImplicitDifferentiableAttributeCrossFile/Inputs/other_file.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ protocol Protocol1: Differentiable {
1616
func internalMethod3(_ x: Float) -> Float
1717
}
1818

19-
protocol Protocol2: Differentiable {
19+
public protocol Protocol2: Differentiable {
20+
@differentiable(wrt: self)
2021
@differentiable(wrt: (self, x))
2122
func internalMethod4(_ x: Float) -> Float
2223
}

test/AutoDiff/Sema/ImplicitDifferentiableAttributeCrossFile/main.swift

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
// RUN: %target-swift-frontend -c %s -primary-file %S/Inputs/other_file.swift
2424
// RUN: not %target-build-swift %s %S/Inputs/other_file.swift
2525

26+
import _Differentiation
27+
2628
// Error: conformance is in different file than witnesses.
2729
// expected-error @+1 {{type 'ConformingStruct' does not conform to protocol 'Protocol1'}}
2830
extension ConformingStruct: Protocol1 {}
@@ -33,3 +35,16 @@ extension ConformingStruct: Protocol2 {
3335
x
3436
}
3537
}
38+
39+
public final class ConformingStructWithSupersetAttr: Protocol2 {}
40+
41+
// rdar://70348904: Witness mismatch failure when a matching witness with a *superset* `@differentiable`
42+
// attribute is specified.
43+
//
44+
// Note that public witnesses are required to explicitly specify `@differentiable` attributes except
45+
// those w.r.t. parameters that have already been covered by an existing `@differentiable` attribute.
46+
extension ConformingStructWithSupersetAttr {
47+
// @differentiable(wrt: self) // Omitting this is okay.
48+
@differentiable
49+
public func internalMethod4(_ x: Float) -> Float { x }
50+
}

test/AutoDiff/Sema/differentiable_attr_type_checking.swift

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -291,39 +291,39 @@ public struct PublicDiffAttrConformance: ProtocolRequirements {
291291
var y: Float
292292

293293
// FIXME(TF-284): Fix unexpected diagnostic.
294-
// expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{10-10=@differentiable }}
294+
// expected-note @+2 {{candidate is missing explicit '@differentiable' attribute to satisfy requirement}} {{10-10=@differentiable }}
295295
// expected-note @+1 {{candidate has non-matching type '(x: Float, y: Float)'}}
296296
public init(x: Float, y: Float) {
297297
self.x = x
298298
self.y = y
299299
}
300300

301301
// FIXME(TF-284): Fix unexpected diagnostic.
302-
// expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{10-10=@differentiable }}
302+
// expected-note @+2 {{candidate is missing explicit '@differentiable' attribute to satisfy requirement}} {{10-10=@differentiable }}
303303
// expected-note @+1 {{candidate has non-matching type '(x: Float, y: Int)'}}
304304
public init(x: Float, y: Int) {
305305
self.x = x
306306
self.y = Float(y)
307307
}
308308

309-
// expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{10-10=@differentiable }}
309+
// expected-note @+2 {{candidate is missing explicit '@differentiable' attribute to satisfy requirement}} {{10-10=@differentiable }}
310310
// expected-note @+1 {{candidate has non-matching type '(Float, Float) -> Float'}}
311311
public func amb(x: Float, y: Float) -> Float {
312312
return x
313313
}
314314

315-
// expected-note @+2 {{candidate is missing attribute '@differentiable(wrt: x)'}} {{10-10=@differentiable(wrt: x) }}
315+
// expected-note @+2 {{candidate is missing explicit '@differentiable(wrt: x)' attribute to satisfy requirement}} {{10-10=@differentiable(wrt: x) }}
316316
// expected-note @+1 {{candidate has non-matching type '(Float, Int) -> Float'}}
317317
public func amb(x: Float, y: Int) -> Float {
318318
return x
319319
}
320320

321-
// expected-note @+1 {{candidate is missing attribute '@differentiable'}}
321+
// expected-note @+1 {{candidate is missing explicit '@differentiable' attribute to satisfy requirement}}
322322
public func f1(_ x: Float) -> Float {
323323
return x
324324
}
325325

326-
// expected-note @+2 {{candidate is missing attribute '@differentiable'}}
326+
// expected-note @+2 {{candidate is missing explicit '@differentiable' attribute to satisfy requirement}}
327327
@differentiable(wrt: (self, x))
328328
public func f2(_ x: Float, _ y: Float) -> Float {
329329
return x + y
@@ -558,7 +558,7 @@ public struct AttemptsToSatisfyRequirement: HasRequirement {
558558
// This `@differentiable` attribute does not satisfy the requirement because
559559
// it is mroe constrained than the requirement's `@differentiable` attribute.
560560
@differentiable(where T: CustomStringConvertible)
561-
// expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: (x, y))'}}
561+
// expected-note @+1 {{candidate is missing explicit '@differentiable(wrt: (x, y))' attribute to satisfy requirement}}
562562
public func requirement<T: Differentiable>(_ x: T, _ y: T) -> T { x }
563563
}
564564

0 commit comments

Comments
 (0)