Skip to content

Commit 0e5c3e5

Browse files
committed
[AutoDiff] Add differentiability_witness_function instruction.
`differentiability_witness_function` looks up a differentiability witness function (JVP, VJP, or transpose) for a referenced function using SIL differentiability witnesses.
1 parent 92b0f22 commit 0e5c3e5

20 files changed

+539
-47
lines changed

docs/SIL.rst

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5696,7 +5696,7 @@ linear_function_extract
56965696
sil-value ':' sil-type
56975697

56985698
sil-linear-function-extractee ::=
5699-
'[' sil-linear-function-extractee ']'
5699+
'[' sil-linear-function-extractee-name ']'
57005700
sil-linear-function-extractee-name ::= 'original' | 'transpose'
57015701

57025702
linear_function_extract [original] %0 : $@differentiable(linear) (T) -> T
@@ -5707,6 +5707,49 @@ Extracts the original function or a transpose function from the given
57075707
``[original]`` or ``[transpose]``.
57085708

57095709

5710+
differentiability_witness_function
5711+
``````````````````````````````````
5712+
5713+
::
5714+
5715+
sil-instruction ::= 'differentiability_witness_function'
5716+
sil-differentiability-witness-function-kind
5717+
'[' 'parameters' sil-differentiability-witness-indices ']'
5718+
'[' 'results' sil-differentiability-witness-indices ']'
5719+
generic-parameter-clause?
5720+
sil-function-name ':' sil-type
5721+
5722+
sil-differentiability-witness-function-kind ::=
5723+
'[' sil-differentiability-witness-function-kind-name ']'
5724+
sil-differentiability-witness-function-kind-name ::=
5725+
'jvp' | 'vjp' | 'transpose'
5726+
sil-differentiability-witness-indices ::= [0-9]+ (' ' [0-9]+)*
5727+
generic-parameter-clause ::=
5728+
'<' generic-parameter-list generic-where-clause '>'
5729+
generic-where-clause ::=
5730+
'where' generic-requirement (',' generic-requirement)*
5731+
generic-requirement ::=
5732+
type '==' type | type ':' type | type ':' layout-constraint
5733+
5734+
differentiability_witness_function [jvp] [parameters 0] [results 0] \
5735+
<T where T: Differentiable> @foo : $(T) -> T
5736+
5737+
Looks up the differentiability witness function for the referenced function
5738+
using SIL differentiability witnesses.
5739+
5740+
The differentiability witness function kind identifies the witness function to
5741+
look up: ``[jvp]``, ``[vjp]``, or ``[transpose]``.
5742+
5743+
The remaining components identify the SIL differentiability witness:
5744+
5745+
- Original function name.
5746+
- Parameter indices.
5747+
- Result indices.
5748+
- Witness generic parameter clause (optional). When parsing SIL, the parsed
5749+
witness generic parameter clause is combined with the original function's
5750+
generic signature to form the full witness generic signature.
5751+
5752+
57105753
Assertion configuration
57115754
~~~~~~~~~~~~~~~~~~~~~~~
57125755

include/swift/AST/AutoDiff.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,27 @@ struct AutoDiffDerivativeFunctionKind {
7979
}
8080
};
8181

