Skip to content

Commit c23afc8

Browse files
authored
[AutoDiff] Require [parameters ...] for `{differentiable,linear}_fu… (#29015)
Make the `[parameters ...]` argument always required in the textual representation for `{differentiable,linear}_function` instructions, instead of defaulting to all parameter indices when unspecified.. This makes the textual representation consistent with other `[parameters ...]` arguments in SIL: - `differentiability_witness_function` - `differentiability_witness`
1 parent 743b6f5 commit c23afc8

File tree

3 files changed

+52
-83
lines changed

3 files changed

+52
-83
lines changed

docs/SIL.rst

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5606,7 +5606,7 @@ differentiable_function
56065606
::
56075607

56085608
sil-instruction ::= 'differentiable_function'
5609-
sil-differentiable-function-parameter-indices?
5609+
sil-differentiable-function-parameter-indices
56105610
sil-value ':' sil-type
56115611
sil-differentiable-function-derivative-functions-clause?
56125612
@@ -5624,8 +5624,7 @@ function. There are two derivative functions: a Jacobian-vector products (JVP)
56245624
function and a vector-Jacobian products (VJP) function.
56255625

56265626
``[parameters ...]`` specifies parameter indices that the original function is
5627-
differentiable with respect to. When not specified, it defaults to all
5628-
parameters.
5627+
differentiable with respect to.
56295628

56305629
A ``with_derivative`` clause specifies the differentiation functions associated
56315630
with the original function. When a ``with_derivative`` clause is not specified,
@@ -5641,7 +5640,7 @@ linear_function
56415640
::
56425641

56435642
sil-instruction ::= 'linear_function'
5644-
sil-linear-function-parameter-indices?
5643+
sil-linear-function-parameter-indices
56455644
sil-value ':' sil-type
56465645
sil-linear-function-transpose-function-clause?
56475646

@@ -5656,7 +5655,7 @@ Bundles a function with its transpose function into a
56565655
``@differentiable(linear)`` function.
56575656

56585657
``[parameters ...]`` specifies parameter indices that the original function is
5659-
linear with respect to. When not specified, it defaults to all parameters.
5658+
linear with respect to.
56605659

56615660
A ``with_transpose`` clause specifies the transpose function associated
56625661
with the original function. When a ``with_transpose`` clause is not specified,

lib/ParseSIL/ParseSIL.cpp

Lines changed: 40 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -2071,6 +2071,32 @@ static bool parseAssignOwnershipQualifier(AssignOwnershipQualifier &Result,
20712071
return false;
20722072
}
20732073

2074+
// SWIFT_ENABLE_TENSORFLOW
2075+
// Parse a list of integer indices, prefaced with the given string label.
2076+
// Returns true on error.
2077+
static bool parseIndexList(Parser &P, StringRef label,
2078+
SmallVectorImpl<unsigned> &indices,
2079+
const Diagnostic &parseIndexDiag) {
2080+
SourceLoc loc;
2081+
// Parse `[<label> <integer_literal>...]`.
2082+
if (P.parseToken(tok::l_square, diag::sil_autodiff_expected_lsquare,
2083+
"index list") ||
2084+
P.parseSpecificIdentifier(
2085+
label, diag::sil_autodiff_expected_index_list_label, label))
2086+
return true;
2087+
while (P.Tok.is(tok::integer_literal)) {
2088+
unsigned index;
2089+
if (P.parseUnsignedInteger(index, loc, parseIndexDiag))
2090+
return true;
2091+
indices.push_back(index);
2092+
}
2093+
if (P.parseToken(tok::r_square, diag::sil_autodiff_expected_rsquare,
2094+
"index list"))
2095+
return true;
2096+
return false;
2097+
};
2098+
// SWIFT_ENABLE_TENSORFLOW END
2099+
20742100
// SWIFT_ENABLE_TENSORFLOW
20752101
/// sil-differentiability-witness-config-and-function ::=
20762102
/// '[' 'parameters' index-subset ']'
@@ -2083,35 +2109,14 @@ static bool parseAssignOwnershipQualifier(AssignOwnershipQualifier &Result,
20832109
static Optional<std::pair<AutoDiffConfig, SILFunction *>>
20842110
parseSILDifferentiabilityWitnessConfigAndFunction(Parser &P, SILParser &SP,
20852111
SILLocation L) {
2086-
SourceLoc lastLoc;
2087-
// Parse an index set, prefaced with the given label.
2088-
auto parseIndexSet = [&](StringRef label, SmallVectorImpl<unsigned> &indices,
2089-
const Diagnostic &parseIndexDiag) -> bool {
2090-
// Parse `[<label> <integer_literal>...]`.
2091-
if (P.parseToken(tok::l_square, diag::sil_autodiff_expected_lsquare,
2092-
"index list") ||
2093-
P.parseSpecificIdentifier(
2094-
label, diag::sil_autodiff_expected_index_list_label, label))
2095-
return true;
2096-
while (P.Tok.is(tok::integer_literal)) {
2097-
unsigned index;
2098-
if (P.parseUnsignedInteger(index, lastLoc, parseIndexDiag))
2099-
return true;
2100-
indices.push_back(index);
2101-
}
2102-
if (P.parseToken(tok::r_square, diag::sil_autodiff_expected_rsquare,
2103-
"index list"))
2104-
return true;
2105-
return false;
2106-
};
21072112
// Parse parameter and result indices.
21082113
SmallVector<unsigned, 8> parameterIndices;
21092114
SmallVector<unsigned, 8> resultIndices;
2110-
if (parseIndexSet("parameters", parameterIndices,
2111-
diag::sil_autodiff_expected_parameter_index))
2115+
if (parseIndexList(P, "parameters", parameterIndices,
2116+
diag::sil_autodiff_expected_parameter_index))
21122117
return {};
2113-
if (parseIndexSet("results", resultIndices,
2114-
diag::sil_autodiff_expected_result_index))
2118+
if (parseIndexList(P, "results", resultIndices,
2119+
diag::sil_autodiff_expected_result_index))
21152120
return {};
21162121
// Parse witness generic parameter clause.
21172122
GenericSignature witnessGenSig = GenericSignature();
@@ -2938,27 +2943,12 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
29382943
//
29392944
// e.g. differentiable_function [parameters 0 1 2] %0 : $T with_derivative
29402945
// {%1 : $T, %2 : $T}
2941-
// ^ jvp ^ vjp
2942-
SourceLoc lastLoc;
2946+
// ^~ jvp ^~ vjp
2947+
// Parse `[parameters <integer_literal>...]`.
29432948
SmallVector<unsigned, 8> parameterIndices;
2944-
// Parse optional `[parameters <integer_literal>...]`
2945-
if (P.Tok.is(tok::l_square) &&
2946-
P.peekToken().is(tok::identifier) &&
2947-
P.peekToken().getText() == "parameters") {
2948-
P.consumeToken(tok::l_square);
2949-
P.consumeToken(tok::identifier);
2950-
// Parse indices.
2951-
while (P.Tok.is(tok::integer_literal)) {
2952-
unsigned index;
2953-
if (P.parseUnsignedInteger(index, lastLoc,
2954-
diag::sil_autodiff_expected_parameter_index))
2955-
return true;
2956-
parameterIndices.push_back(index);
2957-
}
2958-
if (P.parseToken(tok::r_square, diag::sil_autodiff_expected_rsquare,
2959-
"parameter index list"))
2960-
return true;
2961-
}
2949+
if (parseIndexList(P, "parameters", parameterIndices,
2950+
diag::sil_autodiff_expected_parameter_index))
2951+
return true;
29622952
// Parse the original function value.
29632953
SILValue original;
29642954
SourceLoc originalOperandLoc;
@@ -3001,26 +2991,11 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
30012991
case SILInstructionKind::LinearFunctionInst: {
30022992
// e.g. linear_function [parameters 0 1 2] %0 : $T
30032993
// e.g. linear_function [parameters 0 1 2] %0 : $T with_transpose %1 : $T
3004-
SourceLoc lastLoc;
2994+
// Parse `[parameters <integer_literal>...]`.
30052995
SmallVector<unsigned, 8> parameterIndices;
3006-
// Parse optional `[parameters <integer_literal>...]`
3007-
if (P.Tok.is(tok::l_square) &&
3008-
P.peekToken().is(tok::identifier) &&
3009-
P.peekToken().getText() == "parameters") {
3010-
P.consumeToken(tok::l_square);
3011-
P.consumeToken(tok::identifier);
3012-
// Parse indices.
3013-
while (P.Tok.is(tok::integer_literal)) {
3014-
unsigned index;
3015-
if (P.parseUnsignedInteger(index, lastLoc,
3016-
diag::sil_autodiff_expected_parameter_index))
3017-
return true;
3018-
parameterIndices.push_back(index);
3019-
}
3020-
if (P.parseToken(tok::r_square, diag::sil_autodiff_expected_rsquare,
3021-
"parameter index list"))
3022-
return true;
3023-
}
2996+
if (parseIndexList(P, "parameters", parameterIndices,
2997+
diag::sil_autodiff_expected_parameter_index))
2998+
return true;
30242999
// Parse the original function value.
30253000
SILValue original;
30263001
SourceLoc originalOperandLoc;
@@ -3117,9 +3092,8 @@ bool SILParser::parseSILInstruction(SILBuilder &B) {
31173092
SourceLoc keyStartLoc = P.Tok.getLoc();
31183093
auto configAndFn = parseSILDifferentiabilityWitnessConfigAndFunction(
31193094
P, *this, InstLoc);
3120-
if (!configAndFn) {
3095+
if (!configAndFn)
31213096
return true;
3122-
}
31233097
auto config = configAndFn->first;
31243098
auto originalFn = configAndFn->second;
31253099
auto *witness = SILMod.lookUpDifferentiabilityWitness(

lib/SIL/SILPrinter.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,12 +1212,10 @@ class SILPrinter : public SILInstructionVisitor<SILPrinter> {
12121212

12131213
// SWIFT_ENABLE_TENSORFLOW
12141214
void visitDifferentiableFunctionInst(DifferentiableFunctionInst *dfi) {
1215-
if (!dfi->getParameterIndices()->isEmpty()) {
1216-
*this << "[parameters";
1217-
for (auto i : dfi->getParameterIndices()->getIndices())
1218-
*this << ' ' << i;
1219-
*this << "] ";
1220-
}
1215+
*this << "[parameters";
1216+
for (auto i : dfi->getParameterIndices()->getIndices())
1217+
*this << ' ' << i;
1218+
*this << "] ";
12211219
*this << getIDAndType(dfi->getOriginalFunction());
12221220
if (dfi->hasDerivativeFunctions()) {
12231221
*this << " with_derivative ";
@@ -1227,12 +1225,10 @@ class SILPrinter : public SILInstructionVisitor<SILPrinter> {
12271225
}
12281226

12291227
void visitLinearFunctionInst(LinearFunctionInst *lfi) {
1230-
if (!lfi->getParameterIndices()->isEmpty()) {
1231-
*this << "[parameters";
1232-
for (auto i : lfi->getParameterIndices()->getIndices())
1233-
*this << ' ' << i;
1234-
*this << "] ";
1235-
}
1228+
*this << "[parameters";
1229+
for (auto i : lfi->getParameterIndices()->getIndices())
1230+
*this << ' ' << i;
1231+
*this << "] ";
12361232
*this << getIDAndType(lfi->getOriginalFunction());
12371233
if (lfi->hasTransposeFunction()) {
12381234
*this << " with_transpose ";

0 commit comments

Comments
 (0)