Skip to content

Commit b39ea60

Browse files
authored
[AutoDiff] Partial fix for @differentiable function thunking crash. (#25803)
Partial fix for `@differentiable` function thunk with output opaque abstraction pattern. Avoid `AbstractionPattern` methods that crash for opaque patterns. Small step towards fixing TF-123. Reproducers no longer crash, SIL verification fails instead. Robust fix requires more work - see TF-123 for more details.
1 parent bd14b05 commit b39ea60

File tree

2 files changed

+32
-8
lines changed

2 files changed

+32
-8
lines changed

lib/SILGen/SILGenPoly.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3257,19 +3257,23 @@ static ManagedValue createAutoDiffThunk(SILGenFunction &SGF,
32573257

32583258
auto withoutDifferentiablePattern = [](AbstractionPattern pattern)
32593259
-> AbstractionPattern {
3260-
auto patternType = cast<AnyFunctionType>(pattern.getType());
3260+
auto patternType = pattern.getAs<AnyFunctionType>();
3261+
// If pattern does not store an `AnyFunctionType`, return original pattern.
3262+
// This logic handles opaque abstraction patterns.
3263+
if (!patternType)
3264+
return pattern;
32613265
pattern.rewriteType(
32623266
pattern.getGenericSignature(),
32633267
patternType->getWithoutDifferentiability()->getCanonicalType());
32643268
return pattern;
32653269
};
32663270

3271+
auto inputOrigTypeNotDiff = withoutDifferentiablePattern(inputOrigType);
32673272
CanAnyFunctionType inputSubstTypeNotDiff(
32683273
inputSubstType->getWithoutDifferentiability());
3269-
auto inputOrigTypeNotDiff = withoutDifferentiablePattern(inputOrigType);
3274+
auto outputOrigTypeNotDiff = withoutDifferentiablePattern(outputOrigType);
32703275
CanAnyFunctionType outputSubstTypeNotDiff(
32713276
outputSubstType->getWithoutDifferentiability());
3272-
auto outputOrigTypeNotDiff = withoutDifferentiablePattern(outputOrigType);
32733277
auto &expectedTLNotDiff = SGF.getTypeLowering(outputOrigTypeNotDiff,
32743278
outputSubstTypeNotDiff);
32753279
// `autodiff_function_extract` is consuming; copy `fn` before passing as
@@ -3301,18 +3305,21 @@ static ManagedValue createAutoDiffThunk(SILGenFunction &SGF,
33013305
auto getAssocFnPattern =
33023306
[&](AbstractionPattern pattern, AutoDiffAssociatedFunctionKind kind)
33033307
-> AbstractionPattern {
3308+
// If pattern does not store an `AnyFunctionType`, return original
3309+
// pattern. This logic handles opaque abstraction patterns.
3310+
auto patternType = pattern.getAs<AnyFunctionType>();
3311+
if (!patternType)
3312+
return pattern;
33043313
return AbstractionPattern(
3305-
pattern.getGenericSignature(),
3306-
getAssocFnTy(cast<AnyFunctionType>(pattern.getType()), kind));
3314+
pattern.getGenericSignature(), getAssocFnTy(patternType, kind));
33073315
};
33083316
auto createAssocFnThunk = [&](AutoDiffAssociatedFunctionKind kind)
33093317
-> ManagedValue {
3318+
auto assocFnInputOrigType = getAssocFnPattern(inputOrigTypeNotDiff, kind);
33103319
auto assocFnInputSubstType = getAssocFnTy(inputSubstTypeNotDiff, kind);
3311-
auto assocFnInputOrigType = getAssocFnPattern(inputOrigTypeNotDiff,
3312-
kind);
3313-
auto assocFnOutputSubstType = getAssocFnTy(outputSubstTypeNotDiff, kind);
33143320
auto assocFnOutputOrigType = getAssocFnPattern(outputOrigTypeNotDiff,
33153321
kind);
3322+
auto assocFnOutputSubstType = getAssocFnTy(outputSubstTypeNotDiff, kind);
33163323
auto &assocFnExpectedTL = SGF.getTypeLowering(assocFnOutputOrigType,
33173324
assocFnOutputSubstType);
33183325
// `autodiff_function_extract` is consuming; copy `fn` before passing as
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: not --crash %target-swift-frontend -emit-sil %s 2>&1
2+
// REQUIRES: asserts
3+
4+
// NOTE: Remove `not --crash` from RUN line when TF-123 is fixed.
5+
6+
// FIXME(TF-123): `@differentiable` function thunking with opaque
7+
// abstraction patterns.
8+
func blackHole(_ x: Any) {}
9+
let f: @differentiable (Float) -> Float = { $0 }
10+
blackHole(f)
11+
12+
// FIXME(TF-123): `@differentiable` function thunking with opaque
13+
// abstraction patterns.
14+
struct TF_123 {
15+
var f: @differentiable (Float) -> Float
16+
}
17+
_ = \TF_123.f

0 commit comments

Comments
 (0)