Skip to content

[AutoDiff] Add differentiability_witness_function instruction. #27719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions docs/SIL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5671,12 +5671,10 @@ differentiable_function_extract
::

sil-instruction ::= 'differentiable_function_extract'
sil-differentiable-function-extractee
'[' sil-differentiable-function-extractee ']'
sil-value ':' sil-type

sil-differentiable-function-extractee ::=
'[' sil-differentiable-function-extractee ']'
sil-differentiable-function-extractee-name ::= 'original' | 'jvp' | 'vjp'
sil-differentiable-function-extractee ::= 'original' | 'jvp' | 'vjp'

differentiable_function_extract [original] %0 : $@differentiable (T) -> T
differentiable_function_extract [jvp] %0 : $@differentiable (T) -> T
Expand All @@ -5692,12 +5690,10 @@ linear_function_extract
::

sil-instruction ::= 'linear_function_extract'
sil-linear-function-extractee
'[' sil-linear-function-extractee ']'
sil-value ':' sil-type

sil-linear-function-extractee ::=
'[' sil-linear-function-extractee ']'
sil-linear-function-extractee-name ::= 'original' | 'transpose'
sil-linear-function-extractee ::= 'original' | 'transpose'

linear_function_extract [original] %0 : $@differentiable(linear) (T) -> T
linear_function_extract [transpose] %0 : $@differentiable(linear) (T) -> T
Expand All @@ -5707,6 +5703,40 @@ Extracts the original function or a transpose function from the given
``[original]`` or ``[transpose]``.


differentiability_witness_function
``````````````````````````````````
::

sil-instruction ::=
'differentiability_witness_function'
'[' sil-differentiability-witness-function-kind ']'
'[' 'parameters' sil-differentiability-witness-function-index-list ']'
'[' 'results' sil-differentiability-witness-function-index-list ']'
generic-parameter-clause?
sil-function-name ':' sil-type

sil-differentiability-witness-function-kind ::= 'jvp' | 'vjp' | 'transpose'
sil-differentiability-witness-function-index-list ::= [0-9]+ (' ' [0-9]+)*

differentiability_witness_function [jvp] [parameters 0] [results 0] \
<T where T: Differentiable> @foo : $(T) -> T

Looks up the differentiability witness function for the referenced function
using SIL differentiability witnesses.

The differentiability witness function kind identifies the witness function to
look up: ``[jvp]``, ``[vjp]``, or ``[transpose]``.

The remaining components identify the SIL differentiability witness:

- Original function name.
- Parameter indices.
- Result indices.
- Witness generic parameter clause (optional). When parsing SIL, the parsed
witness generic parameter clause is combined with the original function's
generic signature to form the full witness generic signature.


Assertion configuration
~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
21 changes: 21 additions & 0 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,27 @@ struct AutoDiffDerivativeFunctionKind {
}
};

/// The kind of a differentiability witness function.
struct DifferentiabilityWitnessFunctionKind {
enum innerty : uint8_t {
// The Jacobian-vector products function.
JVP = 0,
// The vector-Jacobian products function.
VJP = 1,
// The transpose function.
Transpose = 2
} rawValue;

DifferentiabilityWitnessFunctionKind() = default;
DifferentiabilityWitnessFunctionKind(innerty rawValue) : rawValue(rawValue) {}
explicit DifferentiabilityWitnessFunctionKind(unsigned rawValue)
: rawValue(static_cast<innerty>(rawValue)) {}
explicit DifferentiabilityWitnessFunctionKind(StringRef name);
operator innerty() const { return rawValue; }

Optional<AutoDiffDerivativeFunctionKind> getAsDerivativeFunctionKind() const;
};

