Skip to content

[AutoDiff] Add @differentiable fixit for protocols/classes. #29332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions lib/Sema/TypeCheckDeclOverride.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,8 +613,8 @@ static bool hasOverridingDifferentiableAttribute(ValueDecl *derivedDecl,
.getAttributes<DifferentiableAttr, /*AllowInvalid*/ true>();
auto baseDAs = baseAFD->getAttrs().getAttributes<DifferentiableAttr>();

// Make sure all the `@differentiable` attributes in `baseDecl` are
// also declared in `derivedDecl`.
// Make sure all the `@differentiable` attributes on `baseDecl` are
// also declared on `derivedDecl`.
bool diagnosed = false;
for (auto *baseDA : baseDAs) {
auto baseParameters = baseDA->getParameterIndices();
Expand All @@ -640,21 +640,26 @@ static bool hasOverridingDifferentiableAttribute(ValueDecl *derivedDecl,
if (defined)
continue;
diagnosed = true;
// Omit printing wrt clause if attribute differentiation parameters match
// inferred differentiation parameters.
// Emit an error and fix-it showing the missing base declaration's
// `@differentiable` attribute.
// Omit printing `wrt:` clause if attribute's differentiability parameters
// match inferred differentiability parameters.
auto *inferredParameters =
TypeChecker::inferDifferentiabilityParameters(derivedAFD, nullptr);
bool omitWrtClause =
!baseParameters ||
baseParameters->getNumIndices() == inferredParameters->getNumIndices();
// Get `@differentiable` attribute description.
std::string baseDAString;
llvm::raw_string_ostream stream(baseDAString);
baseDA->print(stream, derivedDecl, omitWrtClause,
std::string baseDiffAttrString;
llvm::raw_string_ostream os(baseDiffAttrString);
baseDA->print(os, derivedDecl, omitWrtClause,
/*omitDerivativeFunctions*/ true);
diags.diagnose(derivedDecl,
diag::overriding_decl_missing_differentiable_attr,
StringRef(stream.str()).trim());
os.flush();
diags
.diagnose(derivedDecl,
diag::overriding_decl_missing_differentiable_attr,
baseDiffAttrString)
.fixItInsert(derivedDecl->getStartLoc(), baseDiffAttrString + ' ');
diags.diagnose(baseDecl, diag::overridden_here);
}
// If a diagnostic was produced, return false.
Expand Down
24 changes: 12 additions & 12 deletions lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2319,28 +2319,28 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance,
diags.diagnose(match.Witness, diag::protocol_witness_not_objc);
break;
case MatchKind::DifferentiableConflict: {
// Emit a note showing the missing requirement `@differentiable` attribute.
// Emit a note and fix-it showing the missing requirement `@differentiable`
// attribute.
auto *reqAttr = cast<DifferentiableAttr>(match.UnmetAttribute);
assert(reqAttr);
if (!reqAttr->getParameterIndices())
break;
// Omit printing wrt clause if attribute differentiation parameters match
// inferred differentiation parameters.
// Omit printing `wrt:` clause if attribute's differentiability
// parameters match inferred differentiability parameters.
auto *original = cast<AbstractFunctionDecl>(match.Witness);
auto *whereClauseGenEnv =
reqAttr->getDerivativeGenericEnvironment(original);
auto *inferredParameters = TypeChecker::inferDifferentiabilityParameters(
original, whereClauseGenEnv);
bool omitWrtClause = reqAttr->getParameterIndices()->getNumIndices() ==
inferredParameters->getNumIndices();
// Get `@differentiable` attribute description.
std::string reqDiffAttrString;
llvm::raw_string_ostream stream(reqDiffAttrString);
reqAttr->print(stream, req, omitWrtClause,
/*omitDerivativeFunctions*/ true);
diags.diagnose(match.Witness,
diag::protocol_witness_missing_differentiable_attr,
StringRef(stream.str()).trim());
llvm::raw_string_ostream os(reqDiffAttrString);
reqAttr->print(os, req, omitWrtClause, /*omitDerivativeFunctions*/ true);
os.flush();
diags
.diagnose(match.Witness,
diag::protocol_witness_missing_differentiable_attr,
reqDiffAttrString)
.fixItInsert(match.Witness->getStartLoc(), reqDiffAttrString + ' ');
break;
}
}
Expand Down
28 changes: 14 additions & 14 deletions test/AutoDiff/Sema/differentiable_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ protocol ProtocolRequirements: Differentiable {
}

protocol ProtocolRequirementsRefined: ProtocolRequirements {
// expected-error @+1 {{overriding declaration is missing attribute '@differentiable'}}
// expected-error @+1 {{overriding declaration is missing attribute '@differentiable'}} {{3-3=@differentiable }}
func f1(_ x: Float) -> Float
}

Expand All @@ -719,39 +719,39 @@ struct DiffAttrConformanceErrors: ProtocolRequirements {
var y: Float

// FIXME(TF-284): Fix unexpected diagnostic.
// expected-note @+2 {{candidate is missing attribute '@differentiable'}}
// expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }}
// expected-note @+1 {{candidate has non-matching type '(x: Float, y: Float)'}}
init(x: Float, y: Float) {
self.x = x
self.y = y
}

// FIXME(TF-284): Fix unexpected diagnostic.
// expected-note @+2 {{candidate is missing attribute '@differentiable'}}
// expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }}
// expected-note @+1 {{candidate has non-matching type '(x: Float, y: Int)'}}
init(x: Float, y: Int) {
self.x = x
self.y = Float(y)
}

// expected-note @+2 {{candidate is missing attribute '@differentiable'}}
// expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }}
// expected-note @+1 {{candidate has non-matching type '(Float, Float) -> Float'}}
func amb(x: Float, y: Float) -> Float {
return x
}

// expected-note @+2 {{candidate is missing attribute '@differentiable(wrt: x)'}}
// expected-note @+2 {{candidate is missing attribute '@differentiable(wrt: x)'}} {{3-3=@differentiable(wrt: x) }}
// expected-note @+1 {{candidate has non-matching type '(Float, Int) -> Float'}}
func amb(x: Float, y: Int) -> Float {
return x
}

// expected-note @+1 {{candidate is missing attribute '@differentiable'}}
// expected-note @+1 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }}
func f1(_ x: Float) -> Float {
return x
}

// expected-note @+2 {{candidate is missing attribute '@differentiable'}}
// expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }}
@differentiable(wrt: (self, x))
func f2(_ x: Float, _ y: Float) -> Float {
return x + y
Expand All @@ -774,15 +774,15 @@ protocol ProtocolRequirementsWithDefault {
func f1(_ x: Float) -> Float
}
extension ProtocolRequirementsWithDefault {
// expected-note @+1 {{candidate is missing attribute '@differentiable'}}
// expected-note @+1 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }}
func f1(_ x: Float) -> Float { x }
}
// expected-error @+1 {{type 'DiffAttrConformanceErrors2' does not conform to protocol 'ProtocolRequirementsWithDefault'}}
struct DiffAttrConformanceErrors2: ProtocolRequirementsWithDefault {
typealias TangentVector = DummyTangentVector
mutating func move(along _: TangentVector) {}

// expected-note @+1 {{candidate is missing attribute '@differentiable'}}
// expected-note @+1 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }}
func f1(_ x: Float) -> Float { x }
}

Expand All @@ -794,7 +794,7 @@ protocol NotRefiningDiffable {

// expected-error @+1 {{type 'CertainlyNotDiffableWrtSelf' does not conform to protocol 'NotRefiningDiffable'}}
struct CertainlyNotDiffableWrtSelf: NotRefiningDiffable {
// expected-note @+1 {{candidate is missing attribute '@differentiable'}}
// expected-note @+1 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }}
func a(_ x: Float) -> Float { return x * 5.0 }
}

Expand All @@ -813,7 +813,7 @@ struct TF285MissingOneDiffAttr: TF285 {

// Requirement is missing an attribute.
@differentiable(wrt: x)
// expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: (x, y))}}
// expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: (x, y))}} {{3-3=@differentiable(wrt: (x, y)) }}
func foo(x: Float, y: Float) -> Float {
return x
}
Expand Down Expand Up @@ -1047,7 +1047,7 @@ public protocol DifferentiableDistribution: Differentiable, Distribution {
// Adding a more general `@differentiable` attribute.
public protocol DoubleDifferentiableDistribution: DifferentiableDistribution
where Value: Differentiable {
// expected-error @+1 {{overriding declaration is missing attribute '@differentiable(wrt: self)'}}
// expected-error @+1 {{overriding declaration is missing attribute '@differentiable(wrt: self)'}} {{3-3=@differentiable(wrt: self) }}
func logProbability(of value: Value) -> Float
}

Expand Down Expand Up @@ -1137,8 +1137,8 @@ class Super: Differentiable {
}

class Sub: Super {
// expected-error @+2 {{overriding declaration is missing attribute '@differentiable(wrt: x)'}}
// expected-error @+1 {{overriding declaration is missing attribute '@differentiable'}}
// expected-error @+2 {{overriding declaration is missing attribute '@differentiable(wrt: x)'}} {{12-12=@differentiable(wrt: x) }}
// expected-error @+1 {{overriding declaration is missing attribute '@differentiable'}} {{12-12=@differentiable }}
override func testMissingAttributes(_ x: Float) -> Float { x }

// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}
Expand Down