Skip to content

Commit eaba367

Browse files
committed
[AutoDiff upstream] Add differentiability_witness_function instruction.
The `differentiability_witness_function` instruction looks up a differentiability witness function (JVP, VJP, or transpose) for a referenced function via SIL differentiability witnesses. Add round-trip parsing/serialization and IRGen tests. Notes: - Differentiability witnesses for linear functions require more support. `differentiability_witness_function [transpose]` instructions do not yet have IRGen. - Nothing currently generates `differentiability_witness_function` instructions. The differentiation transform does, but it hasn't been upstreamed yet. Resolves TF-1141.
1 parent 518509d commit eaba367

22 files changed

+553
-2
lines changed

docs/SIL.rst

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5666,6 +5666,42 @@ destination (if it returns with ``throw``).
56665666

56675667
The rules on generic substitutions are identical to those of ``apply``.
56685668

5669+
Differentiable Programming
5670+
~~~~~~~~~~~~~~~~~~~~~~~~~~
5671+
5672+
differentiability_witness_function
5673+
``````````````````````````````````
5674+
::
5675+
5676+
sil-instruction ::=
5677+
'differentiability_witness_function'
5678+
'[' sil-differentiability-witness-function-kind ']'
5679+
'[' 'parameters' sil-differentiability-witness-function-index-list ']'
5680+
'[' 'results' sil-differentiability-witness-function-index-list ']'
5681+
generic-parameter-clause?
5682+
sil-function-name ':' sil-type
5683+
5684+
sil-differentiability-witness-function-kind ::= 'jvp' | 'vjp' | 'transpose'
5685+
sil-differentiability-witness-function-index-list ::= [0-9]+ (' ' [0-9]+)*
5686+
5687+
differentiability_witness_function [jvp] [parameters 0] [results 0] \
5688+
<T where T: Differentiable> @foo : $(T) -> T
5689+
5690+
Looks up a differentiability witness function (JVP, VJP, or transpose) for
5691+
a referenced function via SIL differentiability witnesses.
5692+
5693+
The differentiability witness function kind identifies the witness function to
5694+
look up: ``[jvp]``, ``[vjp]``, or ``[transpose]``.
5695+
5696+
The remaining components identify the SIL differentiability witness:
5697+
5698+
- Original function name.
5699+
- Parameter indices.
5700+
- Result indices.
5701+
- Witness generic parameter clause (optional). When parsing SIL, the parsed
5702+
witness generic parameter clause is combined with the original function's
5703+
generic signature to form the full witness generic signature.
5704+
56695705
Assertion configuration
56705706
~~~~~~~~~~~~~~~~~~~~~~~
56715707

include/swift/AST/AutoDiff.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,27 @@ struct AutoDiffDerivativeFunctionKind {
7373
}
7474
};
7575

76+
/// The kind of a differentiability witness function.
77+
struct DifferentiabilityWitnessFunctionKind {
78+
enum innerty : uint8_t {
79+
// The Jacobian-vector products function.
80+
JVP = 0,
81+
// The vector-Jacobian products function.
82+
VJP = 1,
83+
// The transpose function.
84+
Transpose = 2
85+
} rawValue;
86+
87+
DifferentiabilityWitnessFunctionKind() = default;
88+
DifferentiabilityWitnessFunctionKind(innerty rawValue) : rawValue(rawValue) {}
89+
explicit DifferentiabilityWitnessFunctionKind(unsigned rawValue)
90+
: rawValue(static_cast<innerty>(rawValue)) {}
91+
explicit DifferentiabilityWitnessFunctionKind(StringRef name);
92+
operator innerty() const { return rawValue; }
93+
94+
Optional<AutoDiffDerivativeFunctionKind> getAsDerivativeFunctionKind() const;
95+
};
96+
7697
/// Identifies an autodiff derivative function configuration:
7798
/// - Parameter indices.
7899
/// - Result indices.