struct NormalDifferentiableFunctionTypeComponent {
enum innerty : unsigned {
Original = 0,
Expand Down
7 changes: 7 additions & 0 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1620,6 +1620,13 @@ ERROR(sil_inst_autodiff_expected_linear_extractee_kind,PointsToFirstBadToken,
"and '[transpose]'", ())
ERROR(sil_inst_autodiff_expected_function_type_operand,PointsToFirstBadToken,
"expected an operand of a function type", ())
ERROR(sil_inst_autodiff_expected_differentiability_witness_kind,PointsToFirstBadToken,
"expected a differentiability witness kind, which can be one of '[jvp]', "
"'[vjp]', or '[transpose]'", ())
ERROR(sil_inst_autodiff_invalid_witness_generic_signature,PointsToFirstBadToken,
"expected witness_generic signature '%0' does not have same generic "
"parameters as original function generic signature '%1'",
(StringRef, StringRef))

// Quoted attribute.
ERROR(attr_quoted_enable_experimental_quasiquotes,PointsToFirstBadToken,
Expand Down
12 changes: 12 additions & 0 deletions include/swift/SIL/SILBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,18 @@ class SILBuilder {
NormalDifferentiableFunctionTypeComponent::Original, TheFunction));
}

DifferentiabilityWitnessFunctionInst *
createDifferentiabilityWitnessFunction(
SILLocation Loc, SILFunction *OriginalFunction,
DifferentiabilityWitnessFunctionKind WitnessKind,
IndexSubset *ParameterIndices, IndexSubset *ResultIndices,
GenericSignature *WitnessGenericSignature) {
return insert(new (getModule()) DifferentiabilityWitnessFunctionInst(
getModule(), getSILDebugLocation(Loc), OriginalFunction, WitnessKind,
ParameterIndices, ResultIndices, WitnessGenericSignature));
}
// SWIFT_ENABLE_TENSORFLOW END

BuiltinInst *createBuiltin(SILLocation Loc, Identifier Name, SILType ResultTy,
SubstitutionMap Subs,
ArrayRef<SILValue> Args) {
Expand Down
11 changes: 11 additions & 0 deletions include/swift/SIL/SILCloner.h
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,17 @@ visitLinearFunctionExtractInst(LinearFunctionExtractInst *Inst) {
getOpLocation(Inst->getLoc()), Inst->getExtractee(),
getOpValue(Inst->getFunctionOperand())));
}

template<typename ImplClass>
void SILCloner<ImplClass>::visitDifferentiabilityWitnessFunctionInst(
DifferentiabilityWitnessFunctionInst *Inst) {
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
recordClonedInstruction(
Inst, getBuilder().createDifferentiabilityWitnessFunction(
getOpLocation(Inst->getLoc()), Inst->getOriginalFunction(),
Inst->getWitnessKind(), Inst->getParameterIndices(),
Inst->getResultIndices(), Inst->getWitnessGenericSignature()));
}
// SWIFT_ENABLE_TENSORFLOW END

template<typename ImplClass>
Expand Down
48 changes: 48 additions & 0 deletions include/swift/SIL/SILInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -8031,6 +8031,54 @@ class LinearFunctionExtractInst
ArrayRef<Operand> getAllOperands() const { return operands.asArray(); }
MutableArrayRef<Operand> getAllOperands() { return operands.asArray(); }
};

class DifferentiabilityWitnessFunctionInst
: public InstructionBase<
SILInstructionKind::DifferentiabilityWitnessFunctionInst,
SingleValueInstruction> {
private:
friend SILBuilder;
/// The original function.
SILFunction *originalFunction;
/// The differentiability witness function kind.
DifferentiabilityWitnessFunctionKind witnessKind;
/// The autodiff config: parameter indices, result indices, and witness
/// derivative signature.
AutoDiffConfig config;

static SILType getDifferentiabilityWitnessType(
SILModule &module, SILFunction *originalFunction,
DifferentiabilityWitnessFunctionKind witnessKind,
IndexSubset *parameterIndices, IndexSubset *resultIndices,
GenericSignature *witnessGenericSignature);

public:
DifferentiabilityWitnessFunctionInst(
SILModule &module, SILDebugLocation loc, SILFunction *originalFunction,
DifferentiabilityWitnessFunctionKind witnessKind,
IndexSubset *parameterIndices, IndexSubset *resultIndices,
GenericSignature *witnessGenericSignature);

static DifferentiabilityWitnessFunctionInst *create(
SILModule &module, SILDebugLocation loc, SILFunction *originalFunction,
DifferentiabilityWitnessFunctionKind witnessKind,
IndexSubset *parameterIndices, IndexSubset *resultIndices,
GenericSignature *witnessGenericSignature);

DifferentiabilityWitnessFunctionKind getWitnessKind() const {
return witnessKind;
}
SILFunction *getOriginalFunction() const { return originalFunction; }
AutoDiffConfig const &getConfig() const { return config; }
IndexSubset *getParameterIndices() const { return config.parameterIndices; }
IndexSubset *getResultIndices() const { return config.resultIndices; }
GenericSignature *getWitnessGenericSignature() const {
return config.derivativeGenericSignature;
}

ArrayRef<Operand> getAllOperands() const { return {}; }
MutableArrayRef<Operand> getAllOperands() { return {}; }
};
// SWIFT_ENABLE_TENSORFLOW END