82+
/// The kind of a differentiability witness function.
83+
struct DifferentiabilityWitnessFunctionKind {
84+
enum innerty : uint8_t {
85+
// The Jacobian-vector products function.
86+
JVP = 0,
87+
// The vector-Jacobian products function.
88+
VJP = 1,
89+
// The transpose function.
90+
Transpose = 2
91+
} rawValue;
92+
93+
DifferentiabilityWitnessFunctionKind() = default;
94+
DifferentiabilityWitnessFunctionKind(innerty rawValue) : rawValue(rawValue) {}
95+
explicit DifferentiabilityWitnessFunctionKind(unsigned rawValue)
96+
: rawValue(static_cast<innerty>(rawValue)) {}
97+
explicit DifferentiabilityWitnessFunctionKind(StringRef name);
98+
operator innerty() const { return rawValue; }
99+
100+
Optional<AutoDiffDerivativeFunctionKind> getAsDerivativeFunctionKind() const;
101+
};
102+
82103
struct NormalDifferentiableFunctionTypeComponent {
83104
enum innerty : unsigned {
84105
Original = 0,

include/swift/AST/DiagnosticsParse.def

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1620,6 +1620,13 @@ ERROR(sil_inst_autodiff_expected_linear_extractee_kind,PointsToFirstBadToken,
16201620
"and '[transpose]'", ())
16211621
ERROR(sil_inst_autodiff_expected_function_type_operand,PointsToFirstBadToken,
16221622
"expected an operand of a function type", ())
1623+
ERROR(sil_inst_autodiff_expected_differentiability_witness_kind,PointsToFirstBadToken,
1624+
"expected a differentiability witness kind, which can be one of '[jvp]', "
1625+
"'[vjp]', or '[transpose]'", ())
1626+
ERROR(sil_inst_autodiff_invalid_witness_generic_signature,PointsToFirstBadToken,
1627+
"expected witness_generic signature '%0' does not have same generic "
1628+
"parameters as original function generic signature '%1'",
1629+
(StringRef, StringRef))
16231630

16241631
// Quoted attribute.
16251632
ERROR(attr_quoted_enable_experimental_quasiquotes,PointsToFirstBadToken,

include/swift/SIL/SILBuilder.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,18 @@ class SILBuilder {
549549
NormalDifferentiableFunctionTypeComponent::Original, TheFunction));
550550
}
551551

552+
DifferentiabilityWitnessFunctionInst *
553+
createDifferentiabilityWitnessFunction(
554+
SILLocation Loc, SILFunction *OriginalFunction,
555+
DifferentiabilityWitnessFunctionKind WitnessKind,
556+
IndexSubset *ParameterIndices, IndexSubset *ResultIndices,
557+
GenericSignature *WitnessGenericSignature) {
558+
return insert(new (getModule()) DifferentiabilityWitnessFunctionInst(
559+
getModule(), getSILDebugLocation(Loc), OriginalFunction, WitnessKind,
560+
ParameterIndices, ResultIndices, WitnessGenericSignature));
561+
}
562+
// SWIFT_ENABLE_TENSORFLOW END
563+
552564
BuiltinInst *createBuiltin(SILLocation Loc, Identifier Name, SILType ResultTy,
553565
SubstitutionMap Subs,
554566
ArrayRef<SILValue> Args) {

include/swift/SIL/SILCloner.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,6 +1011,17 @@ visitLinearFunctionExtractInst(LinearFunctionExtractInst *Inst) {
10111011
getOpLocation(Inst->getLoc()), Inst->getExtractee(),
10121012
getOpValue(Inst->getFunctionOperand())));
10131013
}
1014+
1015+
template<typename ImplClass>
1016+
void SILCloner<ImplClass>::visitDifferentiabilityWitnessFunctionInst(
1017+
DifferentiabilityWitnessFunctionInst *Inst) {
1018+
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
1019+
recordClonedInstruction(
1020+
Inst, getBuilder().createDifferentiabilityWitnessFunction(
1021+
getOpLocation(Inst->getLoc()), Inst->getOriginalFunction(),
1022+
Inst->getWitnessKind(), Inst->getParameterIndices(),
1023+
Inst->getResultIndices(), Inst->getWitnessGenericSignature()));
1024+
}
10141025
// SWIFT_ENABLE_TENSORFLOW END
10151026

10161027
template<typename ImplClass>

