Skip to content

[AutoDiff] Start linear_function canonicalization skeleton #33057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 59 additions & 21 deletions include/swift/SILOptimizer/Differentiation/ADContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,22 @@ class ADContext {
llvm::SmallVector<DifferentiableFunctionInst *, 32>
differentiableFunctionInsts;

/// The worklist (stack) of `linear_function` instructions to be processed.
llvm::SmallVector<LinearFunctionInst *, 32> linearFunctionInsts;

/// The set of `differentiable_function` instructions that have been
/// processed. Used to avoid reprocessing invalidated instructions.
/// NOTE(TF-784): if we use `CanonicalizeInstruction` subclass to replace
/// `ADContext::processDifferentiableFunctionInst`, this field may be removed.
llvm::SmallPtrSet<DifferentiableFunctionInst *, 32>
processedDifferentiableFunctionInsts;

/// The set of `linear_function` instructions that have been processed. Used
/// to avoid reprocessing invalidated instructions.
/// NOTE(TF-784): if we use `CanonicalizeInstruction` subclass to replace
/// `ADContext::processLinearFunctionInst`, this field may be removed.
llvm::SmallPtrSet<LinearFunctionInst *, 32> processedLinearFunctionInsts;

/// Mapping from witnesses to invokers.
/// `SmallMapVector` is used for deterministic insertion order iteration.
llvm::SmallMapVector<SILDifferentiabilityWitness *, DifferentiationInvoker,
Expand Down Expand Up @@ -121,30 +130,19 @@ class ADContext {
SILPassManager &getPassManager() const { return passManager; }
Lowering::TypeConverter &getTypeConverter() { return module.Types; }

/// Get or create the synthesized file for the given `SILFunction`.
/// Used by `LinearMapInfo` for adding generated linear map struct and
/// branching trace enum declarations.
SynthesizedFileUnit &getOrCreateSynthesizedFile(SILFunction *original);

/// Returns true if the `differentiable_function` instruction worklist is
/// empty.
bool isDifferentiableFunctionInstsWorklistEmpty() const {
return differentiableFunctionInsts.empty();
llvm::SmallVectorImpl<DifferentiableFunctionInst *> &
getDifferentiableFunctionInstWorklist() {
return differentiableFunctionInsts;
}

/// Pops and returns a `differentiable_function` instruction from the
/// worklist. Returns nullptr if the worklist is empty.
DifferentiableFunctionInst *popDifferentiableFunctionInstFromWorklist() {
if (differentiableFunctionInsts.empty())
return nullptr;
return differentiableFunctionInsts.pop_back_val();
llvm::SmallVectorImpl<LinearFunctionInst *> &getLinearFunctionInstWorklist() {
return linearFunctionInsts;
}

/// Adds the given `differentiable_function` instruction to the worklist.
void
addDifferentiableFunctionInstToWorklist(DifferentiableFunctionInst *dfi) {
differentiableFunctionInsts.push_back(dfi);
}
/// Get or create the synthesized file for the given `SILFunction`.
/// Used by `LinearMapInfo` for adding generated linear map struct and
/// branching trace enum declarations.
SynthesizedFileUnit &getOrCreateSynthesizedFile(SILFunction *original);

/// Returns true if the given `differentiable_function` instruction has
/// already been processed.
Expand All @@ -159,6 +157,17 @@ class ADContext {
processedDifferentiableFunctionInsts.insert(dfi);
}

/// Returns true if the given `linear_function` instruction has already been
/// processed.
bool isLinearFunctionInstProcessed(LinearFunctionInst *lfi) const {
return processedLinearFunctionInsts.count(lfi);
}

/// Adds the given `linear_function` instruction to the worklist.
void markLinearFunctionInstAsProcessed(LinearFunctionInst *lfi) {
processedLinearFunctionInsts.insert(lfi);
}

const llvm::SmallMapVector<SILDifferentiabilityWitness *,
DifferentiationInvoker, 32> &
getInvokers() const {
Expand Down Expand Up @@ -204,12 +213,26 @@ class ADContext {
IndexSubset *resultIndices, SILValue original,
Optional<std::pair<SILValue, SILValue>> derivativeFunctions = None);

// Given an `differentiable_function` instruction, finds the corresponding
/// Creates a `linear_function` instruction using the given builder
/// and arguments. Erase the newly created instruction from the processed set,
/// if it exists - it may exist in the processed set if it has the same
/// pointer value as a previously processed and deleted instruction.
LinearFunctionInst *
createLinearFunction(SILBuilder &builder, SILLocation loc,
IndexSubset *parameterIndices, SILValue original,
Optional<SILValue> transposeFunction = None);

// Given a `differentiable_function` instruction, finds the corresponding
// differential operator used in the AST. If no differential operator is
// found, return nullptr.
DifferentiableFunctionExpr *
findDifferentialOperator(DifferentiableFunctionInst *inst);

// Given a `linear_function` instruction, finds the corresponding differential
// operator used in the AST. If no differential operator is found, return
// nullptr.
LinearFunctionExpr *findDifferentialOperator(LinearFunctionInst *inst);

template <typename... T, typename... U>
InFlightDiagnostic diagnose(SourceLoc loc, Diag<T...> diag,
U &&... args) const {
Expand Down Expand Up @@ -300,6 +323,21 @@ ADContext::emitNondifferentiabilityError(SourceLoc loc,
return diagnose(loc, diag, std::forward<U>(args)...);
}

// For `linear_function` instructions: if the `linear_function` instruction
// comes from a differential operator, emit an error on the expression and a
// note on the non-differentiable operation. Otherwise, emit both an error and
// note on the non-differentiation operation.
case DifferentiationInvoker::Kind::LinearFunctionInst: {
auto *inst = invoker.getLinearFunctionInst();
if (auto *expr = findDifferentialOperator(inst)) {
diagnose(expr->getLoc(), diag::autodiff_function_not_differentiable_error)
.highlight(expr->getSubExpr()->getSourceRange());
return diagnose(loc, diag, std::forward<U>(args)...);
}
diagnose(loc, diag::autodiff_expression_not_differentiable_error);
return diagnose(loc, diag, std::forward<U>(args)...);
}

// For differentiability witnesses: try to find a `@differentiable` or
// `@derivative` attribute. If an attribute is found, emit an error on it;
// otherwise, emit an error on the original function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ namespace swift {

class ApplyInst;
class DifferentiableFunctionInst;
class LinearFunctionInst;
class SILDifferentiabilityWitness;

namespace autodiff {
Expand All @@ -42,6 +43,10 @@ struct DifferentiationInvoker {
// expression).
DifferentiableFunctionInst,

// Invoked by an `linear_function` instruction, which may or may not
// be linked to a Swift AST node (e.g. an `LinearFunctionExpr` expression).
LinearFunctionInst,

// Invoked by the indirect application of differentiation. This case has an
// associated original `apply` instruction and
// `SILDifferentiabilityWitness`.
Expand All @@ -60,6 +65,10 @@ struct DifferentiationInvoker {
DifferentiableFunctionInst *diffFuncInst;
Value(DifferentiableFunctionInst *inst) : diffFuncInst(inst) {}

/// The instruction associated with the `LinearFunctionInst` case.
LinearFunctionInst *linearFuncInst;
Value(LinearFunctionInst *inst) : linearFuncInst(inst) {}

/// The parent `apply` instruction and the witness associated with the
/// `IndirectDifferentiation` case.
std::pair<ApplyInst *, SILDifferentiabilityWitness *>
Expand All @@ -79,6 +88,8 @@ struct DifferentiationInvoker {
public:
DifferentiationInvoker(DifferentiableFunctionInst *inst)
: kind(Kind::DifferentiableFunctionInst), value(inst) {}
DifferentiationInvoker(LinearFunctionInst *inst)
: kind(Kind::LinearFunctionInst), value(inst) {}
DifferentiationInvoker(ApplyInst *applyInst,
SILDifferentiabilityWitness *witness)
: kind(Kind::IndirectDifferentiation), value({applyInst, witness}) {}
Expand All @@ -92,6 +103,11 @@ struct DifferentiationInvoker {
return value.diffFuncInst;
}

LinearFunctionInst *getLinearFunctionInst() const {
assert(kind == Kind::LinearFunctionInst);
return value.linearFuncInst;
}

std::pair<ApplyInst *, SILDifferentiabilityWitness *>
getIndirectDifferentiation() const {
assert(kind == Kind::IndirectDifferentiation);
Expand Down
14 changes: 14 additions & 0 deletions lib/SILOptimizer/Differentiation/ADContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,24 @@ DifferentiableFunctionInst *ADContext::createDifferentiableFunction(
return dfi;
}

LinearFunctionInst *ADContext::createLinearFunction(
SILBuilder &builder, SILLocation loc, IndexSubset *parameterIndices,
SILValue original, Optional<SILValue> transposeFunction) {
auto *lfi = builder.createLinearFunction(loc, parameterIndices, original,
transposeFunction);
processedLinearFunctionInsts.erase(lfi);
return lfi;
}

DifferentiableFunctionExpr *
ADContext::findDifferentialOperator(DifferentiableFunctionInst *inst) {
return inst->getLoc().getAsASTNode<DifferentiableFunctionExpr>();
}

LinearFunctionExpr *
ADContext::findDifferentialOperator(LinearFunctionInst *inst) {
return inst->getLoc().getAsASTNode<LinearFunctionExpr>();
}

} // end namespace autodiff
} // end namespace swift
5 changes: 5 additions & 0 deletions lib/SILOptimizer/Differentiation/DifferentiationInvoker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ SourceLoc DifferentiationInvoker::getLocation() const {
switch (kind) {
case Kind::DifferentiableFunctionInst:
return getDifferentiableFunctionInst()->getLoc().getSourceLoc();
case Kind::LinearFunctionInst:
return getLinearFunctionInst()->getLoc().getSourceLoc();
case Kind::IndirectDifferentiation:
return getIndirectDifferentiation().first->getLoc().getSourceLoc();
case Kind::SILDifferentiabilityWitnessInvoker:
Expand All @@ -46,6 +48,9 @@ void DifferentiationInvoker::print(llvm::raw_ostream &os) const {
os << "differentiable_function_inst=(" << *getDifferentiableFunctionInst()
<< ")";
break;
case Kind::LinearFunctionInst:
os << "linear_function_inst=(" << *getLinearFunctionInst() << ")";
break;
case Kind::IndirectDifferentiation: {
auto indDiff = getIndirectDifferentiation();
os << "indirect_differentiation=(" << *std::get<0>(indDiff) << ')';
Expand Down
12 changes: 10 additions & 2 deletions lib/SILOptimizer/Differentiation/JVPCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ class JVPCloner::Implementation final
builder, loc, indices.parameters, indices.results, origCallee);

// Record the `differentiable_function` instruction.
context.addDifferentiableFunctionInstToWorklist(diffFuncInst);
context.getDifferentiableFunctionInstWorklist().push_back(diffFuncInst);

auto borrowedADFunc = builder.emitBeginBorrowOperation(loc, diffFuncInst);
auto extractedJVP = builder.createDifferentiableFunctionExtract(
Expand Down Expand Up @@ -749,7 +749,15 @@ class JVPCloner::Implementation final
// instruction to the `differentiable_function` worklist.
TypeSubstCloner::visitDifferentiableFunctionInst(dfi);
auto *newDFI = cast<DifferentiableFunctionInst>(getOpValue(dfi));
context.addDifferentiableFunctionInstToWorklist(newDFI);
context.getDifferentiableFunctionInstWorklist().push_back(newDFI);
}

void visitLinearFunctionInst(LinearFunctionInst *lfi) {
// Clone `linear_function` from original to JVP, then add the cloned
// instruction to the `linear_function` worklist.
TypeSubstCloner::visitLinearFunctionInst(lfi);
auto *newLFI = cast<LinearFunctionInst>(getOpValue(lfi));
context.getLinearFunctionInstWorklist().push_back(newLFI);
}

//--------------------------------------------------------------------------//
Expand Down
12 changes: 10 additions & 2 deletions lib/SILOptimizer/Differentiation/VJPCloner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ class VJPCloner::Implementation final
getBuilder(), loc, indices.parameters, indices.results, origCallee);

// Record the `differentiable_function` instruction.
context.addDifferentiableFunctionInstToWorklist(diffFuncInst);
context.getDifferentiableFunctionInstWorklist().push_back(diffFuncInst);

auto borrowedADFunc = builder.emitBeginBorrowOperation(loc, diffFuncInst);
auto extractedVJP = getBuilder().createDifferentiableFunctionExtract(
Expand Down Expand Up @@ -623,7 +623,15 @@ class VJPCloner::Implementation final
// instruction to the `differentiable_function` worklist.
TypeSubstCloner::visitDifferentiableFunctionInst(dfi);
auto *newDFI = cast<DifferentiableFunctionInst>(getOpValue(dfi));
context.addDifferentiableFunctionInstToWorklist(newDFI);
context.getDifferentiableFunctionInstWorklist().push_back(newDFI);
}

void visitLinearFunctionInst(LinearFunctionInst *lfi) {
// Clone `linear_function` from original to VJP, then add the cloned
// instruction to the `linear_function` worklist.
TypeSubstCloner::visitLinearFunctionInst(lfi);
auto *newLFI = cast<LinearFunctionInst>(getOpValue(lfi));
context.getLinearFunctionInstWorklist().push_back(newLFI);
}
};

Expand Down
Loading