// This is defined out of line to work around the fact that this depends on
Expand Down
4 changes: 4 additions & 0 deletions include/swift/SIL/SILNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,10 @@ ABSTRACT_VALUE_AND_INST(SingleValueInstruction, ValueBase, SILInstruction)
SINGLE_VALUE_INST(LinearFunctionExtractInst,
linear_function_extract,
SingleValueInstruction, None, DoesNotRelease)
SINGLE_VALUE_INST(DifferentiabilityWitnessFunctionInst,
differentiability_witness_function,
SingleValueInstruction, None, DoesNotRelease)
// SWIFT_ENABLE_TENSORFLOW END

// Key paths
// TODO: The only "side effect" is potentially retaining the returned key path
Expand Down
19 changes: 19 additions & 0 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,25 @@ AutoDiffDerivativeFunctionKind(StringRef string) {
rawValue = *result;
}

DifferentiabilityWitnessFunctionKind::
DifferentiabilityWitnessFunctionKind(StringRef string) {
Optional<innerty> result = llvm::StringSwitch<Optional<innerty>>(string)
.Case("jvp", JVP)
.Case("vjp", VJP)
.Case("transpose", Transpose);
assert(result && "Invalid string");
rawValue = *result;
}

Optional<AutoDiffDerivativeFunctionKind>
DifferentiabilityWitnessFunctionKind::getAsDerivativeFunctionKind() const {
switch (rawValue) {
case JVP: return {AutoDiffDerivativeFunctionKind::JVP};
case VJP: return {AutoDiffDerivativeFunctionKind::VJP};
case Transpose: return None;
}
}