include/swift/SIL/SILInstruction.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8031,6 +8031,54 @@ class LinearFunctionExtractInst
80318031
ArrayRef<Operand> getAllOperands() const { return operands.asArray(); }
80328032
MutableArrayRef<Operand> getAllOperands() { return operands.asArray(); }
80338033
};
8034+
8035+
class DifferentiabilityWitnessFunctionInst
8036+
: public InstructionBase<
8037+
SILInstructionKind::DifferentiabilityWitnessFunctionInst,
8038+
SingleValueInstruction> {
8039+
private:
8040+
friend SILBuilder;
8041+
/// The original function.
8042+
SILFunction *originalFunction;
8043+
/// The differentiability witness function kind.
8044+
DifferentiabilityWitnessFunctionKind witnessKind;
8045+
/// The autodiff config: parameter indices, result indices, and witness
8046+
/// derivative signature.
8047+
AutoDiffConfig config;
8048+
8049+
static SILType getDifferentiabilityWitnessType(
8050+
SILModule &module, SILFunction *originalFunction,
8051+
DifferentiabilityWitnessFunctionKind witnessKind,
8052+
IndexSubset *parameterIndices, IndexSubset *resultIndices,
8053+
GenericSignature *witnessGenericSignature);
8054+
8055+
public:
8056+
DifferentiabilityWitnessFunctionInst(
8057+
SILModule &module, SILDebugLocation loc, SILFunction *originalFunction,
8058+
DifferentiabilityWitnessFunctionKind witnessKind,
8059+
IndexSubset *parameterIndices, IndexSubset *resultIndices,
8060+
GenericSignature *witnessGenericSignature);
8061+
8062+
static DifferentiabilityWitnessFunctionInst *create(
8063+
SILModule &module, SILDebugLocation loc, SILFunction *originalFunction,
8064+
DifferentiabilityWitnessFunctionKind witnessKind,
8065+
IndexSubset *parameterIndices, IndexSubset *resultIndices,
8066+
GenericSignature *witnessGenericSignature);
8067+
8068+
DifferentiabilityWitnessFunctionKind getWitnessKind() const {
8069+
return witnessKind;
8070+
}
8071+
SILFunction *getOriginalFunction() const { return originalFunction; }
8072+
AutoDiffConfig const &getConfig() const { return config; }
8073+
IndexSubset *getParameterIndices() const { return config.parameterIndices; }
8074+
IndexSubset *getResultIndices() const { return config.resultIndices; }
8075+
GenericSignature *getWitnessGenericSignature() const {
8076+
return config.derivativeGenericSignature;
8077+
}
8078+
8079+
ArrayRef<Operand> getAllOperands() const { return {}; }
8080+
MutableArrayRef<Operand> getAllOperands() { return {}; }
8081+
};
80348082
// SWIFT_ENABLE_TENSORFLOW END
80358083

80368084
// This is defined out of line to work around the fact that this depends on

include/swift/SIL/SILNodes.def

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,10 @@ ABSTRACT_VALUE_AND_INST(SingleValueInstruction, ValueBase, SILInstruction)
700700
SINGLE_VALUE_INST(LinearFunctionExtractInst,
701701
linear_function_extract,
702702
SingleValueInstruction, None, DoesNotRelease)
703+
SINGLE_VALUE_INST(DifferentiabilityWitnessFunctionInst,
704+
differentiability_witness_function,
705+
SingleValueInstruction, None, DoesNotRelease)
706+
// SWIFT_ENABLE_TENSORFLOW END
703707

704708
// Key paths
705709
// TODO: The only "side effect" is potentially retaining the returned key path

lib/AST/AutoDiff.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,25 @@ AutoDiffDerivativeFunctionKind(StringRef string) {
3232
rawValue = *result;
3333
}
3434

