Skip to content

[AutoDiff] [Sema] Limit implicit @differentiable attribute creation. #33776

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -3112,12 +3112,11 @@ ERROR(overriding_decl_missing_differentiable_attr,none,
"overriding declaration is missing attribute '%0'", (StringRef))
NOTE(protocol_witness_missing_differentiable_attr,none,
"candidate is missing attribute '%0'", (StringRef))
NOTE(protocol_witness_missing_differentiable_attr_nonpublic_other_file,none,
"non-public %1 %2 must have explicit '%0' attribute to satisfy "
"requirement %3 %4 (in protocol %6) because it is declared in a different "
"file than the conformance of %5 to %6",
(StringRef, DescriptiveDeclKind, DeclName, DescriptiveDeclKind, DeclName,
Type, Type))
NOTE(protocol_witness_missing_differentiable_attr_invalid_context,none,
"candidate is missing explicit '%0' attribute to satisfy requirement %1 "
"(in protocol %3); explicit attribute is necessary because candidate is "
"declared in a different type context or file than the conformance of %2 "
"to %3", (StringRef, DeclName, Type, Type))

// @derivative
ERROR(derivative_attr_expected_result_tuple,none,
Expand Down
59 changes: 35 additions & 24 deletions lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -384,23 +384,19 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
supersetConfig = witnessConfig;
}
if (!foundExactConfig) {
bool success = false;
// If no exact witness derivative configuration was found, check
// conditions for creating an implicit witness `@differentiable` attribute
// with the exact derivative configuration:
// - If the witness has a "superset" derivative configuration.
// - If the witness is less than public and is declared in the same file
// as the conformance.
// - `@differentiable` attributes are really only significant for public
// declarations: it improves usability to not require explicit
// `@differentiable` attributes for less-visible declarations.
bool createImplicitWitnessAttribute =
supersetConfig || witness->getFormalAccess() < AccessLevel::Public;
// If the witness has less-than-public visibility and is declared in a
// different file than the conformance, produce an error.
if (!supersetConfig && witness->getFormalAccess() < AccessLevel::Public &&
dc->getModuleScopeContext() !=
witness->getDeclContext()->getModuleScopeContext()) {
// with the exact derivative configuration.

// If witness is declared in a different file or type context than the
// conformance, we should not create an implicit `@differentiable`
// attribute on the witness. Produce an error.
auto sameTypeContext =
dc->getInnermostTypeContext() ==
witness->getDeclContext()->getInnermostTypeContext();
auto sameModule = dc->getModuleScopeContext() ==
witness->getDeclContext()->getModuleScopeContext();
if (!sameTypeContext || !sameModule) {
// FIXME(TF-1014): `@differentiable` attribute diagnostic does not
// appear if associated type inference is involved.
if (auto *vdWitness = dyn_cast<VarDecl>(witness)) {
Expand All @@ -412,6 +408,20 @@ matchWitnessDifferentiableAttr(DeclContext *dc, ValueDecl *req,
reqDiffAttr);
}
}

// Otherwise, the witness must:
// - Have a "superset" derivative configuration.
// - Have less than public visibility.
// - `@differentiable` attributes are really only significant for
// public declarations: it improves usability to not require
// explicit `@differentiable` attributes for less-visible
// declarations.
//
// If these conditions are met, an implicit `@differentiable` attribute
// with the exact derivative configuration can be created.
bool success = false;
bool createImplicitWitnessAttribute =
supersetConfig || witness->getFormalAccess() < AccessLevel::Public;
if (createImplicitWitnessAttribute) {
auto derivativeGenSig = witnessAFD->getGenericSignature();
if (supersetConfig)
Expand Down Expand Up @@ -2448,19 +2458,20 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance,
llvm::raw_string_ostream os(reqDiffAttrString);
reqAttr->print(os, req, omitWrtClause);
os.flush();
// If the witness has less-than-public visibility and is declared in a
// different file than the conformance, emit a specialized diagnostic.
if (witness->getFormalAccess() < AccessLevel::Public &&
conformance->getDeclContext()->getModuleScopeContext() !=
witness->getDeclContext()->getModuleScopeContext()) {
// If the witness is declared in a different file or type context than the
// conformance, emit a specialized diagnostic.
auto sameModule = conformance->getDeclContext()->getModuleScopeContext() !=
witness->getDeclContext()->getModuleScopeContext();
auto sameTypeContext =
conformance->getDeclContext()->getInnermostTypeContext() !=
witness->getDeclContext()->getInnermostTypeContext();
if (sameModule || sameTypeContext) {
diags
.diagnose(
witness,
diag::
protocol_witness_missing_differentiable_attr_nonpublic_other_file,
reqDiffAttrString, witness->getDescriptiveKind(),
witness->getName(), req->getDescriptiveKind(),
req->getName(), conformance->getType(),
protocol_witness_missing_differentiable_attr_invalid_context,
reqDiffAttrString, req->getName(), conformance->getType(),
conformance->getProtocol()->getDeclaredInterfaceType())
.fixItInsert(match.Witness->getStartLoc(), reqDiffAttrString + ' ');
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import _Differentiation

protocol Protocol1: Differentiable {
// expected-note @+2 {{protocol requires function 'internalMethod1' with type '(Float) -> Float'}}
@differentiable(wrt: (self, x))
func internalMethod1(_ x: Float) -> Float

// expected-note @+3 {{protocol requires function 'internalMethod2' with type '(Float) -> Float'}}
@differentiable(wrt: x)
@differentiable(wrt: (self, x))
func internalMethod2(_ x: Float) -> Float

// expected-note @+3 {{protocol requires function 'internalMethod3' with type '(Float) -> Float'}}
@differentiable(wrt: x)
@differentiable(wrt: (self, x))
func internalMethod3(_ x: Float) -> Float
}

protocol Protocol2: Differentiable {
@differentiable(wrt: (self, x))
func internalMethod4(_ x: Float) -> Float
}

// Note:
// - No `ConformingStruct: Protocol1` conformance exists in this file, so this
// file should compile just file.
// - A `ConformingStruct: Protocol1` conformance in a different file should be
// diagnosed to prevent linker errors. Without a diagnostic, compilation of
// the other file creates external references to symbols for implicit
// `@differentiable` attributes, even though no such symbols exist.
// Context: https://github.com/apple/swift/pull/29771#issuecomment-585059721

struct ConformingStruct: Differentiable {
// Error for missing `@differentiable` attribute.
// expected-note @+1 {{candidate is missing explicit '@differentiable' attribute to satisfy requirement 'internalMethod1' (in protocol 'Protocol1'); explicit attribute is necessary because candidate is declared in a different type context or file than the conformance of 'ConformingStruct' to 'Protocol1'}} {{3-3=@differentiable }}
func internalMethod1(_ x: Float) -> Float {
x
}

// Error for missing `@differentiable` superset attribute.
// expected-note @+2 {{candidate is missing explicit '@differentiable' attribute to satisfy requirement 'internalMethod2' (in protocol 'Protocol1'); explicit attribute is necessary because candidate is declared in a different type context or file than the conformance of 'ConformingStruct' to 'Protocol1'}} {{3-3=@differentiable }}
@differentiable(wrt: x)
func internalMethod2(_ x: Float) -> Float {
x
}

// Error for missing `@differentiable` subset attribute.
// expected-note @+2 {{candidate is missing explicit '@differentiable(wrt: x)' attribute to satisfy requirement 'internalMethod3' (in protocol 'Protocol1'); explicit attribute is necessary because candidate is declared in a different type context or file than the conformance of 'ConformingStruct' to 'Protocol1'}} {{3-3=@differentiable(wrt: x) }}
@differentiable(wrt: (self, x))
func internalMethod3(_ x: Float) -> Float {
x
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import _Differentiation

protocol P1: Differentiable {
@differentiable(wrt: self)
// expected-note @+1 {{protocol requires function 'callAsFunction' with type '(Float) -> Float'}}
func callAsFunction(_ input: Float) -> Float
}

protocol P2: P1 {}

extension P2 {
@differentiable(wrt: (self, input))
// expected-note @+1 {{candidate is missing explicit '@differentiable(wrt: self)' attribute to satisfy requirement 'callAsFunction' (in protocol 'P1'); explicit attribute is necessary because candidate is declared in a different type context or file than the conformance of 'ConformingStruct' to 'P1'}}
public func callAsFunction(_ input: Float) -> Float {
return input
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Test missing protocol requirement `@differentiable` attribute errors for
// non-public protocol witnesses, when the protocol conformance is declared in a
// separate file from witnesses.
//
// Implicit `@differentiable` attributes cannot be generated for protocol
// witnesses when the conformance is declared from a separate file from the
// witness. Otherwise, compilation of the file containing the conformance
// creates references to external symbols for implicit `@differentiable`
// attributes, even though no such symbols exist.
//
// Context: https://github.com/apple/swift/pull/29771#issuecomment-585059721

// Note: `swiftc main.swift other_file.swift` runs three commands:
// - `swiftc -frontend -primary-file main.swift other_file.swift -o ...`
// - `swiftc -frontend main.swift -primary-file other_file.swift -o ...`
// - `/usr/bin/ld ...`
//
// `%target-build-swift` performs `swiftc main.swift other_file.swift`, so it is expected to fail (hence `not`).
// `swiftc -frontend -primary-file main.swift other_file.swift` should fail, so `-verify` is needed.
// `swiftc -frontend main.swift -primary-file other_file.swift` should succeed, so no need for `-verify`.

// RUN: %target-swift-frontend -c -verify -primary-file %s %S/Inputs/other_file.swift
// RUN: %target-swift-frontend -c %s -primary-file %S/Inputs/other_file.swift
// RUN: not %target-build-swift %s %S/Inputs/other_file.swift

// Error: conformance is in different file than witnesses.
// expected-error @+1 {{type 'ConformingStruct' does not conform to protocol 'Protocol1'}}
extension ConformingStruct: Protocol1 {}

// No error: conformance is in same file as witnesses.
extension ConformingStruct: Protocol2 {
func internalMethod4(_ x: Float) -> Float {
x
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: %target-swift-frontend -c -verify -primary-file %s %S/Inputs/other_file_protocol_default_implementation_witness.swift

// SR-13455: Test missing protocol requirement `@differentiable` attribute
// errors for protocol witnesses declared in a different file than the protocol
// conformance.
//
// This test case specifically tests protocol extension method witnesses.

import _Differentiation

// expected-error @+1 {{type 'ConformingStruct' does not conform to protocol 'P1'}}
struct ConformingStruct: P2 {}