NormalDifferentiableFunctionTypeComponent::
NormalDifferentiableFunctionTypeComponent(AutoDiffDerivativeFunctionKind kind) {
switch (kind) {
Expand Down
9 changes: 9 additions & 0 deletions lib/IRGen/IRGenSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,8 @@ class IRGenSILFunction :
void visitDifferentiableFunctionExtractInst(
DifferentiableFunctionExtractInst *i);
void visitLinearFunctionExtractInst(LinearFunctionExtractInst *i);
void visitDifferentiabilityWitnessFunctionInst(
DifferentiabilityWitnessFunctionInst *i);
// SWIFT_ENABLE_TENSORFLOW END

void visitFunctionRefBaseInst(FunctionRefBaseInst *i);
Expand Down Expand Up @@ -1927,6 +1929,13 @@ visitLinearFunctionExtractInst(LinearFunctionExtractInst *i) {
setLoweredExplosion(i, e);
}

void IRGenSILFunction::visitDifferentiabilityWitnessFunctionInst(
DifferentiabilityWitnessFunctionInst *i) {
// TODO(TF-916): Implement IRGen for `differentiability_witness_function`.
llvm_unreachable("unimplemented");
}
// SWIFT_ENABLE_TENSORFLOW END

void IRGenSILFunction::visitFunctionRefBaseInst(FunctionRefBaseInst *i) {
auto fn = i->getInitiallyReferencedFunction();

Expand Down
104 changes: 104 additions & 0 deletions lib/ParseSIL/ParseSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3091,6 +3091,110 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
InstLoc, extractee, functionOperand);
break;
}
case SILInstructionKind::DifferentiabilityWitnessFunctionInst: {
// e.g. differentiability_witness_function
// [jvp] [parameters 0 1] [results 0] <T where T: Differentiable>
// @foo : $(T) -> T
DifferentiabilityWitnessFunctionKind witnessKind;
StringRef witnessKindNames[3] = {"jvp", "vjp", "transpose"};
SourceLoc lastLoc;
if (P.parseToken(tok::l_square,
diag::sil_inst_autodiff_expected_differentiability_witness_kind) ||
parseSILIdentifierSwitch(witnessKind, witnessKindNames,
diag::sil_inst_autodiff_expected_differentiability_witness_kind) ||
P.parseToken(tok::r_square, diag::sil_autodiff_expected_rsquare,
"differentiability witness function kind"))
return true;
// Parse an index set, prefaced with the given label.
auto parseIndexSet = [&](StringRef label, SmallVectorImpl<unsigned> &indices,
const Diagnostic &parseIndexDiag) -> bool {
// Parse `[<label> <integer_literal>...]`.
if (P.parseToken(tok::l_square, diag::sil_autodiff_expected_lsquare,
"index list") ||
P.parseSpecificIdentifier(
label, diag::sil_autodiff_expected_index_list_label, label))
return true;
while (P.Tok.is(tok::integer_literal)) {
unsigned index;
if (P.parseUnsignedInteger(index, lastLoc, parseIndexDiag))
return true;
indices.push_back(index);
}
if (P.parseToken(tok::r_square, diag::sil_autodiff_expected_rsquare,
"index list"))
return true;
return false;
};
// Parse parameter and result indices.
SmallVector<unsigned, 8> parameterIndices;
SmallVector<unsigned, 8> resultIndices;
if (parseIndexSet("parameters", parameterIndices,
diag::sil_autodiff_expected_parameter_index))
return true;
if (parseIndexSet("results", resultIndices,
diag::sil_autodiff_expected_result_index))
return true;
// Parse witness generic parameter clause.
GenericSignature *witnessGenSig = nullptr;
SourceLoc witnessGenSigStartLoc = P.getEndOfPreviousLoc();
{
// Create a new scope to avoid type redefinition errors.
Scope genericsScope(&P, ScopeKind::Generics);
auto *genericParams = P.parseSILGenericParams().getPtrOrNull();
if (genericParams) {
auto *witnessGenEnv =
handleSILGenericParams(P.Context, genericParams, &P.SF);
witnessGenSig = witnessGenEnv->getGenericSignature();
}
}
// Parse original function name and type.
SILFunction *originalFunction;
if (parseSILFunctionRef(InstLoc, originalFunction))
return true;
// Resolve parsed witness generic signature.
if (witnessGenSig) {
auto origGenSig = originalFunction
->getLoweredFunctionType()->getGenericSignature();
// Check whether original function generic signature and parsed witness
// generic have the same generic parameters.
auto areGenericParametersConsistent = [&]() {
llvm::DenseSet<GenericParamKey> genericParamKeys;
for (auto *origGP : origGenSig->getGenericParams())
genericParamKeys.insert(GenericParamKey(origGP));
for (auto *witnessGP : witnessGenSig->getGenericParams())
if (!genericParamKeys.erase(GenericParamKey(witnessGP)))
return false;
return genericParamKeys.empty();
};
if (!areGenericParametersConsistent()) {
P.diagnose(witnessGenSigStartLoc,
diag::sil_inst_autodiff_invalid_witness_generic_signature,
witnessGenSig->getAsString(), origGenSig->getAsString());
return true;
}
// Combine parsed witness requirements with original function generic
// signature requirements to form full witness generic signature.
SmallVector<Requirement, 4> witnessRequirements(
witnessGenSig->getRequirements().begin(),
witnessGenSig->getRequirements().end());
witnessGenSig = evaluateOrDefault(
P.Context.evaluator,
AbstractGenericSignatureRequest{
origGenSig,
/*addedGenericParams=*/{},
std::move(witnessRequirements)},
nullptr);
}
auto origFnType = originalFunction->getLoweredFunctionType();
auto *parameterIndexSet = IndexSubset::get(
P.Context, origFnType->getNumParameters(), parameterIndices);
auto *resultIndexSet = IndexSubset::get(
P.Context, origFnType->getNumResults(), resultIndices);
ResultVal = B.createDifferentiabilityWitnessFunction(
InstLoc, originalFunction, witnessKind, parameterIndexSet,
resultIndexSet, witnessGenSig);
break;
}
// SWIFT_ENABLE_TENSORFLOW END

case SILInstructionKind::DynamicFunctionRefInst: {
Expand Down
3 changes: 3 additions & 0 deletions lib/SIL/OperandOwnership.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ SHOULD_NEVER_VISIT_INST(Unwind)
SHOULD_NEVER_VISIT_INST(ReleaseValue)
SHOULD_NEVER_VISIT_INST(ReleaseValueAddr)
SHOULD_NEVER_VISIT_INST(StrongRelease)
// SWIFT_ENABLE_TENSORFLOW
SHOULD_NEVER_VISIT_INST(DifferentiabilityWitnessFunction)
// SWIFT_ENABLE_TENSORFLOW END
#define ALWAYS_OR_SOMETIMES_LOADABLE_CHECKED_REF_STORAGE(Name, ...) \
SHOULD_NEVER_VISIT_INST(StrongRetain##Name) \
SHOULD_NEVER_VISIT_INST(Name##Retain)
Expand Down
Loading