35+
DifferentiabilityWitnessFunctionKind::
36+
DifferentiabilityWitnessFunctionKind(StringRef string) {
37+
Optional<innerty> result = llvm::StringSwitch<Optional<innerty>>(string)
38+
.Case("jvp", JVP)
39+
.Case("vjp", VJP)
40+
.Case("transpose", Transpose);
41+
assert(result && "Invalid string");
42+
rawValue = *result;
43+
}
44+
45+
Optional<AutoDiffDerivativeFunctionKind>
46+
DifferentiabilityWitnessFunctionKind::getAsDerivativeFunctionKind() const {
47+
switch (rawValue) {
48+
case JVP: return {AutoDiffDerivativeFunctionKind::JVP};
49+
case VJP: return {AutoDiffDerivativeFunctionKind::VJP};
50+
case Transpose: return None;
51+
}
52+
}
53+
3554
NormalDifferentiableFunctionTypeComponent::
3655
NormalDifferentiableFunctionTypeComponent(AutoDiffDerivativeFunctionKind kind) {
3756
switch (kind) {

lib/IRGen/IRGenSIL.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,8 @@ class IRGenSILFunction :
925925
void visitDifferentiableFunctionExtractInst(
926926
DifferentiableFunctionExtractInst *i);
927927
void visitLinearFunctionExtractInst(LinearFunctionExtractInst *i);
928+
void visitDifferentiabilityWitnessFunctionInst(
929+
DifferentiabilityWitnessFunctionInst *i);
928930
// SWIFT_ENABLE_TENSORFLOW END
929931

930932
void visitFunctionRefBaseInst(FunctionRefBaseInst *i);
@@ -1927,6 +1929,13 @@ visitLinearFunctionExtractInst(LinearFunctionExtractInst *i) {
19271929
setLoweredExplosion(i, e);
19281930
}
19291931

1932+
void IRGenSILFunction::visitDifferentiabilityWitnessFunctionInst(
1933+
DifferentiabilityWitnessFunctionInst *i) {
1934+
// TODO(TF-916): Implement IRGen for `differentiability_witness_function`.
1935+
llvm_unreachable("unimplemented");
1936+
}
1937+
// SWIFT_ENABLE_TENSORFLOW END
1938+
19301939
void IRGenSILFunction::visitFunctionRefBaseInst(FunctionRefBaseInst *i) {
19311940
auto fn = i->getInitiallyReferencedFunction();
19321941

lib/ParseSIL/ParseSIL.cpp

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3091,6 +3091,110 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
30913091
InstLoc, extractee, functionOperand);
30923092
break;
30933093
}
3094+
case SILInstructionKind::DifferentiabilityWitnessFunctionInst: {
3095+
// e.g. differentiability_witness_function
3096+
// [jvp] [parameters 0 1] [results 0] <T where T: Differentiable>
3097+
// @foo : $(T) -> T
3098+
DifferentiabilityWitnessFunctionKind witnessKind;
3099+
StringRef witnessKindNames[3] = {"jvp", "vjp", "transpose"};
3100+
SourceLoc lastLoc;
3101+
if (P.parseToken(tok::l_square,
3102+
diag::sil_inst_autodiff_expected_differentiability_witness_kind) ||
3103+
parseSILIdentifierSwitch(witnessKind, witnessKindNames,
3104+
diag::sil_inst_autodiff_expected_differentiability_witness_kind) ||
3105+
P.parseToken(tok::r_square, diag::sil_autodiff_expected_rsquare,
3106+
"differentiability witness function kind"))
3107+
return true;
3108+
// Parse an index set, prefaced with the given label.
3109+
auto parseIndexSet = [&](StringRef label, SmallVectorImpl<unsigned> &indices,
3110+
const Diagnostic &parseIndexDiag) -> bool {
3111+
// Parse `[<label> <integer_literal>...]`.
3112+
if (P.parseToken(tok::l_square, diag::sil_autodiff_expected_lsquare,
3113+
"index list") ||
3114+
P.parseSpecificIdentifier(
3115+
label, diag::sil_autodiff_expected_index_list_label, label))
3116+
return true;
3117+
while (P.Tok.is(tok::integer_literal)) {
3118+
unsigned index;
3119+
if (P.parseUnsignedInteger(index, lastLoc, parseIndexDiag))
3120+
return true;
3121+
indices.push_back(index);
3122+
}
3123+
if (P.parseToken(tok::r_square, diag::sil_autodiff_expected_rsquare,
3124+
"index list"))
3125+
return true;
3126+
return false;
3127+
};
3128+
// Parse parameter and result indices.
3129+
SmallVector<unsigned, 8> parameterIndices;
3130+
SmallVector<unsigned, 8> resultIndices;
3131+
if (parseIndexSet("parameters", parameterIndices,
3132+
diag::sil_autodiff_expected_parameter_index))
3133+
return true;
3134+
if (parseIndexSet("results", resultIndices,
3135+
diag::sil_autodiff_expected_result_index))
3136+
return true;
3137+
// Parse witness generic parameter clause.
3138+
GenericSignature *witnessGenSig = nullptr;
3139+
SourceLoc witnessGenSigStartLoc = P.getEndOfPreviousLoc();
3140+
{
3141+
// Create a new scope to avoid type redefinition errors.
3142+
Scope genericsScope(&P, ScopeKind::Generics);
3143+
auto *genericParams = P.parseSILGenericParams().getPtrOrNull();
3144+
if (genericParams) {
3145+
auto *witnessGenEnv =
3146+
handleSILGenericParams(P.Context, genericParams, &P.SF);
3147+
witnessGenSig = witnessGenEnv->getGenericSignature();
3148+
}
3149+
}
3150+
// Parse original function name and type.
3151+
SILFunction *originalFunction;
3152+
if (parseSILFunctionRef(InstLoc, originalFunction))
3153+
return true;
3154+
// Resolve parsed witness generic signature.
3155+
if (witnessGenSig) {
3156+
auto origGenSig = originalFunction
3157+
->getLoweredFunctionType()->getGenericSignature();
3158+
// Check whether original function generic signature and parsed witness
3159+
// generic have the same generic parameters.
3160+
auto areGenericParametersConsistent = [&]() {
3161+
llvm::DenseSet<GenericParamKey> genericParamKeys;
3162+
for (auto *origGP : origGenSig->getGenericParams())
3163+
genericParamKeys.insert(GenericParamKey(origGP));
3164+
for (auto *witnessGP : witnessGenSig->getGenericParams())
3165+
if (!genericParamKeys.erase(GenericParamKey(witnessGP)))
3166+
return false;
3167+
return genericParamKeys.empty();
3168+
};
3169+
if (!areGenericParametersConsistent()) {
3170+
P.diagnose(witnessGenSigStartLoc,
3171+
diag::sil_inst_autodiff_invalid_witness_generic_signature,
3172+
witnessGenSig->getAsString(), origGenSig->getAsString());
3173+
return true;
3174+
}
3175+
// Combine parsed witness requirements with original function generic
3176+
// signature requirements to form full witness generic signature.
3177+
SmallVector<Requirement, 4> witnessRequirements(
3178+
witnessGenSig->getRequirements().begin(),
3179+
witnessGenSig->getRequirements().end());
3180+
witnessGenSig = evaluateOrDefault(
3181+
P.Context.evaluator,
3182+
AbstractGenericSignatureRequest{
3183+
origGenSig,
3184+
/*addedGenericParams=*/{},
3185+
std::move(witnessRequirements)},
3186+
nullptr);
3187+
}
3188+
auto origFnType = originalFunction->getLoweredFunctionType();
3189+
auto *parameterIndexSet = IndexSubset::get(
3190+
P.Context, origFnType->getNumParameters(), parameterIndices);
3191+
auto *resultIndexSet = IndexSubset::get(
3192+
P.Context, origFnType->getNumResults(), resultIndices);
3193+
ResultVal = B.createDifferentiabilityWitnessFunction(
3194+
InstLoc, originalFunction, witnessKind, parameterIndexSet,
3195+
resultIndexSet, witnessGenSig);
3196+
break;
3197+
}
30943198
// SWIFT_ENABLE_TENSORFLOW END
30953199

30963200
case SILInstructionKind::DynamicFunctionRefInst: {

lib/SIL/OperandOwnership.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ SHOULD_NEVER_VISIT_INST(Unwind)
136136
SHOULD_NEVER_VISIT_INST(ReleaseValue)
137137
SHOULD_NEVER_VISIT_INST(ReleaseValueAddr)
138138
SHOULD_NEVER_VISIT_INST(StrongRelease)
139+
// SWIFT_ENABLE_TENSORFLOW
140+
SHOULD_NEVER_VISIT_INST(DifferentiabilityWitnessFunction)
141+
// SWIFT_ENABLE_TENSORFLOW END
139142
#define ALWAYS_OR_SOMETIMES_LOADABLE_CHECKED_REF_STORAGE(Name, ...) \
140143
SHOULD_NEVER_VISIT_INST(StrongRetain##Name) \
141144
SHOULD_NEVER_VISIT_INST(Name##Retain)

0 commit comments

Comments
 (0)