-
Notifications
You must be signed in to change notification settings - Fork 10.5k
differentiate struct_extract instructions using VJP #21567
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -472,9 +472,9 @@ class PrimalInfo { | |
/// corresponding tape of its type. | ||
DenseMap<ApplyInst *, VarDecl *> nestedStaticPrimalValueMap; | ||
|
||
/// Mapping from `apply` instructions in the original function to the | ||
/// corresponding pullback decl in the primal struct. | ||
DenseMap<ApplyInst *, VarDecl *> pullbackValueMap; | ||
/// Mapping from `apply` and `struct_extract` instructions in the original | ||
/// function to the corresponding pullback decl in the primal struct. | ||
DenseMap<SILInstruction *, VarDecl *> pullbackValueMap; | ||
|
||
/// Mapping from types of control-dependent nested primal values to district | ||
/// tapes. | ||
|
@@ -573,7 +573,7 @@ class PrimalInfo { | |
} | ||
|
||
/// Add a pullback to the primal value struct. | ||
VarDecl *addPullbackDecl(ApplyInst *inst, Type pullbackType) { | ||
VarDecl *addPullbackDecl(SILInstruction *inst, Type pullbackType) { | ||
// Decls must have AST types (not `SILFunctionType`), so we convert the | ||
// `SILFunctionType` of the pullback to a `FunctionType` with the same | ||
// parameters and results. | ||
|
@@ -605,9 +605,9 @@ class PrimalInfo { | |
: lookup->getSecond(); | ||
} | ||
|
||
/// Finds the pullback decl in the primal value struct for an `apply` in the | ||
/// original function. | ||
VarDecl *lookUpPullbackDecl(ApplyInst *inst) { | ||
/// Finds the pullback decl in the primal value struct for an `apply` or | ||
/// `struct_extract` in the original function. | ||
VarDecl *lookUpPullbackDecl(SILInstruction *inst) { | ||
auto lookup = pullbackValueMap.find(inst); | ||
return lookup == pullbackValueMap.end() ? nullptr | ||
: lookup->getSecond(); | ||
|
@@ -2227,6 +2227,79 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> { | |
SILClonerWithScopes::visitReleaseValueInst(rvi); | ||
} | ||
|
||
void visitStructExtractInst(StructExtractInst *sei) { | ||
// Special handling logic only applies when the `struct_extract` is active. | ||
// If not, just do standard cloning. | ||
if (!activityInfo.isActive(sei, synthesis.indices)) { | ||
LLVM_DEBUG(getADDebugStream() << "Not active:\n" << *sei << '\n'); | ||
SILClonerWithScopes::visitStructExtractInst(sei); | ||
return; | ||
} | ||
|
||
// This instruction is active. Replace it with a call to the corresponding | ||
// getter's VJP. | ||
|
||
// Find the corresponding getter and its VJP. | ||
auto *getterDecl = sei->getField()->getGetter(); | ||
assert(getterDecl); | ||
auto *getterFn = getContext().getModule().lookUpFunction( | ||
SILDeclRef(getterDecl, SILDeclRef::Kind::Func)); | ||
if (!getterFn) { | ||
getContext().emitNondifferentiabilityError( | ||
sei, synthesis.task, diag::autodiff_property_not_differentiable); | ||
errorOccurred = true; | ||
return; | ||
} | ||
auto getterDiffAttrs = getterFn->getDifferentiableAttrs(); | ||
if (getterDiffAttrs.size() < 1) { | ||
getContext().emitNondifferentiabilityError( | ||
sei, synthesis.task, diag::autodiff_property_not_differentiable); | ||
errorOccurred = true; | ||
return; | ||
} | ||
auto *getterDiffAttr = getterDiffAttrs[0]; | ||
if (!getterDiffAttr->hasVJP()) { | ||
getContext().emitNondifferentiabilityError( | ||
sei, synthesis.task, diag::autodiff_property_not_differentiable); | ||
errorOccurred = true; | ||
return; | ||
} | ||
assert(getterDiffAttr->getIndices() == | ||
SILAutoDiffIndices(/*source*/ 0, /*parameters*/{0})); | ||
auto *getterVJP = lookUpOrLinkFunction(getterDiffAttr->getVJPName(), | ||
getContext().getModule()); | ||
|
||
// Reference and apply the VJP. | ||
auto loc = sei->getLoc(); | ||
auto *getterVJPRef = getBuilder().createFunctionRef(loc, getterVJP); | ||
auto *getterVJPApply = getBuilder().createApply( | ||
loc, getterVJPRef, /*substitutionMap*/ {}, | ||
/*args*/ {getMappedValue(sei->getOperand())}, /*isNonThrowing*/ false); | ||
|
||
// Get the VJP results (original results and pullback). | ||
SmallVector<SILValue, 8> vjpDirectResults; | ||
extractAllElements(getterVJPApply, getBuilder(), vjpDirectResults); | ||
ArrayRef<SILValue> originalDirectResults = | ||
ArrayRef<SILValue>(vjpDirectResults).drop_back(1); | ||
SILValue originalDirectResult = joinElements(originalDirectResults, | ||
getBuilder(), | ||
getterVJPApply->getLoc()); | ||
SILValue pullback = vjpDirectResults.back(); | ||
|
||
// Store the original result to the value map. | ||
mapValue(sei, originalDirectResult); | ||
|
||
// Checkpoint the original results. | ||
getPrimalInfo().addStaticPrimalValueDecl(sei); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. reminder to self: must add a retain here
|
||
getBuilder().createRetainValue(loc, originalDirectResult, | ||
getBuilder().getDefaultAtomicity()); | ||
staticPrimalValues.push_back(originalDirectResult); | ||
|
||
// Checkpoint the pullback. | ||
getPrimalInfo().addPullbackDecl(sei, pullback->getType().getASTType()); | ||
staticPrimalValues.push_back(pullback); | ||
} | ||
|
||
void visitApplyInst(ApplyInst *ai) { | ||
if (DifferentiationUseVJP) | ||
visitApplyInstWithVJP(ai); | ||
|
@@ -3522,33 +3595,36 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> { | |
} | ||
} | ||
|
||
/// Handle `struct_extract` instruction. | ||
/// y = struct_extract <key>, x | ||
/// adj[x] = struct (0, ..., key: adj[y], ..., 0) | ||
void visitStructExtractInst(StructExtractInst *sei) { | ||
auto *structDecl = sei->getStructDecl(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When we have derived conformances to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In fact, this gives me an idea. The original I think that the original I will try out this approach. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds reasonable for internal structs. Public structs need property VJPs in any case because of resilience requirements. |
||
auto av = getAdjointValue(sei); | ||
switch (av.getKind()) { | ||
case AdjointValue::Kind::Zero: | ||
addAdjointValue(sei->getOperand(), | ||
AdjointValue::getZero(sei->getOperand()->getType())); | ||
break; | ||
case AdjointValue::Kind::Materialized: | ||
case AdjointValue::Kind::Aggregate: { | ||
SmallVector<AdjointValue, 8> eltVals; | ||
for (auto *field : structDecl->getStoredProperties()) { | ||
if (field == sei->getField()) | ||
eltVals.push_back(av); | ||
else | ||
eltVals.push_back(AdjointValue::getZero( | ||
SILType::getPrimitiveObjectType( | ||
field->getType()->getCanonicalType()))); | ||
} | ||
addAdjointValue(sei->getOperand(), | ||
AdjointValue::getAggregate(sei->getOperand()->getType(), | ||
eltVals, allocator)); | ||
} | ||
// Replace a `struct_extract` with a call to its pullback. | ||
auto loc = remapLocation(sei->getLoc()); | ||
|
||
// Get the pullback. | ||
auto *pullbackField = getPrimalInfo().lookUpPullbackDecl(sei); | ||
if (!pullbackField) { | ||
// Inactive `struct_extract` instructions don't need to be cloned into the | ||
// adjoint. | ||
assert(!activityInfo.isActive(sei, synthesis.indices)); | ||
return; | ||
} | ||
SILValue pullback = builder.createStructExtract(loc, | ||
primalValueAggregateInAdj, | ||
pullbackField); | ||
|
||
// Construct the pullback arguments. | ||
SmallVector<SILValue, 8> args; | ||
auto seed = getAdjointValue(sei); | ||
assert(seed.getType().isObject()); | ||
args.push_back(materializeAdjointDirect(seed, loc)); | ||
|
||
// Call the pullback. | ||
auto *pullbackCall = builder.createApply(loc, pullback, SubstitutionMap(), | ||
args, /*isNonThrowing*/ false); | ||
assert(!pullbackCall->hasIndirectResults()); | ||
|
||
// Set adjoint for the `struct_extract` operand. | ||
addAdjointValue(sei->getOperand(), | ||
AdjointValue::getMaterialized(pullbackCall)); | ||
} | ||
|
||
/// Handle `tuple` instruction. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When will a getter not exist?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One example is the case that I added to
test/AutoDiff/autodiff_diagnostics.swift
.I don't know why the getter doesn't exist in that case, and if I put the same code in a file an compile it with a plain call to
swiftc
, the getter does exist and the code ends up triggering the case where the getter exists but doesn't have a VJP.I haven't investigated any farther than that, so I don't really know what's going on but the logic seems to handle all the cases I can think of correctly.