Skip to content

Commit 75691d6

Browse files
author
ematejska
authored
[AutoDiff upstream] Add linear function SIL instructions (#30638)
Add `linear_function` and `linear_function_extract` instructions. `linear_function` creates a `@differentiable(linear)` function-typed value from an original function operand and a transpose function operand (optional). `linear_function_extract` extracts either the original or transpose function value from a `@differentiable(linear)` function. Resolves TF-1142 and TF-1143.
1 parent 2fa4fbb commit 75691d6

20 files changed

+543
-1
lines changed

docs/SIL.rst

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5811,6 +5811,37 @@ In raw SIL, the ``with_derivative`` clause is optional. In canonical SIL, the
58115811
``with_derivative`` clause is mandatory.
58125812

58135813

5814+
linear_function
5815+
```````````````
5816+
::
5817+
5818+
sil-instruction ::= 'linear_function'
5819+
sil-linear-function-parameter-indices
5820+
sil-value ':' sil-type
5821+
sil-linear-function-transpose-function-clause?
5822+
5823+
sil-linear-function-parameter-indices ::=
5824+
'[' 'parameters' [0-9]+ (' ' [0-9]+)* ']'
5825+
sil-linear-transpose-function-clause ::=
5826+
with_transpose sil-value ':' sil-type
5827+
5828+
linear_function [parameters 0] %0 : $(T) -> T with_transpose %1 : $(T) -> T
5829+
5830+
Bundles a function with its transpose function into a
5831+
``@differentiable(linear)`` function.
5832+
5833+
``[parameters ...]`` specifies parameter indices that the original function is
5834+
linear with respect to.
5835+
5836+
A ``with_transpose`` clause specifies the transpose function associated
5837+
with the original function. When a ``with_transpose`` clause is not specified,
5838+
the mandatory differentiation transform will add a ``with_transpose`` clause to
5839+
the instruction.
5840+
5841+
In raw SIL, the ``with_transpose`` clause is optional. In canonical SIL,
5842+
the ``with_transpose`` clause is mandatory.
5843+
5844+
58145845
differentiable_function_extract
58155846
```````````````````````````````
58165847
::
@@ -5835,6 +5866,25 @@ Extracts the original function or a derivative function from the given
58355866
In lowered SIL, an explicit extractee type may be provided. This is currently
58365867
used by the LoadableByAddress transformation, which rewrites function types.
58375868

5869+
5870+
linear_function_extract
5871+
```````````````````````
5872+
::
5873+
5874+
sil-instruction ::= 'linear_function_extract'
5875+
'[' sil-linear-function-extractee ']'
5876+
sil-value ':' sil-type
5877+
5878+
sil-linear-function-extractee ::= 'original' | 'transpose'
5879+
5880+
linear_function_extract [original] %0 : $@differentiable(linear) (T) -> T
5881+
linear_function_extract [transpose] %0 : $@differentiable(linear) (T) -> T
5882+
5883+
Extracts the original function or a transpose function from the given
5884+
``@differentiable(linear)`` function. The extractee is one of the following:
5885+
``[original]`` or ``[transpose]``.
5886+
5887+
58385888
differentiability_witness_function
58395889
``````````````````````````````````
58405890
::

include/swift/AST/DiagnosticsParse.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1627,6 +1627,9 @@ ERROR(sil_inst_autodiff_operand_list_expected_rbrace,PointsToFirstBadToken,
16271627
ERROR(sil_inst_autodiff_expected_differentiable_extractee_kind,PointsToFirstBadToken,
16281628
"expected an extractee kind attribute, which can be one of '[original]', "
16291629
"'[jvp]', and '[vjp]'", ())
1630+
ERROR(sil_inst_autodiff_expected_linear_extractee_kind,PointsToFirstBadToken,
1631+
"expected an extractee kind attribute, which can be one of '[original]' "
1632+
"and '[transpose]'", ())
16301633
ERROR(sil_inst_autodiff_expected_function_type_operand,PointsToFirstBadToken,
16311634
"expected an operand of a function type", ())
16321635
ERROR(sil_inst_autodiff_expected_differentiability_witness_kind,PointsToFirstBadToken,

include/swift/SIL/SILBuilder.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2172,6 +2172,14 @@ class SILBuilder {
21722172
OriginalFunction, JVPAndVJPFunctions, hasOwnership()));
21732173
}
21742174

2175+
LinearFunctionInst *createLinearFunction(
2176+
SILLocation Loc, IndexSubset *ParameterIndices, SILValue OriginalFunction,
2177+
Optional<SILValue> TransposeFunction = None) {
2178+
return insert(LinearFunctionInst::create(
2179+
getModule(), getSILDebugLocation(Loc), ParameterIndices,
2180+
OriginalFunction, TransposeFunction, hasOwnership()));
2181+
}
2182+
21752183
/// Note: explicit extractee type may be specified only in lowered SIL.
21762184
DifferentiableFunctionExtractInst *createDifferentiableFunctionExtract(
21772185
SILLocation Loc, NormalDifferentiableFunctionTypeComponent Extractee,
@@ -2181,6 +2189,13 @@ class SILBuilder {
21812189
ExtracteeType));
21822190
}
21832191

2192+
LinearFunctionExtractInst *createLinearFunctionExtract(
2193+
SILLocation Loc, LinearDifferentiableFunctionTypeComponent Extractee,
2194+
SILValue TheFunction) {
2195+
return insert(new (getModule()) LinearFunctionExtractInst(
2196+
getModule(), getSILDebugLocation(Loc), Extractee, TheFunction));
2197+
}
2198+
21842199
/// Note: explicit function type may be specified only in lowered SIL.
21852200
DifferentiabilityWitnessFunctionInst *createDifferentiabilityWitnessFunction(
21862201
SILLocation Loc, DifferentiabilityWitnessFunctionKind WitnessKind,

include/swift/SIL/SILCloner.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2841,6 +2841,18 @@ void SILCloner<ImplClass>::visitDifferentiableFunctionInst(
28412841
getOpValue(Inst->getOriginalFunction()), derivativeFns));
28422842
}
28432843

2844+
template<typename ImplClass>
2845+
void SILCloner<ImplClass>::visitLinearFunctionInst(LinearFunctionInst *Inst) {
2846+
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
2847+
auto transpose = Inst->getOptionalTransposeFunction();
2848+
if (transpose)
2849+
transpose = getOpValue(*transpose);
2850+
recordClonedInstruction(
2851+
Inst, getBuilder().createLinearFunction(
2852+
getOpLocation(Inst->getLoc()), Inst->getParameterIndices(),
2853+
getOpValue(Inst->getOriginalFunction()), transpose));
2854+
}
2855+
28442856
template <typename ImplClass>
28452857
void SILCloner<ImplClass>::visitDifferentiableFunctionExtractInst(
28462858
DifferentiableFunctionExtractInst *Inst) {
@@ -2854,6 +2866,16 @@ void SILCloner<ImplClass>::visitDifferentiableFunctionExtractInst(
28542866
getOpValue(Inst->getOperand()), explicitExtracteeType));
28552867
}
28562868

2869+
template<typename ImplClass>
2870+
void SILCloner<ImplClass>::
2871+
visitLinearFunctionExtractInst(LinearFunctionExtractInst *Inst) {
2872+
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
2873+
recordClonedInstruction(
2874+
Inst, getBuilder().createLinearFunctionExtract(
2875+
getOpLocation(Inst->getLoc()), Inst->getExtractee(),
2876+
getOpValue(Inst->getFunctionOperand())));
2877+
}
2878+
28572879
template <typename ImplClass>
28582880
void SILCloner<ImplClass>::visitDifferentiabilityWitnessFunctionInst(
28592881
DifferentiabilityWitnessFunctionInst *Inst) {

include/swift/SIL/SILInstruction.h

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8057,6 +8057,45 @@ class DifferentiableFunctionInst final
80578057
}
80588058
};
80598059

8060+
/// LinearFunctionInst - given a function, its derivative and traspose functions,
8061+
/// create an `@differentiable(linear)` function that represents a bundle of these.
8062+
class LinearFunctionInst final :
8063+
public InstructionBaseWithTrailingOperands<
8064+
SILInstructionKind::LinearFunctionInst,
8065+
LinearFunctionInst, OwnershipForwardingSingleValueInst> {
8066+
private:
8067+
friend SILBuilder;
8068+
/// Parameters to differentiate with respect to.
8069+
IndexSubset *ParameterIndices;
8070+
/// Indicates whether a transpose function exists.
8071+
bool HasTransposeFunction;
8072+
8073+
static SILType getLinearFunctionType(
8074+
SILValue OriginalFunction, IndexSubset *ParameterIndices);
8075+
8076+
public:
8077+
LinearFunctionInst(SILDebugLocation Loc, IndexSubset *ParameterIndices,
8078+
SILValue OriginalFunction,
8079+
Optional<SILValue> TransposeFunction, bool HasOwnership);
8080+
8081+
static LinearFunctionInst *create(SILModule &Module, SILDebugLocation Loc,
8082+
IndexSubset *ParameterIndices,
8083+
SILValue OriginalFunction,
8084+
Optional<SILValue> TransposeFunction,
8085+
bool HasOwnership);
8086+
8087+
IndexSubset *getParameterIndices() const { return ParameterIndices; }
8088+
bool hasTransposeFunction() const { return HasTransposeFunction; }
8089+
SILValue getOriginalFunction() const { return getOperand(0); }
8090+
Optional<SILValue> getOptionalTransposeFunction() const {
8091+
return HasTransposeFunction ? Optional<SILValue>(getOperand(1)) : None;
8092+
}
8093+
SILValue getTransposeFunction() const {
8094+
assert(HasTransposeFunction);
8095+
return getOperand(1);
8096+
}
8097+
};
8098+
80608099
/// DifferentiableFunctionExtractInst - extracts either the original or
80618100
/// derivative function value from a `@differentiable` function.
80628101
class DifferentiableFunctionExtractInst
@@ -8094,6 +8133,39 @@ class DifferentiableFunctionExtractInst
80948133
bool hasExplicitExtracteeType() const { return HasExplicitExtracteeType; }
80958134
};
80968135

8136+
/// LinearFunctionExtractInst - given an `@differentiable(linear)` function
8137+
/// representing a bundle of the original function and the transpose function,
8138+
/// extract the specified function.
8139+
class LinearFunctionExtractInst
8140+
: public InstructionBase<
8141+
SILInstructionKind::LinearFunctionExtractInst,
8142+
SingleValueInstruction> {
8143+
private:
8144+
/// The extractee.
8145+
LinearDifferentiableFunctionTypeComponent extractee;
8146+
/// The list containing the `@differentiable(linear)` function operand.
8147+
FixedOperandList<1> operands;
8148+
8149+
static SILType
8150+
getExtracteeType(SILValue function,
8151+
LinearDifferentiableFunctionTypeComponent extractee,
8152+
SILModule &module);
8153+
8154+
public:
8155+
explicit LinearFunctionExtractInst(
8156+
SILModule &module, SILDebugLocation debugLoc,
8157+
LinearDifferentiableFunctionTypeComponent extractee,
8158+
SILValue theFunction);
8159+
8160+
LinearDifferentiableFunctionTypeComponent getExtractee() const {
8161+
return extractee;
8162+
}
8163+
8164+
SILValue getFunctionOperand() const { return operands[0].get(); }
8165+
ArrayRef<Operand> getAllOperands() const { return operands.asArray(); }
8166+
MutableArrayRef<Operand> getAllOperands() { return operands.asArray(); }
8167+
};
8168+
80978169
/// DifferentiabilityWitnessFunctionInst - Looks up a differentiability witness
80988170
/// function for a given original function.
80998171
class DifferentiabilityWitnessFunctionInst

include/swift/SIL/SILNodes.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,9 +694,14 @@ ABSTRACT_VALUE_AND_INST(SingleValueInstruction, ValueBase, SILInstruction)
694694
// Differentiable programming
695695
SINGLE_VALUE_INST(DifferentiableFunctionInst, differentiable_function,
696696
SingleValueInstruction, None, DoesNotRelease)
697+
SINGLE_VALUE_INST(LinearFunctionInst, linear_function,
698+
SingleValueInstruction, None, DoesNotRelease)
697699
SINGLE_VALUE_INST(DifferentiableFunctionExtractInst,
698700
differentiable_function_extract,
699701
SingleValueInstruction, None, DoesNotRelease)
702+
SINGLE_VALUE_INST(LinearFunctionExtractInst,
703+
linear_function_extract,
704+
SingleValueInstruction, None, DoesNotRelease)
700705
SINGLE_VALUE_INST(DifferentiabilityWitnessFunctionInst,
701706
differentiability_witness_function,
702707
SingleValueInstruction, None, DoesNotRelease)

lib/IRGen/IRGenSIL.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,8 +1045,10 @@ class IRGenSILFunction :
10451045
void visitKeyPathInst(KeyPathInst *I);
10461046

10471047
void visitDifferentiableFunctionInst(DifferentiableFunctionInst *i);
1048+
void visitLinearFunctionInst(LinearFunctionInst *i);
10481049
void
10491050
visitDifferentiableFunctionExtractInst(DifferentiableFunctionExtractInst *i);
1051+
void visitLinearFunctionExtractInst(LinearFunctionExtractInst *i);
10501052
void visitDifferentiabilityWitnessFunctionInst(
10511053
DifferentiabilityWitnessFunctionInst *i);
10521054

@@ -1856,6 +1858,16 @@ void IRGenSILFunction::visitDifferentiableFunctionInst(
18561858
setLoweredExplosion(i, e);
18571859
}
18581860

1861+
void IRGenSILFunction::
1862+
visitLinearFunctionInst(LinearFunctionInst *i) {
1863+
auto origExp = getLoweredExplosion(i->getOriginalFunction());
1864+
Explosion e;
1865+
e.add(origExp.claimAll());
1866+
assert(i->hasTransposeFunction());
1867+
e.add(getLoweredExplosion(i->getTransposeFunction()).claimAll());
1868+
setLoweredExplosion(i, e);
1869+
}
1870+
18591871
void IRGenSILFunction::visitDifferentiableFunctionExtractInst(
18601872
DifferentiableFunctionExtractInst *i) {
18611873
unsigned structFieldOffset = i->getExtractee().rawValue;
@@ -1873,6 +1885,23 @@ void IRGenSILFunction::visitDifferentiableFunctionExtractInst(
18731885
setLoweredExplosion(i, e);
18741886
}
18751887

1888+
void IRGenSILFunction::
1889+
visitLinearFunctionExtractInst(LinearFunctionExtractInst *i) {
1890+
unsigned structFieldOffset = i->getExtractee().rawValue;
1891+
unsigned fieldSize = 1;
1892+
auto fnRepr = i->getFunctionOperand()->getType().getFunctionRepresentation();
1893+
if (fnRepr == SILFunctionTypeRepresentation::Thick) {
1894+
structFieldOffset *= 2;
1895+
fieldSize = 2;
1896+
}
1897+
auto diffFnExp = getLoweredExplosion(i->getFunctionOperand());
1898+
assert(diffFnExp.size() == fieldSize * 2);
1899+
Explosion e;
1900+
e.add(diffFnExp.getRange(structFieldOffset, structFieldOffset + fieldSize));
1901+
(void)diffFnExp.claimAll();
1902+
setLoweredExplosion(i, e);
1903+
}
1904+
18761905
void IRGenSILFunction::visitDifferentiabilityWitnessFunctionInst(
18771906
DifferentiabilityWitnessFunctionInst *i) {
18781907
llvm::Value *diffWitness =

lib/ParseSIL/ParseSIL.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5072,6 +5072,41 @@ bool SILParser::parseSpecificSILInstruction(SILBuilder &B,
50725072
InstLoc, parameterIndicesSubset, original, derivativeFunctions);
50735073
break;
50745074
}
5075+
case SILInstructionKind::LinearFunctionInst: {
5076+
// e.g. linear_function [parameters 0 1 2] %0 : $T
5077+
// e.g. linear_function [parameters 0 1 2] %0 : $T with_transpose %1 : $T
5078+
// Parse `[parameters <integer_literal>...]`.
5079+
SmallVector<unsigned, 8> parameterIndices;
5080+
if (parseIndexList(P, "parameters", parameterIndices,
5081+
diag::sil_autodiff_expected_parameter_index))
5082+
return true;
5083+
// Parse the original function value.
5084+
SILValue original;
5085+
SourceLoc originalOperandLoc;
5086+
if (parseTypedValueRef(original, originalOperandLoc, B))
5087+
return true;
5088+
auto fnType = original->getType().getAs<SILFunctionType>();
5089+
if (!fnType) {
5090+
P.diagnose(originalOperandLoc,
5091+
diag::sil_inst_autodiff_expected_function_type_operand);
5092+
return true;
5093+
}
5094+
// Parse an optional transpose function.
5095+
Optional<SILValue> transpose = None;
5096+
if (P.Tok.is(tok::identifier) && P.Tok.getText() == "with_transpose") {
5097+
P.consumeToken(tok::identifier);
5098+
transpose = SILValue();
5099+
if (parseTypedValueRef(*transpose, B))
5100+
return true;
5101+
}
5102+
if (parseSILDebugLocation(InstLoc, B))
5103+
return true;
5104+
auto *parameterIndicesSubset = IndexSubset::get(
5105+
P.Context, fnType->getNumParameters(), parameterIndices);
5106+
ResultVal = B.createLinearFunction(
5107+
InstLoc, parameterIndicesSubset, original, transpose);
5108+
break;
5109+
}
50755110
case SILInstructionKind::DifferentiableFunctionExtractInst: {
50765111
// Parse the rest of the instruction: an extractee, a differentiable
50775112
// function operand, an optional explicit extractee type, and a debug
@@ -5104,6 +5139,27 @@ bool SILParser::parseSpecificSILInstruction(SILBuilder &B,
51045139
InstLoc, extractee, functionOperand, extracteeType);
51055140
break;
51065141
}
5142+
case SILInstructionKind::LinearFunctionExtractInst: {
5143+
// Parse the rest of the instruction: an extractee, a linear function
5144+
// operand, and a debug location.
5145+
LinearDifferentiableFunctionTypeComponent extractee;
5146+
StringRef extracteeNames[2] = {"original", "transpose"};
5147+
SILValue functionOperand;
5148+
SourceLoc lastLoc;
5149+
if (P.parseToken(tok::l_square,
5150+
diag::sil_inst_autodiff_expected_linear_extractee_kind) ||
5151+
parseSILIdentifierSwitch(extractee, extracteeNames,
5152+
diag::sil_inst_autodiff_expected_linear_extractee_kind) ||
5153+
P.parseToken(tok::r_square, diag::sil_autodiff_expected_rsquare,
5154+
"extractee kind"))
5155+
return true;
5156+
if (parseTypedValueRef(functionOperand, B) ||
5157+
parseSILDebugLocation(InstLoc, B))
5158+
return true;
5159+
ResultVal = B.createLinearFunctionExtract(
5160+
InstLoc, extractee, functionOperand);
5161+
break;
5162+
}
51075163
case SILInstructionKind::DifferentiabilityWitnessFunctionInst: {
51085164
// e.g. differentiability_witness_function
51095165
// [jvp] [parameters 0 1] [results 0] <T where T: Differentiable>

lib/SIL/OperandOwnership.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ FORWARD_ANY_OWNERSHIP_INST(DestructureStruct)
349349
FORWARD_ANY_OWNERSHIP_INST(DestructureTuple)
350350
FORWARD_ANY_OWNERSHIP_INST(InitExistentialRef)
351351
FORWARD_ANY_OWNERSHIP_INST(DifferentiableFunction)
352+
FORWARD_ANY_OWNERSHIP_INST(LinearFunction)
352353
#undef FORWARD_ANY_OWNERSHIP_INST
353354

354355
// An instruction that forwards a constant ownership or trivial ownership.
@@ -369,6 +370,8 @@ FORWARD_CONSTANT_OR_NONE_OWNERSHIP_INST(Guaranteed, MustBeLive, TupleExtract)
369370
FORWARD_CONSTANT_OR_NONE_OWNERSHIP_INST(Guaranteed, MustBeLive, StructExtract)
370371
FORWARD_CONSTANT_OR_NONE_OWNERSHIP_INST(Guaranteed, MustBeLive,
371372
DifferentiableFunctionExtract)
373+
FORWARD_CONSTANT_OR_NONE_OWNERSHIP_INST(Guaranteed, MustBeLive,
374+
LinearFunctionExtract)
372375
FORWARD_CONSTANT_OR_NONE_OWNERSHIP_INST(Owned, MustBeInvalidated,
373376
MarkUninitialized)
374377
#undef CONSTANT_OR_NONE_OWNERSHIP_INST

0 commit comments

Comments
 (0)