include/swift/AST/DiagnosticsParse.def

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1604,6 +1604,13 @@ ERROR(sil_autodiff_expected_parameter_index,PointsToFirstBadToken,
16041604
"expected the index of a parameter to differentiate with respect to", ())
16051605
ERROR(sil_autodiff_expected_result_index,PointsToFirstBadToken,
16061606
"expected the index of a result to differentiate from", ())
1607+
ERROR(sil_inst_autodiff_expected_differentiability_witness_kind,PointsToFirstBadToken,
1608+
"expected a differentiability witness kind, which can be one of '[jvp]', "
1609+
"'[vjp]', or '[transpose]'", ())
1610+
ERROR(sil_inst_autodiff_invalid_witness_generic_signature,PointsToFirstBadToken,
1611+
"expected witness_generic signature '%0' does not have same generic "
1612+
"parameters as original function generic signature '%1'",
1613+
(StringRef, StringRef))
16071614

16081615
//------------------------------------------------------------------------------
16091616
// MARK: Generics parsing diagnostics

include/swift/SIL/SILBuilder.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2157,6 +2157,20 @@ class SILBuilder {
21572157
SILValue emitThickToObjCMetatype(SILLocation Loc, SILValue Op, SILType Ty);
21582158
SILValue emitObjCToThickMetatype(SILLocation Loc, SILValue Op, SILType Ty);
21592159

2160+
//===--------------------------------------------------------------------===//
2161+
// Differentiable programming instructions
2162+
//===--------------------------------------------------------------------===//
2163+
2164+
/// Note: explicit function type may be specified only in lowered SIL.
2165+
DifferentiabilityWitnessFunctionInst *createDifferentiabilityWitnessFunction(
2166+
SILLocation Loc, DifferentiabilityWitnessFunctionKind WitnessKind,
2167+
SILDifferentiabilityWitness *Witness,
2168+
Optional<SILType> FunctionType = None) {
2169+
return insert(new (getModule()) DifferentiabilityWitnessFunctionInst(
2170+
getModule(), getSILDebugLocation(Loc), WitnessKind, Witness,
2171+
FunctionType));
2172+
}
2173+
21602174
//===--------------------------------------------------------------------===//
21612175
// Private Helper Methods
21622176
//===--------------------------------------------------------------------===//

include/swift/SIL/SILCloner.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2825,6 +2825,16 @@ void SILCloner<ImplClass>::visitKeyPathInst(KeyPathInst *Inst) {
28252825
opValues, getOpType(Inst->getType())));
28262826
}
28272827

2828+
template <typename ImplClass>
2829+
void SILCloner<ImplClass>::visitDifferentiabilityWitnessFunctionInst(
2830+
DifferentiabilityWitnessFunctionInst *Inst) {
2831+
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
2832+
recordClonedInstruction(Inst,
2833+
getBuilder().createDifferentiabilityWitnessFunction(
2834+
getOpLocation(Inst->getLoc()),
2835+
Inst->getWitnessKind(), Inst->getWitness()));
2836+
}
2837+
28282838
} // end namespace swift
28292839

28302840
#endif

include/swift/SIL/SILInstruction.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#ifndef SWIFT_SIL_INSTRUCTION_H
1818
#define SWIFT_SIL_INSTRUCTION_H
1919

20+
#include "swift/AST/AutoDiff.h"
2021
#include "swift/AST/Builtins.h"
2122
#include "swift/AST/Decl.h"
2223
#include "swift/AST/GenericSignature.h"
@@ -61,6 +62,7 @@ class SILBasicBlock;
6162
class SILBuilder;
6263
class SILDebugLocation;
6364
class SILDebugScope;
65+
class SILDifferentiabilityWitness;
6466
class SILFunction;
6567
class SILGlobalVariable;
6668
class SILInstructionResultArray;
@@ -7931,6 +7933,40 @@ class TryApplyInst final
79317933
const GenericSpecializationInformation *SpecializationInfo);
79327934
};
79337935

