Skip to content

Commit 4a71e58

Browse files
committed
Address review feedback.
1 parent da69544 commit 4a71e58

File tree

3 files changed

+18
-18
lines changed

3 files changed

+18
-18
lines changed

lib/SIL/SILPrinter.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3168,7 +3168,7 @@ void SILDifferentiabilityWitness::print(
31683168
}
31693169
}
31703170
if (!requirements.empty()) {
3171-
OS << " [where ";
3171+
OS << "[where ";
31723172
auto subPrinter = PrintOptions::printSIL();
31733173
subPrinter.GenericEnv = origGenEnv;
31743174
interleave(requirements,
@@ -3189,12 +3189,12 @@ void SILDifferentiabilityWitness::print(
31893189
if (jvp) {
31903190
OS << " jvp: ";
31913191
printSILFunctionNameAndType(OS, jvp);
3192-
OS << "\n";
3192+
OS << '\n';
31933193
}
31943194
if (vjp) {
31953195
OS << " vjp: ";
31963196
printSILFunctionNameAndType(OS, vjp);
3197-
OS << "\n";
3197+
OS << '\n';
31983198
}
31993199
OS << "}\n\n";
32003200
}

lib/SILGen/SILGen.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -772,22 +772,22 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
772772
jvp = getFunction(SILDeclRef(jvpDecl), NotForDefinition);
773773
if (auto *vjpDecl = diffAttr->getVJPFunction())
774774
vjp = getFunction(SILDeclRef(vjpDecl), NotForDefinition);
775-
emitDifferentiabilityWitness(AFD, F, diffAttr->getParameterIndices(), jvp,
776-
vjp,
777-
diffAttr->getDerivativeGenericSignature());
775+
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
776+
AutoDiffConfig config{diffAttr->getParameterIndices(), resultIndices,
777+
diffAttr->getDerivativeGenericSignature()};
778+
emitDifferentiabilityWitness(AFD, F, config, jvp, vjp);
778779
}
779780
}
780781
F->verify();
781782
}
782783

783784
void SILGenModule::emitDifferentiabilityWitness(
784785
AbstractFunctionDecl *originalAFD, SILFunction *originalFunction,
785-
IndexSubset *parameterIndices, SILFunction *jvp, SILFunction *vjp,
786-
GenericSignature *derivativeGenSig) {
786+
const AutoDiffConfig &config, SILFunction *jvp, SILFunction *vjp) {
787787
auto *origFnType = originalAFD->getInterfaceType()->castTo<AnyFunctionType>();
788788
auto origSilFnType = originalFunction->getLoweredFunctionType();
789789
auto *loweredParamIndices = autodiff::getLoweredParameterIndices(
790-
parameterIndices, origFnType);
790+
config.parameterIndices, origFnType);
791791
// NOTE(TF-893): Extending capacity is necessary when `origSilFnType` has
792792
// parameters corresponding to captured variables. These parameters do not
793793
// appear in the type of `origFnType`.
@@ -813,7 +813,7 @@ void SILGenModule::emitDifferentiabilityWitness(
813813

814814
// Get or create differentiability witness.
815815
CanGenericSignature derivativeCanGenSig;
816-
if (derivativeGenSig)
816+
if (auto *derivativeGenSig = config.derivativeGenericSignature)
817817
derivativeCanGenSig = derivativeGenSig->getCanonicalSignature();
818818
// TODO(TF-835): Use simpler derivative generic signature logic below when
819819
// type-checking no longer generates implicit `@differentiable` attributes.
@@ -830,13 +830,12 @@ void SILGenModule::emitDifferentiabilityWitness(
830830
derivativeCanGenSig = vjpCanGenSig;
831831
assert(derivativeCanGenSig == vjpCanGenSig);
832832
}
833-
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
834833
// Create new SIL differentiability witness.
835834
// Witness JVP and VJP are set below.
836835
auto *diffWitness = SILDifferentiabilityWitness::create(
837836
M, originalFunction->getLinkage(), originalFunction,
838-
loweredParamIndices, resultIndices, derivativeGenSig, /*jvp*/ nullptr,
839-
/*vjp*/ nullptr, /*isSerialized*/ true);
837+
loweredParamIndices, config.resultIndices, derivativeCanGenSig,
838+
/*jvp*/ nullptr, /*vjp*/ nullptr, /*isSerialized*/ true);
840839

841840
// Set derivative function in differentiability witness.
842841
auto setDerivativeInDifferentiabilityWitness =
@@ -852,8 +851,10 @@ void SILGenModule::emitDifferentiabilityWitness(
852851
derivativeThunk = getOrCreateAutoDiffDerivativeFunctionThunk(
853852
originalFunction, indices, derivative, kind, reorderSelf);
854853
} else {
854+
// Note: `AutoDiffDerivativeFunctionIdentifier` must be constructed with
855+
// the AST-level parameter indices, not the SIL-level ones.
855856
auto *id = AutoDiffDerivativeFunctionIdentifier::get(
856-
kind, parameterIndices, getASTContext());
857+
kind, config.parameterIndices, getASTContext());
857858
derivativeThunk = getOrCreateAutoDiffThunk(
858859
SILDeclRef(originalAFD).asAutoDiffDerivativeFunction(id), derivative,
859860
expectedDerivativeType);

lib/SILGen/SILGen.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -320,13 +320,12 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
320320

321321
// SWIFT_ENABLE_TENSORFLOW
322322
/// Emit the differentiability witness for the given original function
323-
/// declaration and SIL function, parameter indices, and JVP and VJP
323+
/// declaration and SIL function, autodiff configuration, and JVP and VJP
324324
/// functions (null if undefined).
325325
void emitDifferentiabilityWitness(AbstractFunctionDecl *originalAFD,
326326
SILFunction *originalFunction,
327-
IndexSubset *parameterIndices,
328-
SILFunction *jvp, SILFunction *vjp,
329-
GenericSignature *derivativeGenSig);
327+
const AutoDiffConfig &config,
328+
SILFunction *jvp, SILFunction *vjp);
330329
// SWIFT_ENABLE_TENSORFLOW END
331330

332331
/// Emit the lazy initializer function for a global pattern binding

0 commit comments

Comments
 (0)