Skip to content

[AutoDiff upstream] Add differentiability_witness_function instruction. #29765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions docs/SIL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5666,6 +5666,42 @@ destination (if it returns with ``throw``).

The rules on generic substitutions are identical to those of ``apply``.

Differentiable Programming
~~~~~~~~~~~~~~~~~~~~~~~~~~

differentiability_witness_function
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This instruction is very confusing to me (not sure if its just me though). differentiability_witness_function makes me think that this returns the witness. If this is looking up the associated differentiation function associated with a function would you be open to renaming this to something like differential?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name is similar to ’witness_method’, where it returns a method from a witness. ‘differentiability_witness_function’ returns a function from a differentiability witness. “Differential” is not the technically right name and can cause confusion with the “differential” in our API. The returned function is not a differential, but a JVP/VJP (derivative) function.

I’m entirely open to renaming this to ‘derivative_function’, but I just wanted to clarify that the current name falls in line with the name ‘witness_method’.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, I thought Id looked through the rest and then came back and changed differential to derivative. Id be okay with derivative_function as well.

The thing is that function becomes a fuzzy term here. It fetches the function pointer from the witness, returning a function which implements the function which performs a derivative over a function. The way that I initially read it, it made me question whether it was getting a getter for the differentiability witness or an entry in the witness table.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

derivative_function is perhaps not wholly precise, because the instruction has a [transpose] option for returning a transpose function, in addition to [jvp] and [vjp] options for returning derivative functions. I'm open to derivative_function though.

Do we have consensus on derivative_function (or some other name) as an alternative name for differentiability_witness_function? cc @rxwei @marcrasi

If not, let's merge and defer renaming until later. I'll merge within a day if no one responds.

Copy link
Contributor

@rxwei rxwei Feb 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future each witness will be either a linear witness or non-linear witness, so it’s possible to define both a ‘derivative_function’ and a ‘transpose_function’.

That said, I don’t like the idea of dropping the word “witness” from this instruction. It makes it sounds like it’s differentiating or transposing something, whereas all it does is retrieve the pointer to a function in the witness.

In the future when we unify JVP and VJP into a single derivative, a differentiability witness will be equivalent to a function forward declaration, so calling it ‘differentiability_witness_function’ is quite clear. I’m supporting keeping the existing name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future when we unify JVP and VJP into a single derivative, a differentiability witness will be equivalent to a function forward declaration, so calling it ‘differentiability_witness_function’ is quite clear.

Interesting point. I'll go ahead and merge this patch now.

``````````````````````````````````
::

sil-instruction ::=
'differentiability_witness_function'
'[' sil-differentiability-witness-function-kind ']'
'[' 'parameters' sil-differentiability-witness-function-index-list ']'
'[' 'results' sil-differentiability-witness-function-index-list ']'
generic-parameter-clause?
sil-function-name ':' sil-type

sil-differentiability-witness-function-kind ::= 'jvp' | 'vjp' | 'transpose'
sil-differentiability-witness-function-index-list ::= [0-9]+ (' ' [0-9]+)*

differentiability_witness_function [jvp] [parameters 0] [results 0] \
<T where T: Differentiable> @foo : $(T) -> T

Looks up a differentiability witness function (JVP, VJP, or transpose) for
a referenced function via SIL differentiability witnesses.

The differentiability witness function kind identifies the witness function to
look up: ``[jvp]``, ``[vjp]``, or ``[transpose]``.

The remaining components identify the SIL differentiability witness:

- Original function name.
- Parameter indices.
- Result indices.
- Witness generic parameter clause (optional). When parsing SIL, the parsed
witness generic parameter clause is combined with the original function's
generic signature to form the full witness generic signature.

Assertion configuration
~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
21 changes: 21 additions & 0 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,27 @@ struct AutoDiffDerivativeFunctionKind {
}
};

/// The kind of a differentiability witness function.
struct DifferentiabilityWitnessFunctionKind {
enum innerty : uint8_t {
// The Jacobian-vector products function.
JVP = 0,
// The vector-Jacobian products function.
VJP = 1,
// The transpose function.
Transpose = 2
} rawValue;

DifferentiabilityWitnessFunctionKind() = default;
DifferentiabilityWitnessFunctionKind(innerty rawValue) : rawValue(rawValue) {}
explicit DifferentiabilityWitnessFunctionKind(unsigned rawValue)
: rawValue(static_cast<innerty>(rawValue)) {}
explicit DifferentiabilityWitnessFunctionKind(StringRef name);
operator innerty() const { return rawValue; }

Optional<AutoDiffDerivativeFunctionKind> getAsDerivativeFunctionKind() const;
};

