@@ -3303,27 +3303,36 @@ class VJPEmitter final
3303
3303
/* differentiationOrder*/ 1 , functionSource);
3304
3304
}
3305
3305
3306
- // Check and diagnose non-differentiable arguments.
3307
- for (unsigned paramIndex : range (originalFnTy->getNumParameters ())) {
3308
- if (indices.isWrtParameter (paramIndex) &&
3309
- !originalFnTy->getParameters ()[paramIndex]
3310
- .getSILStorageType ()
3311
- .isDifferentiable (getModule ())) {
3312
- context.emitNondifferentiabilityError (
3313
- original, invoker, diag::autodiff_nondifferentiable_argument);
3314
- errorOccurred = true ;
3315
- return ;
3316
- }
3317
- }
3318
- // Check and diagnose non-differentiable results.
3319
- if (!originalFnTy->getResults ()[indices.source ]
3320
- .getSILStorageType ()
3321
- .isDifferentiable (getModule ())) {
3322
- context.emitNondifferentiabilityError (
3323
- original, invoker, diag::autodiff_nondifferentiable_result);
3324
- errorOccurred = true ;
3306
+ // Check and diagnose non-differentiable original function type.
3307
+ auto diagnoseNondifferentiableOriginalFunctionType =
3308
+ [&](CanSILFunctionType origFnTy) {
3309
+ // Check and diagnose non-differentiable arguments.
3310
+ for (unsigned paramIndex : range (originalFnTy->getNumParameters ())) {
3311
+ if (indices.isWrtParameter (paramIndex) &&
3312
+ !originalFnTy->getParameters ()[paramIndex]
3313
+ .getSILStorageType ()
3314
+ .isDifferentiable (getModule ())) {
3315
+ context.emitNondifferentiabilityError (
3316
+ ai->getArgumentsWithoutIndirectResults ()[paramIndex], invoker,
3317
+ diag::autodiff_nondifferentiable_argument);
3318
+ errorOccurred = true ;
3319
+ return true ;
3320
+ }
3321
+ }
3322
+ // Check and diagnose non-differentiable results.
3323
+ if (!originalFnTy->getResults ()[indices.source ]
3324
+ .getSILStorageType ()
3325
+ .isDifferentiable (getModule ())) {
3326
+ context.emitNondifferentiabilityError (
3327
+ original, invoker, diag::autodiff_nondifferentiable_result);
3328
+ errorOccurred = true ;
3329
+ return true ;
3330
+ }
3331
+ return false ;
3332
+ };
3333
+ if (diagnoseNondifferentiableOriginalFunctionType (originalFnTy))
3325
3334
return ;
3326
- }
3335
+
3327
3336
// If VJP has not yet been found, emit an `autodiff_function` instruction
3328
3337
// on the remapped original function operand and `autodiff_function_extract`
3329
3338
// the VJP. The actual JVP/VJP functions will be populated in the
@@ -3354,6 +3363,10 @@ class VJPEmitter final
3354
3363
ai->getLoc (), original, substMap, {},
3355
3364
ParameterConvention::Direct_Guaranteed);
3356
3365
original = vjpPartialApply;
3366
+ originalFnTy = original->getType ().castTo <SILFunctionType>();
3367
+ // Diagnose if new original function type is non-differentiable.
3368
+ if (diagnoseNondifferentiableOriginalFunctionType (originalFnTy))
3369
+ return ;
3357
3370
}
3358
3371
3359
3372
auto *autoDiffFuncInst = context.createAutoDiffFunction (
@@ -3363,6 +3376,8 @@ class VJPEmitter final
3363
3376
3364
3377
// Record the `autodiff_function` instruction.
3365
3378
context.getAutoDiffFunctionInsts ().push_back (autoDiffFuncInst);
3379
+ // TODO(TF-689): Make `autodiff_function` store result indices and remove
3380
+ // `ADContext::resultIndices`.
3366
3381
context.getResultIndices ()[autoDiffFuncInst] =
3367
3382
activeResultIndices.front ();
3368
3383
0 commit comments