Skip to content

[AutoDiff upstream] [SIL] Add the 'linear_function' and 'linear_function_extract' instructions #30638

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions docs/SIL.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5811,6 +5811,37 @@ In raw SIL, the ``with_derivative`` clause is optional. In canonical SIL, the
``with_derivative`` clause is mandatory.


linear_function
```````````````
::

sil-instruction ::= 'linear_function'
sil-linear-function-parameter-indices
sil-value ':' sil-type
sil-linear-function-transpose-function-clause?

sil-linear-function-parameter-indices ::=
'[' 'parameters' [0-9]+ (' ' [0-9]+)* ']'
sil-linear-transpose-function-clause ::=
with_transpose sil-value ':' sil-type

linear_function [parameters 0] %0 : $(T) -> T with_transpose %1 : $(T) -> T

Bundles a function with its transpose function into a
``@differentiable(linear)`` function.

``[parameters ...]`` specifies parameter indices that the original function is
linear with respect to.

A ``with_transpose`` clause specifies the transpose function associated
with the original function. When a ``with_transpose`` clause is not specified,
the mandatory differentiation transform will add a ``with_transpose`` clause to
the instruction.

In raw SIL, the ``with_transpose`` clause is optional. In canonical SIL,
the ``with_transpose`` clause is mandatory.


differentiable_function_extract
```````````````````````````````
::
Expand All @@ -5835,6 +5866,25 @@ Extracts the original function or a derivative function from the given
In lowered SIL, an explicit extractee type may be provided. This is currently
used by the LoadableByAddress transformation, which rewrites function types.