/// Identifies an autodiff derivative function configuration:
/// - Parameter indices.
/// - Result indices.
Expand Down
7 changes: 7 additions & 0 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1604,6 +1604,13 @@ ERROR(sil_autodiff_expected_parameter_index,PointsToFirstBadToken,
"expected the index of a parameter to differentiate with respect to", ())
ERROR(sil_autodiff_expected_result_index,PointsToFirstBadToken,
"expected the index of a result to differentiate from", ())
ERROR(sil_inst_autodiff_expected_differentiability_witness_kind,PointsToFirstBadToken,
"expected a differentiability witness kind, which can be one of '[jvp]', "
"'[vjp]', or '[transpose]'", ())
ERROR(sil_inst_autodiff_invalid_witness_generic_signature,PointsToFirstBadToken,
"expected witness_generic signature '%0' does not have same generic "
"parameters as original function generic signature '%1'",
(StringRef, StringRef))

//------------------------------------------------------------------------------
// MARK: Generics parsing diagnostics
Expand Down
14 changes: 14 additions & 0 deletions include/swift/SIL/SILBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -2157,6 +2157,20 @@ class SILBuilder {
SILValue emitThickToObjCMetatype(SILLocation Loc, SILValue Op, SILType Ty);
SILValue emitObjCToThickMetatype(SILLocation Loc, SILValue Op, SILType Ty);

//===--------------------------------------------------------------------===//
// Differentiable programming instructions
//===--------------------------------------------------------------------===//

/// Note: explicit function type may be specified only in lowered SIL.
DifferentiabilityWitnessFunctionInst *createDifferentiabilityWitnessFunction(
SILLocation Loc, DifferentiabilityWitnessFunctionKind WitnessKind,
SILDifferentiabilityWitness *Witness,
Optional<SILType> FunctionType = None) {
return insert(new (getModule()) DifferentiabilityWitnessFunctionInst(
getModule(), getSILDebugLocation(Loc), WitnessKind, Witness,
FunctionType));
}

//===--------------------------------------------------------------------===//
// Private Helper Methods
//===--------------------------------------------------------------------===//
Expand Down
10 changes: 10 additions & 0 deletions include/swift/SIL/SILCloner.h
Original file line number Diff line number Diff line change
Expand Up @@ -2825,6 +2825,16 @@ void SILCloner<ImplClass>::visitKeyPathInst(KeyPathInst *Inst) {
opValues, getOpType(Inst->getType())));
}

template <typename ImplClass>
void SILCloner<ImplClass>::visitDifferentiabilityWitnessFunctionInst(
DifferentiabilityWitnessFunctionInst *Inst) {
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
recordClonedInstruction(Inst,
getBuilder().createDifferentiabilityWitnessFunction(
getOpLocation(Inst->getLoc()),
Inst->getWitnessKind(), Inst->getWitness()));
}

} // end namespace swift

#endif
36 changes: 36 additions & 0 deletions include/swift/SIL/SILInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#ifndef SWIFT_SIL_INSTRUCTION_H
#define SWIFT_SIL_INSTRUCTION_H

#include "swift/AST/AutoDiff.h"
#include "swift/AST/Builtins.h"
#include "swift/AST/Decl.h"
#include "swift/AST/GenericSignature.h"
Expand Down Expand Up @@ -61,6 +62,7 @@ class SILBasicBlock;
class SILBuilder;
class SILDebugLocation;
class SILDebugScope;
class SILDifferentiabilityWitness;
class SILFunction;
class SILGlobalVariable;
class SILInstructionResultArray;
Expand Down Expand Up @@ -7931,6 +7933,40 @@ class TryApplyInst final
const GenericSpecializationInformation *SpecializationInfo);
};

class DifferentiabilityWitnessFunctionInst
: public InstructionBase<
SILInstructionKind::DifferentiabilityWitnessFunctionInst,
SingleValueInstruction> {
private:
friend SILBuilder;
/// The differentiability witness function kind.
DifferentiabilityWitnessFunctionKind witnessKind;
/// The referenced SIL differentiability witness.
SILDifferentiabilityWitness *witness;
/// Whether the instruction has an explicit function type.
bool hasExplicitFunctionType;

static SILType getDifferentiabilityWitnessType(
SILModule &module, DifferentiabilityWitnessFunctionKind witnessKind,
SILDifferentiabilityWitness *witness);

public:
/// Note: explicit function type may be specified only in lowered SIL.
DifferentiabilityWitnessFunctionInst(
SILModule &module, SILDebugLocation loc,
DifferentiabilityWitnessFunctionKind witnessKind,
SILDifferentiabilityWitness *witness, Optional<SILType> FunctionType);

DifferentiabilityWitnessFunctionKind getWitnessKind() const {
return witnessKind;
}
SILDifferentiabilityWitness *getWitness() const { return witness; }
bool getHasExplicitFunctionType() const { return hasExplicitFunctionType; }

ArrayRef<Operand> getAllOperands() const { return {}; }
MutableArrayRef<Operand> getAllOperands() { return {}; }
};

