Skip to content

Commit deb70d3

Browse files
authored
[AutoDiff] Fix @differentiable attribute type checking crasher. (#26405)
Diagnose invalid `Differentiable` conformances during `@differentiable` attribute type checking. Reject conformances that cannot resolve a `TangentVector` witness type. Resolves TF-521.
1 parent 844813e commit deb70d3

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2658,9 +2658,14 @@ static bool conformsToDifferentiable(Type type, DeclContext *DC) {
26582658
auto &ctx = type->getASTContext();
26592659
auto *differentiableProto =
26602660
ctx.getProtocol(KnownProtocolKind::Differentiable);
2661-
return TypeChecker::conformsToProtocol(type, differentiableProto, DC,
2662-
ConformanceCheckFlags::InExpression)
2663-
.hasValue();
2661+
auto conf = TypeChecker::conformsToProtocol(
2662+
type, differentiableProto, DC, ConformanceCheckFlags::InExpression);
2663+
if (!conf)
2664+
return false;
2665+
// Try to get the `TangentVector` type witness, in case the conformance has
2666+
// not been fully checked and the type witness cannot be resolved.
2667+
Type tanType = conf->getTypeWitnessByName(type, ctx.Id_TangentVector);
2668+
return !tanType.isNull();
26642669
};
26652670

26662671
// SWIFT_ENABLE_TENSORFLOW

test/AutoDiff/differentiable_attr_type_checking.swift

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,36 @@ struct TF285MissingOneDiffAttr : TF285 {
751751
}
752752
}
753753

754+
// TF-521: Test invalid `@differentiable` attribute due to invalid
755+
// `Differentiable` conformance (`TangentVector` does not conform to
756+
// `AdditiveArithmetic`).
757+
struct TF_521<T: FloatingPoint> {
758+
var real: T
759+
var imaginary: T
760+
761+
// expected-error @+1 {{can only differentiate functions with results that conform to 'Differentiable', but 'TF_521<T>' does not conform to 'Differentiable'}}
762+
@differentiable(vjp: _vjpInit where T: Differentiable, T == T.TangentVector)
763+
init(real: T = 0, imaginary: T = 0) {
764+
self.real = real
765+
self.imaginary = imaginary
766+
}
767+
}
768+
// expected-error @+2 {{type 'TF_521<T>' does not conform to protocol 'Differentiable'}}
769+
// expected-note @+1 {{do you want to add protocol stubs}}
770+
extension TF_521: Differentiable where T: Differentiable {
771+
// expected-note @+1 {{possibly intended match 'TF_521<T>.TangentVector' does not conform to 'AdditiveArithmetic'}}
772+
typealias TangentVector = TF_521
773+
typealias AllDifferentiableVariables = TF_521
774+
}
775+
extension TF_521 where T: Differentiable, T == T.TangentVector {
776+
static func _vjpInit(real: T, imaginary: T) -> (TF_521, (TF_521) -> (T, T)) {
777+
return (TF_521(real: real, imaginary: imaginary), { ($0.real, $0.imaginary) })
778+
}
779+
}
780+
// expected-error @+1 {{result is not differentiable, but the function type is marked '@differentiable'}}
781+
let _: @differentiable(Float, Float) -> TF_521<Float> = { r, i in
782+
TF_521(real: r, imaginary: i)
783+
}
754784

755785
// TF-296: Infer `@differentiable` wrt parameters to be to all parameters that conform to `Differentiable`.
756786

0 commit comments

Comments
 (0)