7936+
class DifferentiabilityWitnessFunctionInst
7937+
: public InstructionBase<
7938+
SILInstructionKind::DifferentiabilityWitnessFunctionInst,
7939+
SingleValueInstruction> {
7940+
private:
7941+
friend SILBuilder;
7942+
/// The differentiability witness function kind.
7943+
DifferentiabilityWitnessFunctionKind witnessKind;
7944+
/// The referenced SIL differentiability witness.
7945+
SILDifferentiabilityWitness *witness;
7946+
/// Whether the instruction has an explicit function type.
7947+
bool hasExplicitFunctionType;
7948+
7949+
static SILType getDifferentiabilityWitnessType(
7950+
SILModule &module, DifferentiabilityWitnessFunctionKind witnessKind,
7951+
SILDifferentiabilityWitness *witness);
7952+
7953+
public:
7954+
/// Note: explicit function type may be specified only in lowered SIL.
7955+
DifferentiabilityWitnessFunctionInst(
7956+
SILModule &module, SILDebugLocation loc,
7957+
DifferentiabilityWitnessFunctionKind witnessKind,
7958+
SILDifferentiabilityWitness *witness, Optional<SILType> FunctionType);
7959+
7960+
DifferentiabilityWitnessFunctionKind getWitnessKind() const {
7961+
return witnessKind;
7962+
}
7963+
SILDifferentiabilityWitness *getWitness() const { return witness; }
7964+
bool getHasExplicitFunctionType() const { return hasExplicitFunctionType; }
7965+
7966+
ArrayRef<Operand> getAllOperands() const { return {}; }
7967+
MutableArrayRef<Operand> getAllOperands() { return {}; }
7968+
};
7969+
79347970
// This is defined out of line to work around the fact that this depends on
79357971
// PartialApplyInst being defined, but PartialApplyInst is a subclass of
79367972
// ApplyInstBase, so we can not place ApplyInstBase after it.

include/swift/SIL/SILNodes.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,11 @@ ABSTRACT_VALUE_AND_INST(SingleValueInstruction, ValueBase, SILInstruction)
688688
SINGLE_VALUE_INST(InitBlockStorageHeaderInst, init_block_storage_header,
689689
SingleValueInstruction, None, DoesNotRelease)
690690

691+
// Differentiable programming
692+
SINGLE_VALUE_INST(DifferentiabilityWitnessFunctionInst,
693+
differentiability_witness_function,
694+
SingleValueInstruction, None, DoesNotRelease)
695+
691696
// Key paths
692697
// TODO: The only "side effect" is potentially retaining the returned key path
693698
// object; is there a more specific effect?

lib/AST/AutoDiff.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//
33
// This source file is part of the Swift.org open source project
44
//
5-
// Copyright (c) 2019 Apple Inc. and the Swift project authors
5+
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
66
// Licensed under Apache License v2.0 with Runtime Library Exception
77
//
88
// See https://swift.org/LICENSE.txt for license information
@@ -12,11 +12,34 @@
1212

1313
#include "swift/AST/AutoDiff.h"
1414
#include "swift/AST/ASTContext.h"
15+
#include "swift/AST/Module.h"
1516
#include "swift/AST/TypeCheckRequests.h"
1617
#include "swift/AST/Types.h"
1718

1819
using namespace swift;
1920