// This is defined out of line to work around the fact that this depends on
// PartialApplyInst being defined, but PartialApplyInst is a subclass of
// ApplyInstBase, so we can not place ApplyInstBase after it.
Expand Down
5 changes: 5 additions & 0 deletions include/swift/SIL/SILNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,11 @@ ABSTRACT_VALUE_AND_INST(SingleValueInstruction, ValueBase, SILInstruction)
SINGLE_VALUE_INST(InitBlockStorageHeaderInst, init_block_storage_header,
SingleValueInstruction, None, DoesNotRelease)

// Differentiable programming
SINGLE_VALUE_INST(DifferentiabilityWitnessFunctionInst,
differentiability_witness_function,
SingleValueInstruction, None, DoesNotRelease)

// Key paths
// TODO: The only "side effect" is potentially retaining the returned key path
// object; is there a more specific effect?
Expand Down
25 changes: 24 additions & 1 deletion lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 Apple Inc. and the Swift project authors
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
Expand All @@ -12,11 +12,34 @@

#include "swift/AST/AutoDiff.h"
#include "swift/AST/ASTContext.h"
#include "swift/AST/Module.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/AST/Types.h"

using namespace swift;

DifferentiabilityWitnessFunctionKind::DifferentiabilityWitnessFunctionKind(
StringRef string) {
Optional<innerty> result = llvm::StringSwitch<Optional<innerty>>(string)
.Case("jvp", JVP)
.Case("vjp", VJP)
.Case("transpose", Transpose);
assert(result && "Invalid string");
rawValue = *result;
}

Optional<AutoDiffDerivativeFunctionKind>
DifferentiabilityWitnessFunctionKind::getAsDerivativeFunctionKind() const {
switch (rawValue) {
case JVP:
return {AutoDiffDerivativeFunctionKind::JVP};
case VJP:
return {AutoDiffDerivativeFunctionKind::VJP};
case Transpose:
return None;
}
}

void AutoDiffConfig::print(llvm::raw_ostream &s) const {
s << "(parameters=";
parameterIndices->print(s);
Expand Down
30 changes: 30 additions & 0 deletions lib/IRGen/IRGenSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,9 @@ class IRGenSILFunction :

void visitKeyPathInst(KeyPathInst *I);

void visitDifferentiabilityWitnessFunctionInst(
DifferentiabilityWitnessFunctionInst *i);

#define LOADABLE_REF_STORAGE_HELPER(Name) \
void visitRefTo##Name##Inst(RefTo##Name##Inst *i); \
void visit##Name##ToRefInst(Name##ToRefInst *i); \
Expand Down Expand Up @@ -1815,6 +1818,33 @@ void IRGenSILFunction::visitSILBasicBlock(SILBasicBlock *BB) {
assert(Builder.hasPostTerminatorIP() && "SIL bb did not terminate block?!");
}

void IRGenSILFunction::visitDifferentiabilityWitnessFunctionInst(
DifferentiabilityWitnessFunctionInst *i) {
llvm::Value *diffWitness =
IGM.getAddrOfDifferentiabilityWitness(i->getWitness());
unsigned offset = 0;
switch (i->getWitnessKind()) {
case DifferentiabilityWitnessFunctionKind::JVP:
offset = 0;
break;
case DifferentiabilityWitnessFunctionKind::VJP:
offset = 1;
break;
case DifferentiabilityWitnessFunctionKind::Transpose:
llvm_unreachable("Not yet implemented");
}

diffWitness = Builder.CreateStructGEP(diffWitness, offset);
diffWitness = Builder.CreateLoad(diffWitness, IGM.getPointerAlignment());

auto fnType = cast<SILFunctionType>(i->getType().getASTType());
Signature signature = IGM.getSignature(fnType);
diffWitness =
Builder.CreateBitCast(diffWitness, signature.getType()->getPointerTo());

setLoweredFunctionPointer(i, FunctionPointer(diffWitness, signature));
}

void IRGenSILFunction::visitFunctionRefBaseInst(FunctionRefBaseInst *i) {
auto fn = i->getInitiallyReferencedFunction();

Expand Down
41 changes: 41 additions & 0 deletions lib/ParseSIL/ParseSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5023,6 +5023,47 @@ bool SILParser::parseSpecificSILInstruction(SILBuilder &B,
blockType, subMap);
break;
}
case SILInstructionKind::DifferentiabilityWitnessFunctionInst: {
// e.g. differentiability_witness_function
// [jvp] [parameters 0 1] [results 0] <T where T: Differentiable>
// @foo : <T> $(T) -> T
DifferentiabilityWitnessFunctionKind witnessKind;
StringRef witnessKindNames[3] = {"jvp", "vjp", "transpose"};
if (P.parseToken(
tok::l_square,
diag::
sil_inst_autodiff_expected_differentiability_witness_kind) ||
parseSILIdentifierSwitch(
witnessKind, witnessKindNames,
diag::
sil_inst_autodiff_expected_differentiability_witness_kind) ||
P.parseToken(tok::r_square, diag::sil_autodiff_expected_rsquare,
"differentiability witness function kind"))
return true;
SourceLoc keyStartLoc = P.Tok.getLoc();
auto configAndFn =
parseSILDifferentiabilityWitnessConfigAndFunction(P, *this, InstLoc);
if (!configAndFn)
return true;
auto config = configAndFn->first;
auto originalFn = configAndFn->second;
auto *witness = SILMod.lookUpDifferentiabilityWitness(
{originalFn->getName(), config});
if (!witness) {
P.diagnose(keyStartLoc, diag::sil_diff_witness_undefined);
return true;
}
// Parse an optional explicit function type.
Optional<SILType> functionType = None;
if (P.consumeIf(tok::kw_as)) {
functionType = SILType();
if (parseSILType(*functionType))
return true;
}
ResultVal = B.createDifferentiabilityWitnessFunction(
InstLoc, witnessKind, witness, functionType);
break;
}
}