linear_function_extract
```````````````````````
::

sil-instruction ::= 'linear_function_extract'
'[' sil-linear-function-extractee ']'
sil-value ':' sil-type

sil-linear-function-extractee ::= 'original' | 'transpose'

linear_function_extract [original] %0 : $@differentiable(linear) (T) -> T
linear_function_extract [transpose] %0 : $@differentiable(linear) (T) -> T

Extracts the original function or a transpose function from the given
``@differentiable(linear)`` function. The extractee is one of the following:
``[original]`` or ``[transpose]``.


differentiability_witness_function
``````````````````````````````````
::
Expand Down
3 changes: 3 additions & 0 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1627,6 +1627,9 @@ ERROR(sil_inst_autodiff_operand_list_expected_rbrace,PointsToFirstBadToken,
ERROR(sil_inst_autodiff_expected_differentiable_extractee_kind,PointsToFirstBadToken,
"expected an extractee kind attribute, which can be one of '[original]', "
"'[jvp]', and '[vjp]'", ())
ERROR(sil_inst_autodiff_expected_linear_extractee_kind,PointsToFirstBadToken,
"expected an extractee kind attribute, which can be one of '[original]' "
"and '[transpose]'", ())
ERROR(sil_inst_autodiff_expected_function_type_operand,PointsToFirstBadToken,
"expected an operand of a function type", ())
ERROR(sil_inst_autodiff_expected_differentiability_witness_kind,PointsToFirstBadToken,
Expand Down
15 changes: 15 additions & 0 deletions include/swift/SIL/SILBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -2172,6 +2172,14 @@ class SILBuilder {
OriginalFunction, JVPAndVJPFunctions, hasOwnership()));
}

LinearFunctionInst *createLinearFunction(
SILLocation Loc, IndexSubset *ParameterIndices, SILValue OriginalFunction,
Optional<SILValue> TransposeFunction = None) {
return insert(LinearFunctionInst::create(
getModule(), getSILDebugLocation(Loc), ParameterIndices,
OriginalFunction, TransposeFunction, hasOwnership()));
}

/// Note: explicit extractee type may be specified only in lowered SIL.
DifferentiableFunctionExtractInst *createDifferentiableFunctionExtract(
SILLocation Loc, NormalDifferentiableFunctionTypeComponent Extractee,
Expand All @@ -2181,6 +2189,13 @@ class SILBuilder {
ExtracteeType));
}

LinearFunctionExtractInst *createLinearFunctionExtract(
SILLocation Loc, LinearDifferentiableFunctionTypeComponent Extractee,
SILValue TheFunction) {
return insert(new (getModule()) LinearFunctionExtractInst(
getModule(), getSILDebugLocation(Loc), Extractee, TheFunction));
}

/// Note: explicit function type may be specified only in lowered SIL.
DifferentiabilityWitnessFunctionInst *createDifferentiabilityWitnessFunction(
SILLocation Loc, DifferentiabilityWitnessFunctionKind WitnessKind,
Expand Down
22 changes: 22 additions & 0 deletions include/swift/SIL/SILCloner.h
Original file line number Diff line number Diff line change
Expand Up @@ -2841,6 +2841,18 @@ void SILCloner<ImplClass>::visitDifferentiableFunctionInst(
getOpValue(Inst->getOriginalFunction()), derivativeFns));
}

template<typename ImplClass>
void SILCloner<ImplClass>::visitLinearFunctionInst(LinearFunctionInst *Inst) {
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
auto transpose = Inst->getOptionalTransposeFunction();
if (transpose)
transpose = getOpValue(*transpose);
recordClonedInstruction(
Inst, getBuilder().createLinearFunction(
getOpLocation(Inst->getLoc()), Inst->getParameterIndices(),
getOpValue(Inst->getOriginalFunction()), transpose));
}

template <typename ImplClass>
void SILCloner<ImplClass>::visitDifferentiableFunctionExtractInst(
DifferentiableFunctionExtractInst *Inst) {
Expand All @@ -2854,6 +2866,16 @@ void SILCloner<ImplClass>::visitDifferentiableFunctionExtractInst(
getOpValue(Inst->getOperand()), explicitExtracteeType));
}

template<typename ImplClass>
void SILCloner<ImplClass>::
visitLinearFunctionExtractInst(LinearFunctionExtractInst *Inst) {
getBuilder().setCurrentDebugScope(getOpScope(Inst->getDebugScope()));
recordClonedInstruction(
Inst, getBuilder().createLinearFunctionExtract(
getOpLocation(Inst->getLoc()), Inst->getExtractee(),
getOpValue(Inst->getFunctionOperand())));
}

template <typename ImplClass>
void SILCloner<ImplClass>::visitDifferentiabilityWitnessFunctionInst(
DifferentiabilityWitnessFunctionInst *Inst) {
Expand Down
72 changes: 72 additions & 0 deletions include/swift/SIL/SILInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -8057,6 +8057,45 @@ class DifferentiableFunctionInst final
}
};

/// LinearFunctionInst - given a function, its derivative and traspose functions,
/// create an `@differentiable(linear)` function that represents a bundle of these.
class LinearFunctionInst final :
public InstructionBaseWithTrailingOperands<
SILInstructionKind::LinearFunctionInst,
LinearFunctionInst, OwnershipForwardingSingleValueInst> {
private:
friend SILBuilder;
/// Parameters to differentiate with respect to.
IndexSubset *ParameterIndices;
/// Indicates whether a transpose function exists.
bool HasTransposeFunction;

static SILType getLinearFunctionType(
SILValue OriginalFunction, IndexSubset *ParameterIndices);

public:
LinearFunctionInst(SILDebugLocation Loc, IndexSubset *ParameterIndices,
SILValue OriginalFunction,
Optional<SILValue> TransposeFunction, bool HasOwnership);

static LinearFunctionInst *create(SILModule &Module, SILDebugLocation Loc,
IndexSubset *ParameterIndices,
SILValue OriginalFunction,
Optional<SILValue> TransposeFunction,
bool HasOwnership);

IndexSubset *getParameterIndices() const { return ParameterIndices; }
bool hasTransposeFunction() const { return HasTransposeFunction; }
SILValue getOriginalFunction() const { return getOperand(0); }
Optional<SILValue> getOptionalTransposeFunction() const {
return HasTransposeFunction ? Optional<SILValue>(getOperand(1)) : None;
}
SILValue getTransposeFunction() const {
assert(HasTransposeFunction);
return getOperand(1);
}
};

/// DifferentiableFunctionExtractInst - extracts either the original or
/// derivative function value from a `@differentiable` function.
class DifferentiableFunctionExtractInst
Expand Down Expand Up @@ -8094,6 +8133,39 @@ class DifferentiableFunctionExtractInst
bool hasExplicitExtracteeType() const { return HasExplicitExtracteeType; }
};

/// LinearFunctionExtractInst - given an `@differentiable(linear)` function
/// representing a bundle of the original function and the transpose function,
/// extract the specified function.
class LinearFunctionExtractInst
: public InstructionBase<
SILInstructionKind::LinearFunctionExtractInst,
SingleValueInstruction> {
private:
/// The extractee.
LinearDifferentiableFunctionTypeComponent extractee;
/// The list containing the `@differentiable(linear)` function operand.
FixedOperandList<1> operands;

static SILType
getExtracteeType(SILValue function,
LinearDifferentiableFunctionTypeComponent extractee,
SILModule &module);

public:
explicit LinearFunctionExtractInst(
SILModule &module, SILDebugLocation debugLoc,
LinearDifferentiableFunctionTypeComponent extractee,
SILValue theFunction);

LinearDifferentiableFunctionTypeComponent getExtractee() const {
return extractee;
}

SILValue getFunctionOperand() const { return operands[0].get(); }
ArrayRef<Operand> getAllOperands() const { return operands.asArray(); }
MutableArrayRef<Operand> getAllOperands() { return operands.asArray(); }
};

/// DifferentiabilityWitnessFunctionInst - Looks up a differentiability witness
/// function for a given original function.
class DifferentiabilityWitnessFunctionInst
Expand Down
5 changes: 5 additions & 0 deletions include/swift/SIL/SILNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -694,9 +694,14 @@ ABSTRACT_VALUE_AND_INST(SingleValueInstruction, ValueBase, SILInstruction)
// Differentiable programming
SINGLE_VALUE_INST(DifferentiableFunctionInst, differentiable_function,
SingleValueInstruction, None, DoesNotRelease)
SINGLE_VALUE_INST(LinearFunctionInst, linear_function,
SingleValueInstruction, None, DoesNotRelease)
SINGLE_VALUE_INST(DifferentiableFunctionExtractInst,
differentiable_function_extract,
SingleValueInstruction, None, DoesNotRelease)
SINGLE_VALUE_INST(LinearFunctionExtractInst,
linear_function_extract,
SingleValueInstruction, None, DoesNotRelease)
SINGLE_VALUE_INST(DifferentiabilityWitnessFunctionInst,
differentiability_witness_function,
SingleValueInstruction, None, DoesNotRelease)
Expand Down
29 changes: 29 additions & 0 deletions lib/IRGen/IRGenSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1045,8 +1045,10 @@ class IRGenSILFunction :
void visitKeyPathInst(KeyPathInst *I);

void visitDifferentiableFunctionInst(DifferentiableFunctionInst *i);
void visitLinearFunctionInst(LinearFunctionInst *i);
void
visitDifferentiableFunctionExtractInst(DifferentiableFunctionExtractInst *i);
void visitLinearFunctionExtractInst(LinearFunctionExtractInst *i);
void visitDifferentiabilityWitnessFunctionInst(
DifferentiabilityWitnessFunctionInst *i);

Expand Down Expand Up @@ -1856,6 +1858,16 @@ void IRGenSILFunction::visitDifferentiableFunctionInst(
setLoweredExplosion(i, e);
}

void IRGenSILFunction::
visitLinearFunctionInst(LinearFunctionInst *i) {
auto origExp = getLoweredExplosion(i->getOriginalFunction());
Explosion e;
e.add(origExp.claimAll());
assert(i->hasTransposeFunction());
e.add(getLoweredExplosion(i->getTransposeFunction()).claimAll());
setLoweredExplosion(i, e);
}

void IRGenSILFunction::visitDifferentiableFunctionExtractInst(
DifferentiableFunctionExtractInst *i) {
unsigned structFieldOffset = i->getExtractee().rawValue;
Expand All @@ -1873,6 +1885,23 @@ void IRGenSILFunction::visitDifferentiableFunctionExtractInst(
setLoweredExplosion(i, e);
}

void IRGenSILFunction::
visitLinearFunctionExtractInst(LinearFunctionExtractInst *i) {
unsigned structFieldOffset = i->getExtractee().rawValue;
unsigned fieldSize = 1;
auto fnRepr = i->getFunctionOperand()->getType().getFunctionRepresentation();
if (fnRepr == SILFunctionTypeRepresentation::Thick) {
structFieldOffset *= 2;
fieldSize = 2;
}
auto diffFnExp = getLoweredExplosion(i->getFunctionOperand());
assert(diffFnExp.size() == fieldSize * 2);
Explosion e;
e.add(diffFnExp.getRange(structFieldOffset, structFieldOffset + fieldSize));
(void)diffFnExp.claimAll();
setLoweredExplosion(i, e);
}

void IRGenSILFunction::visitDifferentiabilityWitnessFunctionInst(
DifferentiabilityWitnessFunctionInst *i) {
llvm::Value *diffWitness =
Expand Down
56 changes: 56 additions & 0 deletions lib/ParseSIL/ParseSIL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5072,6 +5072,41 @@ bool SILParser::parseSpecificSILInstruction(SILBuilder &B,
InstLoc, parameterIndicesSubset, original, derivativeFunctions);
break;
}
case SILInstructionKind::LinearFunctionInst: {
// e.g. linear_function [parameters 0 1 2] %0 : $T
// e.g. linear_function [parameters 0 1 2] %0 : $T with_transpose %1 : $T
// Parse `[parameters <integer_literal>...]`.
SmallVector<unsigned, 8> parameterIndices;
if (parseIndexList(P, "parameters", parameterIndices,
diag::sil_autodiff_expected_parameter_index))
return true;
// Parse the original function value.
SILValue original;
SourceLoc originalOperandLoc;
if (parseTypedValueRef(original, originalOperandLoc, B))
return true;
auto fnType = original->getType().getAs<SILFunctionType>();
if (!fnType) {
P.diagnose(originalOperandLoc,
diag::sil_inst_autodiff_expected_function_type_operand);
return true;
}
// Parse an optional transpose function.
Optional<SILValue> transpose = None;
if (P.Tok.is(tok::identifier) && P.Tok.getText() == "with_transpose") {
P.consumeToken(tok::identifier);
transpose = SILValue();
if (parseTypedValueRef(*transpose, B))
return true;
}
if (parseSILDebugLocation(InstLoc, B))
return true;
auto *parameterIndicesSubset = IndexSubset::get(
P.Context, fnType->getNumParameters(), parameterIndices);
ResultVal = B.createLinearFunction(
InstLoc, parameterIndicesSubset, original, transpose);
break;
}
case SILInstructionKind::DifferentiableFunctionExtractInst: {
// Parse the rest of the instruction: an extractee, a differentiable
// function operand, an optional explicit extractee type, and a debug
Expand Down Expand Up @@ -5104,6 +5139,27 @@ bool SILParser::parseSpecificSILInstruction(SILBuilder &B,
InstLoc, extractee, functionOperand, extracteeType);
break;
}
case SILInstructionKind::LinearFunctionExtractInst: {
// Parse the rest of the instruction: an extractee, a linear function
// operand, and a debug location.
LinearDifferentiableFunctionTypeComponent extractee;
StringRef extracteeNames[2] = {"original", "transpose"};
SILValue functionOperand;
SourceLoc lastLoc;
if (P.parseToken(tok::l_square,
diag::sil_inst_autodiff_expected_linear_extractee_kind) ||
parseSILIdentifierSwitch(extractee, extracteeNames,
diag::sil_inst_autodiff_expected_linear_extractee_kind) ||
P.parseToken(tok::r_square, diag::sil_autodiff_expected_rsquare,
"extractee kind"))
return true;
if (parseTypedValueRef(functionOperand, B) ||
parseSILDebugLocation(InstLoc, B))
return true;
ResultVal = B.createLinearFunctionExtract(
InstLoc, extractee, functionOperand);
break;
}
case SILInstructionKind::DifferentiabilityWitnessFunctionInst: {
// e.g. differentiability_witness_function
// [jvp] [parameters 0 1] [results 0] <T where T: Differentiable>
Expand Down
3 changes: 3 additions & 0 deletions lib/SIL/OperandOwnership.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ FORWARD_ANY_OWNERSHIP_INST(DestructureStruct)
FORWARD_ANY_OWNERSHIP_INST(DestructureTuple)
FORWARD_ANY_OWNERSHIP_INST(InitExistentialRef)
FORWARD_ANY_OWNERSHIP_INST(DifferentiableFunction)
FORWARD_ANY_OWNERSHIP_INST(LinearFunction)
#undef FORWARD_ANY_OWNERSHIP_INST

// An instruction that forwards a constant ownership or trivial ownership.
Expand All @@ -369,6 +370,8 @@ FORWARD_CONSTANT_OR_NONE_OWNERSHIP_INST(Guaranteed, MustBeLive, TupleExtract)
FORWARD_CONSTANT_OR_NONE_OWNERSHIP_INST(Guaranteed, MustBeLive, StructExtract)
FORWARD_CONSTANT_OR_NONE_OWNERSHIP_INST(Guaranteed, MustBeLive,
DifferentiableFunctionExtract)
FORWARD_CONSTANT_OR_NONE_OWNERSHIP_INST(Guaranteed, MustBeLive,
LinearFunctionExtract)
FORWARD_CONSTANT_OR_NONE_OWNERSHIP_INST(Owned, MustBeInvalidated,
MarkUninitialized)
#undef CONSTANT_OR_NONE_OWNERSHIP_INST
Expand Down
Loading