Skip to content

Commit aabc849

Browse files
authored
[AutoDiff] Add @differentiable fix-it for protocols/classes. (#29332)
For protocol requirements and class members with `@differentiable` attribute, conforming types and subclasses must have the same `@differentiable` attribute (or one with a superset of differentiability parameters) on implementing/ overriding declarations. For implementing/overriding declarations that are missing a `@differentiable` attribute, emit a fix-it that adds the missing attribute. Resolves TF-1118.
1 parent 91fe7db commit aabc849

File tree

3 files changed

+41
-36
lines changed

3 files changed

+41
-36
lines changed

lib/Sema/TypeCheckDeclOverride.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -613,8 +613,8 @@ static bool hasOverridingDifferentiableAttribute(ValueDecl *derivedDecl,
613613
.getAttributes<DifferentiableAttr, /*AllowInvalid*/ true>();
614614
auto baseDAs = baseAFD->getAttrs().getAttributes<DifferentiableAttr>();
615615

616-
// Make sure all the `@differentiable` attributes in `baseDecl` are
617-
// also declared in `derivedDecl`.
616+
// Make sure all the `@differentiable` attributes on `baseDecl` are
617+
// also declared on `derivedDecl`.
618618
bool diagnosed = false;
619619
for (auto *baseDA : baseDAs) {
620620
auto baseParameters = baseDA->getParameterIndices();
@@ -640,21 +640,26 @@ static bool hasOverridingDifferentiableAttribute(ValueDecl *derivedDecl,
640640
if (defined)
641641
continue;
642642
diagnosed = true;
643-
// Omit printing wrt clause if attribute differentiation parameters match
644-
// inferred differentiation parameters.
643+
// Emit an error and fix-it showing the missing base declaration's
644+
// `@differentiable` attribute.
645+
// Omit printing `wrt:` clause if attribute's differentiability parameters
646+
// match inferred differentiability parameters.
645647
auto *inferredParameters =
646648
TypeChecker::inferDifferentiabilityParameters(derivedAFD, nullptr);
647649
bool omitWrtClause =
648650
!baseParameters ||
649651
baseParameters->getNumIndices() == inferredParameters->getNumIndices();
650652
// Get `@differentiable` attribute description.
651-
std::string baseDAString;
652-
llvm::raw_string_ostream stream(baseDAString);
653-
baseDA->print(stream, derivedDecl, omitWrtClause,
653+
std::string baseDiffAttrString;
654+
llvm::raw_string_ostream os(baseDiffAttrString);
655+
baseDA->print(os, derivedDecl, omitWrtClause,
654656
/*omitDerivativeFunctions*/ true);
655-
diags.diagnose(derivedDecl,
656-
diag::overriding_decl_missing_differentiable_attr,
657-
StringRef(stream.str()).trim());
657+
os.flush();
658+
diags
659+
.diagnose(derivedDecl,
660+
diag::overriding_decl_missing_differentiable_attr,
661+
baseDiffAttrString)
662+
.fixItInsert(derivedDecl->getStartLoc(), baseDiffAttrString + ' ');
658663
diags.diagnose(baseDecl, diag::overridden_here);
659664
}
660665
// If a diagnostic was produced, return false.

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2319,28 +2319,28 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance,
23192319
diags.diagnose(match.Witness, diag::protocol_witness_not_objc);
23202320
break;
23212321
case MatchKind::DifferentiableConflict: {
2322-
// Emit a note showing the missing requirement `@differentiable` attribute.
2322+
// Emit a note and fix-it showing the missing requirement `@differentiable`
2323+
// attribute.
23232324
auto *reqAttr = cast<DifferentiableAttr>(match.UnmetAttribute);
23242325
assert(reqAttr);
2325-
if (!reqAttr->getParameterIndices())
2326-
break;
2327-
// Omit printing wrt clause if attribute differentiation parameters match
2328-
// inferred differentiation parameters.
2326+
// Omit printing `wrt:` clause if attribute's differentiability
2327+
// parameters match inferred differentiability parameters.
23292328
auto *original = cast<AbstractFunctionDecl>(match.Witness);
23302329
auto *whereClauseGenEnv =
23312330
reqAttr->getDerivativeGenericEnvironment(original);
23322331
auto *inferredParameters = TypeChecker::inferDifferentiabilityParameters(
23332332
original, whereClauseGenEnv);
23342333
bool omitWrtClause = reqAttr->getParameterIndices()->getNumIndices() ==
23352334
inferredParameters->getNumIndices();
2336-
// Get `@differentiable` attribute description.
23372335
std::string reqDiffAttrString;
2338-
llvm::raw_string_ostream stream(reqDiffAttrString);
2339-
reqAttr->print(stream, req, omitWrtClause,
2340-
/*omitDerivativeFunctions*/ true);
2341-
diags.diagnose(match.Witness,
2342-
diag::protocol_witness_missing_differentiable_attr,
2343-
StringRef(stream.str()).trim());
2336+
llvm::raw_string_ostream os(reqDiffAttrString);
2337+
reqAttr->print(os, req, omitWrtClause, /*omitDerivativeFunctions*/ true);
2338+
os.flush();
2339+
diags
2340+
.diagnose(match.Witness,
2341+
diag::protocol_witness_missing_differentiable_attr,
2342+
reqDiffAttrString)
2343+
.fixItInsert(match.Witness->getStartLoc(), reqDiffAttrString + ' ');
23442344
break;
23452345
}
23462346
}

test/AutoDiff/Sema/differentiable_attr_type_checking.swift

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,7 @@ protocol ProtocolRequirements: Differentiable {
706706
}
707707

708708
protocol ProtocolRequirementsRefined: ProtocolRequirements {
709-
// expected-error @+1 {{overriding declaration is missing attribute '@differentiable'}}
709+
// expected-error @+1 {{overriding declaration is missing attribute '@differentiable'}} {{3-3=@differentiable }}
710710
func f1(_ x: Float) -> Float
711711
}
712712

@@ -719,39 +719,39 @@ struct DiffAttrConformanceErrors: ProtocolRequirements {
719719
var y: Float
720720

721721
// FIXME(TF-284): Fix unexpected diagnostic.
722-
// expected-note @+2 {{candidate is missing attribute '@differentiable'}}
722+
// expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }}
723723
// expected-note @+1 {{candidate has non-matching type '(x: Float, y: Float)'}}
724724
init(x: Float, y: Float) {
725725
self.x = x
726726
self.y = y
727727
}
728728

729729
// FIXME(TF-284): Fix unexpected diagnostic.
730-
// expected-note @+2 {{candidate is missing attribute '@differentiable'}}
730+
// expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }}
731731
// expected-note @+1 {{candidate has non-matching type '(x: Float, y: Int)'}}
732732
init(x: Float, y: Int) {
733733
self.x = x
734734
self.y = Float(y)
735735
}
736736

737-
// expected-note @+2 {{candidate is missing attribute '@differentiable'}}
737+
// expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }}
738738
// expected-note @+1 {{candidate has non-matching type '(Float, Float) -> Float'}}
739739
func amb(x: Float, y: Float) -> Float {
740740
return x
741741
}
742742

743-
// expected-note @+2 {{candidate is missing attribute '@differentiable(wrt: x)'}}
743+
// expected-note @+2 {{candidate is missing attribute '@differentiable(wrt: x)'}} {{3-3=@differentiable(wrt: x) }}
744744
// expected-note @+1 {{candidate has non-matching type '(Float, Int) -> Float'}}
745745
func amb(x: Float, y: Int) -> Float {
746746
return x
747747
}
748748

749-
// expected-note @+1 {{candidate is missing attribute '@differentiable'}}
749+
// expected-note @+1 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }}
750750
func f1(_ x: Float) -> Float {
751751
return x
752752
}
753753

754-
// expected-note @+2 {{candidate is missing attribute '@differentiable'}}
754+
// expected-note @+2 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }}
755755
@differentiable(wrt: (self, x))
756756
func f2(_ x: Float, _ y: Float) -> Float {
757757
return x + y
@@ -774,15 +774,15 @@ protocol ProtocolRequirementsWithDefault {
774774
func f1(_ x: Float) -> Float
775775
}
776776
extension ProtocolRequirementsWithDefault {
777-
// expected-note @+1 {{candidate is missing attribute '@differentiable'}}
777+
// expected-note @+1 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }}
778778
func f1(_ x: Float) -> Float { x }
779779
}
780780
// expected-error @+1 {{type 'DiffAttrConformanceErrors2' does not conform to protocol 'ProtocolRequirementsWithDefault'}}
781781
struct DiffAttrConformanceErrors2: ProtocolRequirementsWithDefault {
782782
typealias TangentVector = DummyTangentVector
783783
mutating func move(along _: TangentVector) {}
784784

785-
// expected-note @+1 {{candidate is missing attribute '@differentiable'}}
785+
// expected-note @+1 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }}
786786
func f1(_ x: Float) -> Float { x }
787787
}
788788

@@ -794,7 +794,7 @@ protocol NotRefiningDiffable {
794794

795795
// expected-error @+1 {{type 'CertainlyNotDiffableWrtSelf' does not conform to protocol 'NotRefiningDiffable'}}
796796
struct CertainlyNotDiffableWrtSelf: NotRefiningDiffable {
797-
// expected-note @+1 {{candidate is missing attribute '@differentiable'}}
797+
// expected-note @+1 {{candidate is missing attribute '@differentiable'}} {{3-3=@differentiable }}
798798
func a(_ x: Float) -> Float { return x * 5.0 }
799799
}
800800

@@ -813,7 +813,7 @@ struct TF285MissingOneDiffAttr: TF285 {
813813

814814
// Requirement is missing an attribute.
815815
@differentiable(wrt: x)
816-
// expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: (x, y))}}
816+
// expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: (x, y))}} {{3-3=@differentiable(wrt: (x, y)) }}
817817
func foo(x: Float, y: Float) -> Float {
818818
return x
819819
}
@@ -1047,7 +1047,7 @@ public protocol DifferentiableDistribution: Differentiable, Distribution {
10471047
// Adding a more general `@differentiable` attribute.
10481048
public protocol DoubleDifferentiableDistribution: DifferentiableDistribution
10491049
where Value: Differentiable {
1050-
// expected-error @+1 {{overriding declaration is missing attribute '@differentiable(wrt: self)'}}
1050+
// expected-error @+1 {{overriding declaration is missing attribute '@differentiable(wrt: self)'}} {{3-3=@differentiable(wrt: self) }}
10511051
func logProbability(of value: Value) -> Float
10521052
}
10531053

@@ -1137,8 +1137,8 @@ class Super: Differentiable {
11371137
}
11381138

11391139
class Sub: Super {
1140-
// expected-error @+2 {{overriding declaration is missing attribute '@differentiable(wrt: x)'}}
1141-
// expected-error @+1 {{overriding declaration is missing attribute '@differentiable'}}
1140+
// expected-error @+2 {{overriding declaration is missing attribute '@differentiable(wrt: x)'}} {{12-12=@differentiable(wrt: x) }}
1141+
// expected-error @+1 {{overriding declaration is missing attribute '@differentiable'}} {{12-12=@differentiable }}
11421142
override func testMissingAttributes(_ x: Float) -> Float { x }
11431143

11441144
// expected-warning @+2 {{'jvp:' and 'vjp:' arguments in '@differentiable' attribute are deprecated}}

0 commit comments

Comments
 (0)