return false;
Expand Down
1 change: 1 addition & 0 deletions lib/SIL/OperandOwnership.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ SHOULD_NEVER_VISIT_INST(AllocBox)
SHOULD_NEVER_VISIT_INST(AllocExistentialBox)
SHOULD_NEVER_VISIT_INST(AllocGlobal)
SHOULD_NEVER_VISIT_INST(AllocStack)
SHOULD_NEVER_VISIT_INST(DifferentiabilityWitnessFunction)
SHOULD_NEVER_VISIT_INST(FloatLiteral)
SHOULD_NEVER_VISIT_INST(FunctionRef)
SHOULD_NEVER_VISIT_INST(DynamicFunctionRef)
Expand Down
44 changes: 44 additions & 0 deletions lib/SIL/SILInstructions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,50 @@ TryApplyInst *TryApplyInst::create(
normalBB, errorBB, specializationInfo);
}

SILType DifferentiabilityWitnessFunctionInst::getDifferentiabilityWitnessType(
SILModule &module, DifferentiabilityWitnessFunctionKind witnessKind,
SILDifferentiabilityWitness *witness) {
auto fnTy = witness->getOriginalFunction()->getLoweredFunctionType();
CanGenericSignature witnessCanGenSig;
if (auto witnessGenSig = witness->getDerivativeGenericSignature())
witnessCanGenSig = witnessGenSig->getCanonicalSignature();
auto *parameterIndices = witness->getParameterIndices();
auto *resultIndices = witness->getResultIndices();
if (auto derivativeKind = witnessKind.getAsDerivativeFunctionKind()) {
bool isReabstractionThunk =
witness->getOriginalFunction()->isThunk() == IsReabstractionThunk;
auto diffFnTy = fnTy->getAutoDiffDerivativeFunctionType(
parameterIndices, *resultIndices->begin(), *derivativeKind,
module.Types, LookUpConformanceInModule(module.getSwiftModule()),
witnessCanGenSig, isReabstractionThunk);
return SILType::getPrimitiveObjectType(diffFnTy);
}
assert(witnessKind == DifferentiabilityWitnessFunctionKind::Transpose);
auto transposeFnTy = fnTy->getAutoDiffTransposeFunctionType(
parameterIndices, module.Types,
LookUpConformanceInModule(module.getSwiftModule()), witnessCanGenSig);
return SILType::getPrimitiveObjectType(transposeFnTy);
}

DifferentiabilityWitnessFunctionInst::DifferentiabilityWitnessFunctionInst(
SILModule &module, SILDebugLocation debugLoc,
DifferentiabilityWitnessFunctionKind witnessKind,
SILDifferentiabilityWitness *witness, Optional<SILType> functionType)
: InstructionBase(debugLoc, functionType
? *functionType
: getDifferentiabilityWitnessType(
module, witnessKind, witness)),
witnessKind(witnessKind), witness(witness),
hasExplicitFunctionType(functionType) {
assert(witness && "Differentiability witness must not be null");
#ifndef NDEBUG
if (functionType.hasValue()) {
assert(module.getStage() == SILStage::Lowered &&
"Explicit type is valid only in lowered SIL");
}
#endif
}

FunctionRefBaseInst::FunctionRefBaseInst(SILInstructionKind Kind,
SILDebugLocation DebugLoc,
SILFunction *F,
Expand Down
Loading