21+
DifferentiabilityWitnessFunctionKind::DifferentiabilityWitnessFunctionKind(
22+
StringRef string) {
23+
Optional<innerty> result = llvm::StringSwitch<Optional<innerty>>(string)
24+
.Case("jvp", JVP)
25+
.Case("vjp", VJP)
26+
.Case("transpose", Transpose);
27+
assert(result && "Invalid string");
28+
rawValue = *result;
29+
}
30+
31+
Optional<AutoDiffDerivativeFunctionKind>
32+
DifferentiabilityWitnessFunctionKind::getAsDerivativeFunctionKind() const {
33+
switch (rawValue) {
34+
case JVP:
35+
return {AutoDiffDerivativeFunctionKind::JVP};
36+
case VJP:
37+
return {AutoDiffDerivativeFunctionKind::VJP};
38+
case Transpose:
39+
return None;
40+
}
41+
}
42+
2043
void AutoDiffConfig::print(llvm::raw_ostream &s) const {
2144
s << "(parameters=";
2245
parameterIndices->print(s);

lib/IRGen/IRGenSIL.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,6 +1042,9 @@ class IRGenSILFunction :
10421042

10431043
void visitKeyPathInst(KeyPathInst *I);
10441044

1045+
void visitDifferentiabilityWitnessFunctionInst(
1046+
DifferentiabilityWitnessFunctionInst *i);
1047+
10451048
#define LOADABLE_REF_STORAGE_HELPER(Name) \
10461049
void visitRefTo##Name##Inst(RefTo##Name##Inst *i); \
10471050
void visit##Name##ToRefInst(Name##ToRefInst *i); \
@@ -1815,6 +1818,33 @@ void IRGenSILFunction::visitSILBasicBlock(SILBasicBlock *BB) {
18151818
assert(Builder.hasPostTerminatorIP() && "SIL bb did not terminate block?!");
18161819
}
18171820

1821+
void IRGenSILFunction::visitDifferentiabilityWitnessFunctionInst(
1822+
DifferentiabilityWitnessFunctionInst *i) {
1823+
llvm::Value *diffWitness =
1824+
IGM.getAddrOfDifferentiabilityWitness(i->getWitness());
1825+
unsigned offset = 0;
1826+
switch (i->getWitnessKind()) {
1827+
case DifferentiabilityWitnessFunctionKind::JVP:
1828+
offset = 0;
1829+
break;
1830+
case DifferentiabilityWitnessFunctionKind::VJP:
1831+
offset = 1;
1832+
break;
1833+
case DifferentiabilityWitnessFunctionKind::Transpose:
1834+
llvm_unreachable("Not yet implemented");
1835+
}
1836+
1837+
diffWitness = Builder.CreateStructGEP(diffWitness, offset);
1838+
diffWitness = Builder.CreateLoad(diffWitness, IGM.getPointerAlignment());
1839+
1840+
auto fnType = cast<SILFunctionType>(i->getType().getASTType());
1841+
Signature signature = IGM.getSignature(fnType);
1842+
diffWitness =
1843+
Builder.CreateBitCast(diffWitness, signature.getType()->getPointerTo());
1844+
1845+
setLoweredFunctionPointer(i, FunctionPointer(diffWitness, signature));
1846+
}
1847+
18181848
void IRGenSILFunction::visitFunctionRefBaseInst(FunctionRefBaseInst *i) {
18191849
auto fn = i->getInitiallyReferencedFunction();
18201850

lib/ParseSIL/ParseSIL.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5023,6 +5023,47 @@ bool SILParser::parseSpecificSILInstruction(SILBuilder &B,
50235023
blockType, subMap);
50245024
break;
50255025
}
5026+
case SILInstructionKind::DifferentiabilityWitnessFunctionInst: {
5027+
// e.g. differentiability_witness_function
5028+
// [jvp] [parameters 0 1] [results 0] <T where T: Differentiable>
5029+
// @foo : <T> $(T) -> T
5030+
DifferentiabilityWitnessFunctionKind witnessKind;
5031+
StringRef witnessKindNames[3] = {"jvp", "vjp", "transpose"};
5032+
if (P.parseToken(
5033+
tok::l_square,
5034+
diag::
5035+
sil_inst_autodiff_expected_differentiability_witness_kind) ||
5036+
parseSILIdentifierSwitch(
5037+
witnessKind, witnessKindNames,
5038+
diag::
5039+
sil_inst_autodiff_expected_differentiability_witness_kind) ||
5040+
P.parseToken(tok::r_square, diag::sil_autodiff_expected_rsquare,
5041+
"differentiability witness function kind"))
5042+
return true;
5043+
SourceLoc keyStartLoc = P.Tok.getLoc();
5044+
auto configAndFn =
5045+
parseSILDifferentiabilityWitnessConfigAndFunction(P, *this, InstLoc);
5046+
if (!configAndFn)
5047+
return true;
5048+
auto config = configAndFn->first;
5049+
auto originalFn = configAndFn->second;
5050+
auto *witness = SILMod.lookUpDifferentiabilityWitness(
5051+
{originalFn->getName(), config});
5052+
if (!witness) {
5053+
P.diagnose(keyStartLoc, diag::sil_diff_witness_undefined);
5054+
return true;
5055+
}
5056+
// Parse an optional explicit function type.
5057+
Optional<SILType> functionType = None;
5058+
if (P.consumeIf(tok::kw_as)) {
5059+
functionType = SILType();
5060+
if (parseSILType(*functionType))
5061+
return true;
5062+
}
5063+
ResultVal = B.createDifferentiabilityWitnessFunction(
5064+
InstLoc, witnessKind, witness, functionType);
5065+
break;
5066+
}
50265067
}
50275068

