Skip to content

Commit 31c81c6

Browse files
authored
[AutoDiff] Fix stdlib @derivative(of: subscript) crash. (#28935)
Fix `@derivative(of: subscript)` crash for SIMD{n} types. Simplify `findAbstractFunctionDecl`. If `lookupContext` is a type context, use its `self` type for member lookup. Resolves TF-1090. Exposes TF-1094: - During `.swiftinterface` compilation, unserialized `@derivative` function is not registered for serialized `@differentiable` function. - Workaround: mark `SIMD{n}._vjpSubscript(_:)` as `@inlinable` (serialized).
1 parent 7ffb598 commit 31c81c6

File tree

3 files changed

+45
-12
lines changed

3 files changed

+45
-12
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3195,20 +3195,15 @@ static AbstractFunctionDecl *findAbstractFunctionDecl(
31953195

31963196
// Perform lookup.
31973197
LookupResult results;
3198+
// If `baseType` is not null but `lookupContext` is a type context, set
3199+
// `baseType` to the `self` type of `lookupContext` to perform member lookup.
3200+
if (!baseType && lookupContext->isTypeContext())
3201+
baseType = lookupContext->getSelfTypeInContext();
31983202
if (baseType) {
31993203
results = TypeChecker::lookupMember(lookupContext, baseType, funcName);
32003204
} else {
32013205
results = TypeChecker::lookupUnqualified(lookupContext, funcName,
32023206
funcNameLoc, lookupOptions);
3203-
3204-
// If looking up an operator within a type context, look specifically within
3205-
// the type context.
3206-
// This tries to resolve unqualified operators, like `+`.
3207-
if (funcName.isOperator() && lookupContext->isTypeContext()) {
3208-
if (auto tmp = TypeChecker::lookupMember(
3209-
lookupContext, lookupContext->getSelfTypeInContext(), funcName))
3210-
results = tmp;
3211-
}
32123207
}
32133208

32143209
// Initialize error flags.

stdlib/public/core/SIMDVectorTypes.swift.gyb

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ public struct SIMD${n}<Scalar>: SIMD where Scalar: SIMDScalar {
4646

4747
/// Accesses the scalar at the specified position.
4848
// SWIFT_ENABLE_TENSORFLOW
49-
@differentiable(vjp: _vjpSubscript
49+
@differentiable(
5050
where Scalar : EuclideanDifferentiable & BinaryFloatingPoint,
5151
Scalar.TangentVector : BinaryFloatingPoint)
5252
public subscript(index: Int) -> Scalar {
@@ -206,9 +206,11 @@ extension SIMD${n} : EuclideanDifferentiable
206206
extension SIMD${n}
207207
where Scalar : EuclideanDifferentiable & BinaryFloatingPoint,
208208
Scalar.TangentVector : BinaryFloatingPoint {
209-
@usableFromInline
209+
// NOTE(TF-1094): serialized `@derivative` for `.swiftinterface` compilation.
210+
@inlinable
211+
@derivative(of: subscript(_:))
210212
internal func _vjpSubscript(index: Int)
211-
-> (Scalar, (Scalar.TangentVector) -> TangentVector) {
213+
-> (value: Scalar, pullback: (Scalar.TangentVector) -> TangentVector) {
212214
return (self[index], { v in
213215
var zeros = Self.zero
214216
zeros[index] = Scalar(v)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// RUN: %empty-directory(%t)
2+
// RUN: %target-swift-frontend -typecheck -emit-module-interface-path - %s -swift-version 5 -enable-library-evolution -module-name Module > %t/Module.swiftinterface
3+
// RUN: %target-swift-frontend -emit-silgen %t/Module.swiftinterface | %FileCheck %s --check-prefix=CHECK-SILGEN
4+
// RUN: not %target-swift-frontend -compile-module-from-interface %t/Module.swiftinterface -o %t/Module.swiftmodule 2>&1 | %FileCheck %s --check-prefix=CHECK-COMPILE
5+
6+
// TF-1094: Derivative registration fails for `.swiftinterface` compilation when
7+
// original `@differentiable` function is serialized but `@derivative` function
8+
// is unserialized.
9+
10+
@inlinable // serialized
11+
@differentiable
12+
@_silgen_name("foo")
13+
public func foo(_ x: Float) -> Float {
14+
fatalError()
15+
}
16+
17+
@usableFromInline // not serialized
18+
@derivative(of: foo)
19+
@_silgen_name("vjp_foo")
20+
func vjpFoo(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
21+
return (x, { $0 })
22+
}
23+
24+
// Missing differentiability witness VJP entry because `func vjpFoo` is bodiless
25+
// in the `.swiftinterface` file and not lowered to a SIL function.
26+
27+
// CHECK-SILGEN-LABEL: // differentiability witness for foo
28+
// CHECK-SILGEN-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] @foo : $@convention(thin) (Float) -> Float {
29+
// CHECK-SILGEN-NEXT: }
30+
31+
// CHECK-SILGEN-LABEL: sil [serialized] [ossa] @foo
32+
// CHECK-SILGEN-NOT: sil {{.*}} @vjp_foo
33+
34+
// CHECK-COMPILE: Module.swiftinterface:5:2: error: function is not differentiable
35+
// CHECK-COMPILE: Module.swiftinterface:7:24: note: when differentiating this function definition
36+
// CHECK-COMPILE: Module.swiftinterface:9:1: note: missing return for differentiation

0 commit comments

Comments
 (0)