Skip to content

Commit 5ccab9f

Browse files
committed
WIP
1 parent ee8818c commit 5ccab9f

File tree

1 file changed

+26
-21
lines changed

1 file changed

+26
-21
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3636,9 +3636,12 @@ class JVPEmitter final
36363636
directResults.append(origResults.begin(), origResults.end());
36373637
auto diffType = jvp->mapTypeIntoContext(
36383638
differential->getLoweredFunctionType()
3639-
->getWithRepresentation(SILFunctionTypeRepresentation::Thick))->getCanonicalType();
3639+
->getWithRepresentation(SILFunctionTypeRepresentation::Thick))
3640+
->getCanonicalType();
36403641

3641-
directResults.push_back(SILUndef::get(SILType::getPrimitiveObjectType(diffType), *differential));
3642+
directResults.push_back(SILUndef::get(
3643+
SILType::getPrimitiveObjectType(diffType),
3644+
*differential));
36423645
builder.createReturn(
36433646
ri->getLoc(), joinElements(directResults, builder, loc));
36443647
}
@@ -6395,7 +6398,7 @@ static SILFunction *createEmptyVJP(
63956398

63966399
static SILFunction *createEmptyJVP(
63976400
ADContext &context, SILFunction *original, SILDifferentiableAttr *attr,
6398-
bool isExported) {
6401+
bool isExported, bool vjpGenerated) {
63996402
LLVM_DEBUG({
64006403
auto &s = getADDebugStream();
64016404
s << "Creating JVP:\n\t";
@@ -6475,23 +6478,6 @@ bool ADContext::processDifferentiableAttribute(
64756478
invoker.getKind() ==
64766479
DifferentiationInvoker::Kind::SILDifferentiableAttribute;
64776480

6478-
// If the JVP doesn't exist, need to synthesize it.
6479-
if (!jvp) {
6480-
// Diagnose:
6481-
// - Functions with no return.
6482-
// - Functions with unsupported control flow.
6483-
if (diagnoseNoReturn(*this, original, invoker) ||
6484-
diagnoseUnsupportedControlFlow(*this, original, invoker))
6485-
return true;
6486-
6487-
jvp = createEmptyJVP(*this, original, attr, isAssocFnExported);
6488-
getGeneratedFunctions().push_back(jvp);
6489-
JVPEmitter emitter(*this, original, attr, jvp, invoker);
6490-
if (emitter.run()) {
6491-
return true;
6492-
}
6493-
}
6494-
64956481
// Try to look up VJP only if attribute specifies VJP name or if original
64966482
// function is an external declaration. If VJP function cannot be found,
64976483
// create an external VJP reference.
@@ -6514,17 +6500,36 @@ bool ADContext::processDifferentiableAttribute(
65146500
}
65156501

65166502
// If the JVP doesn't exist, need to synthesize it.
6503+
auto vjpGenerated = false;
65176504
if (!vjp) {
65186505
// Diagnose:
65196506
// - Functions with no return.
65206507
// - Functions with unsupported control flow.
65216508
if (diagnoseNoReturn(*this, original, invoker) ||
65226509
diagnoseUnsupportedControlFlow(*this, original, invoker))
65236510
return true;
6524-
6511+
6512+
vjpGenerated = true;
65256513
vjp = createEmptyVJP(*this, original, attr, isAssocFnExported);
65266514
getGeneratedFunctions().push_back(vjp);
65276515
VJPEmitter emitter(*this, original, attr, vjp, invoker);
6516+
if (emitter.run()) {
6517+
return true;
6518+
}
6519+
}
6520+
6521+
// If the JVP doesn't exist, need to synthesize it.
6522+
if (!jvp) {
6523+
// Diagnose:
6524+
// - Functions with no return.
6525+
// - Functions with unsupported control flow.
6526+
if (diagnoseNoReturn(*this, original, invoker) ||
6527+
diagnoseUnsupportedControlFlow(*this, original, invoker))
6528+
return true;
6529+
6530+
jvp = createEmptyJVP(*this, original, attr, isAssocFnExported, vjpGenerated);
6531+
getGeneratedFunctions().push_back(jvp);
6532+
JVPEmitter emitter(*this, original, attr, jvp, invoker);
65286533
return emitter.run();
65296534
}
65306535

0 commit comments

Comments
 (0)