Skip to content
This repository was archived by the owner on Jan 10, 2023. It is now read-only.

Commit 1691310

Browse files
committed
Merge branch 'main' of github.com:apple/swift into tensorflow-stage
* 'main' of github.com:apple/swift: [AutoDiff] [Sema] Fix '@differentiable' witness matching regression. (swiftlang#34533)
2 parents 1a3543c + e29c19c commit 1691310

File tree

5 files changed

+51
-46
lines changed

5 files changed

+51
-46
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3142,8 +3142,6 @@ ERROR(differentiable_attr_layout_req_unsupported,none,
31423142
())
31433143
ERROR(overriding_decl_missing_differentiable_attr,none,
31443144
"overriding declaration is missing attribute '%0'", (StringRef))
3145-
NOTE(protocol_witness_missing_differentiable_attr,none,
3146-
"candidate is missing attribute '%0'", (StringRef))
31473145
NOTE(protocol_witness_missing_differentiable_attr_invalid_context,none,
31483146
"candidate is missing explicit '%0' attribute to satisfy requirement %1 "
31493147
"(in protocol %3); explicit attribute is necessary because candidate is "

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -383,20 +383,27 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
383383
reqDiffAttr->getParameterIndices()))
384384
supersetConfig = witnessConfig;
385385
}
386+
387+
// If no exact witness derivative configuration was found, check conditions
388+
// for creating an implicit witness `@differentiable` attribute with the
389+
// exact derivative configuration.
386390
if (!foundExactConfig) {
387-
// If no exact witness derivative configuration was found, check
388-
// conditions for creating an implicit witness `@differentiable` attribute
389-
// with the exact derivative configuration.
390-
391-
// If witness is declared in a different file or type context than the
392-
// conformance, we should not create an implicit `@differentiable`
393-
// attribute on the witness. Produce an error.
394-
auto sameTypeContext =
395-
dc->getInnermostTypeContext() ==
391+
auto witnessInDifferentFile =
392+
dc->getParentSourceFile() !=
393+
witness->getDeclContext()->getParentSourceFile();
394+
auto witnessInDifferentTypeContext =
395+
dc->getInnermostTypeContext() !=
396396
witness->getDeclContext()->getInnermostTypeContext();
397-
auto sameModule = dc->getModuleScopeContext() ==
398-
witness->getDeclContext()->getModuleScopeContext();
399-
if (!sameTypeContext || !sameModule) {
397+
// Produce an error instead of creating an implicit `@differentiable`
398+
// attribute if any of the following conditions are met:
399+
// - The witness is in a different file than the conformance
400+
// declaration.
401+
// - The witness is in a different type context (i.e. extension) than
402+
// the conformance declaration, and there is no existing
403+
// `@differentiable` attribute that covers the required differentiation
404+
// parameters.
405+
if (witnessInDifferentFile ||
406+
(witnessInDifferentTypeContext && !supersetConfig)) {
400407
// FIXME(TF-1014): `@differentiable` attribute diagnostic does not
401408
// appear if associated type inference is involved.
402409
if (auto *vdWitness = dyn_cast<VarDecl>(witness)) {
@@ -2492,30 +2499,14 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance,
24922499
llvm::raw_string_ostream os(reqDiffAttrString);
24932500
reqAttr->print(os, req, omitWrtClause);
24942501
os.flush();
2495-
// If the witness is declared in a different file or type context than the
2496-
// conformance, emit a specialized diagnostic.
2497-
auto sameModule = conformance->getDeclContext()->getModuleScopeContext() !=
2498-
witness->getDeclContext()->getModuleScopeContext();
2499-
auto sameTypeContext =
2500-
conformance->getDeclContext()->getInnermostTypeContext() !=
2501-
witness->getDeclContext()->getInnermostTypeContext();
2502-
if (sameModule || sameTypeContext) {
2503-
diags
2504-
.diagnose(
2505-
witness,
2506-
diag::
2507-
protocol_witness_missing_differentiable_attr_invalid_context,
2508-
reqDiffAttrString, req->getName(), conformance->getType(),
2509-
conformance->getProtocol()->getDeclaredInterfaceType())
2510-
.fixItInsert(match.Witness->getStartLoc(), reqDiffAttrString + ' ');
2511-
}
2512-
// Otherwise, emit a general "missing attribute" diagnostic.
2513-
else {
2514-
diags
2515-
.diagnose(witness, diag::protocol_witness_missing_differentiable_attr,
2516-
reqDiffAttrString)
2517-
.fixItInsert(witness->getStartLoc(), reqDiffAttrString + ' ');
2518-
}
2502+
diags
2503+
.diagnose(
2504+
witness,
2505+
diag::
2506+
protocol_witness_missing_differentiable_attr_invalid_context,
2507+
reqDiffAttrString, req->getName(), conformance->getType(),
2508+
conformance->getProtocol()->getDeclaredInterfaceType())
2509+
.fixItInsert(match.Witness->getStartLoc(), reqDiffAttrString + ' ');
25192510
break;
25202511
}
25212512
case MatchKind::EnumCaseWithAssociatedValues:

test/AutoDiff/Sema/ImplicitDifferentiableAttributeCrossFile/Inputs/other_file.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ protocol Protocol1: Differentiable {
1616
func internalMethod3(_ x: Float) -> Float
1717
}
1818

19-
protocol Protocol2: Differentiable {
19+
public protocol Protocol2: Differentiable {
20+
@differentiable(wrt: self)
2021
@differentiable(wrt: (self, x))
2122
func internalMethod4(_ x: Float) -> Float
2223
}

test/AutoDiff/Sema/ImplicitDifferentiableAttributeCrossFile/main.swift

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
// RUN: %target-swift-frontend -c %s -primary-file %S/Inputs/other_file.swift
2424
// RUN: not %target-build-swift %s %S/Inputs/other_file.swift
2525

26+
import _Differentiation
27+
2628
// Error: conformance is in different file than witnesses.
2729
// expected-error @+1 {{type 'ConformingStruct' does not conform to protocol 'Protocol1'}}
2830
extension ConformingStruct: Protocol1 {}
@@ -33,3 +35,16 @@ extension ConformingStruct: Protocol2 {
3335
x
3436
}
3537
}
38+
39+
public final class ConformingStructWithSupersetAttr: Protocol2 {}
40+
41+
// rdar://70348904: Witness mismatch failure when a matching witness with a *superset* `@differentiable`
42+
// attribute is specified.
43+
//
44+
// Note that public witnesses are required to explicitly specify `@differentiable` attributes except
45+
// those w.r.t. parameters that have already been covered by an existing `@differentiable` attribute.
46+
extension ConformingStructWithSupersetAttr {
47+
// @differentiable(wrt: self) // Omitting this is okay.
48+
@differentiable
49+
public func internalMethod4(_ x: Float) -> Float { x }
50+
}

test/AutoDiff/Sema/differentiable_attr_type_checking.swift

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -289,39 +289,39 @@ public struct PublicDiffAttrConformance: ProtocolRequirements {
289289
var y: Float
290290

291291
// FIXME(TF-284): Fix unexpected diagnostic.
292-
// expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{10-10=@differentiable }}
292+
// expected-note @+2 {{candidate is missing explicit '@differentiable' attribute to satisfy requirement}} {{10-10=@differentiable }}
293293
// expected-note @+1 {{candidate has non-matching type '(x: Float, y: Float)'}}
294294
public init(x: Float, y: Float) {
295295
self.x = x
296296
self.y = y
297297
}
298298

299299
// FIXME(TF-284): Fix unexpected diagnostic.
300-
// expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{10-10=@differentiable }}
300+
// expected-note @+2 {{candidate is missing explicit '@differentiable' attribute to satisfy requirement}} {{10-10=@differentiable }}
301301
// expected-note @+1 {{candidate has non-matching type '(x: Float, y: Int)'}}
302302
public init(x: Float, y: Int) {
303303
self.x = x
304304
self.y = Float(y)
305305
}
306306

307-
// expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{10-10=@differentiable }}
307+
// expected-note @+2 {{candidate is missing explicit '@differentiable' attribute to satisfy requirement}} {{10-10=@differentiable }}
308308
// expected-note @+1 {{candidate has non-matching type '(Float, Float) -> Float'}}
309309
public func amb(x: Float, y: Float) -> Float {
310310
return x
311311
}
312312

313-
// expected-note @+2 {{candidate is missing attribute '@differentiable(wrt: x)'}} {{10-10=@differentiable(wrt: x) }}
313+
// expected-note @+2 {{candidate is missing explicit '@differentiable(wrt: x)' attribute to satisfy requirement}} {{10-10=@differentiable(wrt: x) }}
314314
// expected-note @+1 {{candidate has non-matching type '(Float, Int) -> Float'}}
315315
public func amb(x: Float, y: Int) -> Float {
316316
return x
317317
}
318318

319-
// expected-note @+1 {{candidate is missing attribute '@differentiable'}}
319+
// expected-note @+1 {{candidate is missing explicit '@differentiable' attribute to satisfy requirement}}
320320
public func f1(_ x: Float) -> Float {
321321
return x
322322
}
323323

324-
// expected-note @+2 {{candidate is missing attribute '@differentiable'}}
324+
// expected-note @+2 {{candidate is missing explicit '@differentiable' attribute to satisfy requirement}}
325325
@differentiable(wrt: (self, x))
326326
public func f2(_ x: Float, _ y: Float) -> Float {
327327
return x + y
@@ -556,7 +556,7 @@ public struct AttemptsToSatisfyRequirement: HasRequirement {
556556
// This `@differentiable` attribute does not satisfy the requirement because
557557
// it is mroe constrained than the requirement's `@differentiable` attribute.
558558
@differentiable(where T: CustomStringConvertible)
559-
// expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: (x, y))'}}
559+
// expected-note @+1 {{candidate is missing explicit '@differentiable(wrt: (x, y))' attribute to satisfy requirement}}
560560
public func requirement<T: Differentiable>(_ x: T, _ y: T) -> T { x }
561561
}
562562

0 commit comments

Comments
 (0)