@@ -3636,9 +3636,12 @@ class JVPEmitter final
3636
3636
directResults.append (origResults.begin (), origResults.end ());
3637
3637
auto diffType = jvp->mapTypeIntoContext (
3638
3638
differential->getLoweredFunctionType ()
3639
- ->getWithRepresentation (SILFunctionTypeRepresentation::Thick))->getCanonicalType ();
3639
+ ->getWithRepresentation (SILFunctionTypeRepresentation::Thick))
3640
+ ->getCanonicalType ();
3640
3641
3641
- directResults.push_back (SILUndef::get (SILType::getPrimitiveObjectType (diffType), *differential));
3642
+ directResults.push_back (SILUndef::get (
3643
+ SILType::getPrimitiveObjectType (diffType),
3644
+ *differential));
3642
3645
builder.createReturn (
3643
3646
ri->getLoc (), joinElements (directResults, builder, loc));
3644
3647
}
@@ -6395,7 +6398,7 @@ static SILFunction *createEmptyVJP(
6395
6398
6396
6399
static SILFunction *createEmptyJVP (
6397
6400
ADContext &context, SILFunction *original, SILDifferentiableAttr *attr,
6398
- bool isExported) {
6401
+ bool isExported, bool vjpGenerated ) {
6399
6402
LLVM_DEBUG ({
6400
6403
auto &s = getADDebugStream ();
6401
6404
s << " Creating JVP:\n\t " ;
@@ -6475,23 +6478,6 @@ bool ADContext::processDifferentiableAttribute(
6475
6478
invoker.getKind () ==
6476
6479
DifferentiationInvoker::Kind::SILDifferentiableAttribute;
6477
6480
6478
- // If the JVP doesn't exist, need to synthesize it.
6479
- if (!jvp) {
6480
- // Diagnose:
6481
- // - Functions with no return.
6482
- // - Functions with unsupported control flow.
6483
- if (diagnoseNoReturn (*this , original, invoker) ||
6484
- diagnoseUnsupportedControlFlow (*this , original, invoker))
6485
- return true ;
6486
-
6487
- jvp = createEmptyJVP (*this , original, attr, isAssocFnExported);
6488
- getGeneratedFunctions ().push_back (jvp);
6489
- JVPEmitter emitter (*this , original, attr, jvp, invoker);
6490
- if (emitter.run ()) {
6491
- return true ;
6492
- }
6493
- }
6494
-
6495
6481
// Try to look up VJP only if attribute specifies VJP name or if original
6496
6482
// function is an external declaration. If VJP function cannot be found,
6497
6483
// create an external VJP reference.
@@ -6514,17 +6500,36 @@ bool ADContext::processDifferentiableAttribute(
6514
6500
}
6515
6501
6516
6502
// If the JVP doesn't exist, need to synthesize it.
6503
+ auto vjpGenerated = false ;
6517
6504
if (!vjp) {
6518
6505
// Diagnose:
6519
6506
// - Functions with no return.
6520
6507
// - Functions with unsupported control flow.
6521
6508
if (diagnoseNoReturn (*this , original, invoker) ||
6522
6509
diagnoseUnsupportedControlFlow (*this , original, invoker))
6523
6510
return true ;
6524
-
6511
+
6512
+ vjpGenerated = true ;
6525
6513
vjp = createEmptyVJP (*this , original, attr, isAssocFnExported);
6526
6514
getGeneratedFunctions ().push_back (vjp);
6527
6515
VJPEmitter emitter (*this , original, attr, vjp, invoker);
6516
+ if (emitter.run ()) {
6517
+ return true ;
6518
+ }
6519
+ }
6520
+
6521
+ // If the JVP doesn't exist, need to synthesize it.
6522
+ if (!jvp) {
6523
+ // Diagnose:
6524
+ // - Functions with no return.
6525
+ // - Functions with unsupported control flow.
6526
+ if (diagnoseNoReturn (*this , original, invoker) ||
6527
+ diagnoseUnsupportedControlFlow (*this , original, invoker))
6528
+ return true ;
6529
+
6530
+ jvp = createEmptyJVP (*this , original, attr, isAssocFnExported, vjpGenerated);
6531
+ getGeneratedFunctions ().push_back (jvp);
6532
+ JVPEmitter emitter (*this , original, attr, jvp, invoker);
6528
6533
return emitter.run ();
6529
6534
}
6530
6535
0 commit comments