50285069
return false;

lib/SIL/OperandOwnership.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ SHOULD_NEVER_VISIT_INST(AllocBox)
116116
SHOULD_NEVER_VISIT_INST(AllocExistentialBox)
117117
SHOULD_NEVER_VISIT_INST(AllocGlobal)
118118
SHOULD_NEVER_VISIT_INST(AllocStack)
119+
SHOULD_NEVER_VISIT_INST(DifferentiabilityWitnessFunction)
119120
SHOULD_NEVER_VISIT_INST(FloatLiteral)
120121
SHOULD_NEVER_VISIT_INST(FunctionRef)
121122
SHOULD_NEVER_VISIT_INST(DynamicFunctionRef)

lib/SIL/SILInstructions.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,50 @@ TryApplyInst *TryApplyInst::create(
605605
normalBB, errorBB, specializationInfo);
606606
}
607607

608+
SILType DifferentiabilityWitnessFunctionInst::getDifferentiabilityWitnessType(
609+
SILModule &module, DifferentiabilityWitnessFunctionKind witnessKind,
610+
SILDifferentiabilityWitness *witness) {
611+
auto fnTy = witness->getOriginalFunction()->getLoweredFunctionType();
612+
CanGenericSignature witnessCanGenSig;
613+
if (auto witnessGenSig = witness->getDerivativeGenericSignature())
614+
witnessCanGenSig = witnessGenSig->getCanonicalSignature();
615+
auto *parameterIndices = witness->getParameterIndices();
616+
auto *resultIndices = witness->getResultIndices();
617+
if (auto derivativeKind = witnessKind.getAsDerivativeFunctionKind()) {
618+
bool isReabstractionThunk =
619+
witness->getOriginalFunction()->isThunk() == IsReabstractionThunk;
620+
auto diffFnTy = fnTy->getAutoDiffDerivativeFunctionType(
621+
parameterIndices, *resultIndices->begin(), *derivativeKind,
622+
module.Types, LookUpConformanceInModule(module.getSwiftModule()),
623+
witnessCanGenSig, isReabstractionThunk);
624+
return SILType::getPrimitiveObjectType(diffFnTy);
625+
}
626+
assert(witnessKind == DifferentiabilityWitnessFunctionKind::Transpose);
627+
auto transposeFnTy = fnTy->getAutoDiffTransposeFunctionType(
628+
parameterIndices, module.Types,
629+
LookUpConformanceInModule(module.getSwiftModule()), witnessCanGenSig);
630+
return SILType::getPrimitiveObjectType(transposeFnTy);
631+
}
632+
633+
DifferentiabilityWitnessFunctionInst::DifferentiabilityWitnessFunctionInst(
634+
SILModule &module, SILDebugLocation debugLoc,
635+
DifferentiabilityWitnessFunctionKind witnessKind,
636+
SILDifferentiabilityWitness *witness, Optional<SILType> functionType)
637+
: InstructionBase(debugLoc, functionType
638+
? *functionType
639+
: getDifferentiabilityWitnessType(
640+
module, witnessKind, witness)),
641+
witnessKind(witnessKind), witness(witness),
642+
hasExplicitFunctionType(functionType) {
643+
assert(witness && "Differentiability witness must not be null");
644+
#ifndef NDEBUG
645+
if (functionType.hasValue()) {
646+
assert(module.getStage() == SILStage::Lowered &&
647+
"Explicit type is valid only in lowered SIL");
648+
}
649+
#endif
650+
}
651+
608652
FunctionRefBaseInst::FunctionRefBaseInst(SILInstructionKind Kind,
609653
SILDebugLocation DebugLoc,
610654
SILFunction *F,

0 commit comments

Comments
 (0)