@@ -3257,19 +3257,23 @@ static ManagedValue createAutoDiffThunk(SILGenFunction &SGF,
3257
3257
3258
3258
auto withoutDifferentiablePattern = [](AbstractionPattern pattern)
3259
3259
-> AbstractionPattern {
3260
- auto patternType = cast<AnyFunctionType>(pattern.getType ());
3260
+ auto patternType = pattern.getAs <AnyFunctionType>();
3261
+ // If pattern does not store an `AnyFunctionType`, return original pattern.
3262
+ // This logic handles opaque abstraction patterns.
3263
+ if (!patternType)
3264
+ return pattern;
3261
3265
pattern.rewriteType (
3262
3266
pattern.getGenericSignature (),
3263
3267
patternType->getWithoutDifferentiability ()->getCanonicalType ());
3264
3268
return pattern;
3265
3269
};
3266
3270
3271
+ auto inputOrigTypeNotDiff = withoutDifferentiablePattern (inputOrigType);
3267
3272
CanAnyFunctionType inputSubstTypeNotDiff (
3268
3273
inputSubstType->getWithoutDifferentiability ());
3269
- auto inputOrigTypeNotDiff = withoutDifferentiablePattern (inputOrigType );
3274
+ auto outputOrigTypeNotDiff = withoutDifferentiablePattern (outputOrigType );
3270
3275
CanAnyFunctionType outputSubstTypeNotDiff (
3271
3276
outputSubstType->getWithoutDifferentiability ());
3272
- auto outputOrigTypeNotDiff = withoutDifferentiablePattern (outputOrigType);
3273
3277
auto &expectedTLNotDiff = SGF.getTypeLowering (outputOrigTypeNotDiff,
3274
3278
outputSubstTypeNotDiff);
3275
3279
// `autodiff_function_extract` is consuming; copy `fn` before passing as
@@ -3301,18 +3305,21 @@ static ManagedValue createAutoDiffThunk(SILGenFunction &SGF,
3301
3305
auto getAssocFnPattern =
3302
3306
[&](AbstractionPattern pattern, AutoDiffAssociatedFunctionKind kind)
3303
3307
-> AbstractionPattern {
3308
+ // If pattern does not store an `AnyFunctionType`, return original
3309
+ // pattern. This logic handles opaque abstraction patterns.
3310
+ auto patternType = pattern.getAs <AnyFunctionType>();
3311
+ if (!patternType)
3312
+ return pattern;
3304
3313
return AbstractionPattern (
3305
- pattern.getGenericSignature (),
3306
- getAssocFnTy (cast<AnyFunctionType>(pattern.getType ()), kind));
3314
+ pattern.getGenericSignature (), getAssocFnTy (patternType, kind));
3307
3315
};
3308
3316
auto createAssocFnThunk = [&](AutoDiffAssociatedFunctionKind kind)
3309
3317
-> ManagedValue {
3318
+ auto assocFnInputOrigType = getAssocFnPattern (inputOrigTypeNotDiff, kind);
3310
3319
auto assocFnInputSubstType = getAssocFnTy (inputSubstTypeNotDiff, kind);
3311
- auto assocFnInputOrigType = getAssocFnPattern (inputOrigTypeNotDiff,
3312
- kind);
3313
- auto assocFnOutputSubstType = getAssocFnTy (outputSubstTypeNotDiff, kind);
3314
3320
auto assocFnOutputOrigType = getAssocFnPattern (outputOrigTypeNotDiff,
3315
3321
kind);
3322
+ auto assocFnOutputSubstType = getAssocFnTy (outputSubstTypeNotDiff, kind);
3316
3323
auto &assocFnExpectedTL = SGF.getTypeLowering (assocFnOutputOrigType,
3317
3324
assocFnOutputSubstType);
3318
3325
// `autodiff_function_extract` is consuming; copy `fn` before passing as
0 commit comments