Skip to content

Commit aea64d3

Browse files
committed
Update differentiability witness syntax.
Print original function name in comment. ``` // differentiability witness for foo sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @foo ```
1 parent 873468f commit aea64d3

File tree

4 files changed

+56
-39
lines changed

4 files changed

+56
-39
lines changed

lib/ParseSIL/ParseSIL.cpp

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6805,17 +6805,17 @@ static void convertRequirements(Parser &P, SILFunction *F,
68056805

68066806
/// decl-sil-differentiability-witness ::=
68076807
/// 'sil_differentiability_witness'
6808+
/// '[' 'parameters' index-subset ']'
6809+
/// '[' 'results' index-subset ']'
6810+
/// ('[' 'where' derivatve-generic-signature-requirements ']')?
68086811
/// sil-function-name ':' sil-type
6809-
/// 'parameters' autodiff-index-subset
6810-
/// 'results' autodiff-index-subset
6811-
/// ('where' generic-signature)?
68126812
/// '{'
68136813
/// ('jvp' sil-function-name ':' sil-type)?
68146814
/// ('vjp' sil-function-name ':' sil-type)?
68156815
/// '}'
68166816
///
6817-
/// autodiff-index-subset ::=
6818-
/// '(' [0-9]+ (',', [0-9]+)* ')'
6817+
/// index-subset ::=
6818+
/// [0-9]+ (' ' [0-9]+)*
68196819
bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
68206820
P.consumeToken(tok::kw_sil_differentiability_witness);
68216821
SILParser State(P);
@@ -6853,21 +6853,17 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
68536853
State.TUState.PotentialZombieFns.insert(fn);
68546854
return false;
68556855
};
6856-
// Parse original function name.
6857-
SILFunction *originalFn;
6858-
if (parseFunctionName(originalFn))
6859-
return true;
68606856

68616857
SourceLoc lastLoc = P.getEndOfPreviousLoc();
68626858
// Parse an index subset, prefaced with the given label.
68636859
auto parseIndexSubset =
68646860
[&](StringRef label, IndexSubset *& indexSubset) -> bool {
6861+
if (P.parseToken(tok::l_square, diag::sil_diff_witness_expected_keyword,
6862+
"["))
6863+
return true;
68656864
if (P.parseSpecificIdentifier(
68666865
label, diag::sil_diff_witness_expected_keyword, label))
68676866
return true;
6868-
if (P.parseToken(tok::l_paren, diag::sil_diff_witness_expected_keyword,
6869-
"("))
6870-
return true;
68716867
// Parse parameter index list.
68726868
SmallVector<unsigned, 8> paramIndices;
68736869
// Function that parses an index into `paramIndices`. Returns true on error.
@@ -6884,11 +6880,11 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
68846880
if (parseParam())
68856881
return true;
68866882
// Parse rest.
6887-
while (P.consumeIf(tok::comma))
6883+
while (P.Tok.isNot(tok::r_square))
68886884
if (parseParam())
68896885
return true;
6890-
if (P.parseToken(tok::r_paren, diag::sil_diff_witness_expected_keyword,
6891-
"("))
6886+
if (P.parseToken(tok::r_square, diag::sil_diff_witness_expected_keyword,
6887+
"]"))
68926888
return true;
68936889
auto maxIndexRef =
68946890
std::max_element(paramIndices.begin(), paramIndices.end());
@@ -6907,18 +6903,33 @@ bool SILParserTUState::parseSILDifferentiabilityWitness(Parser &P) {
69076903
// Parse a trailing 'where' clause (optional).
69086904
// This represents derivative generic signature requirements.
69096905
GenericSignature *derivativeGenSig = nullptr;
6910-
if (P.Tok.is(tok::kw_where)) {
6911-
SourceLoc whereLoc;
6912-
SmallVector<RequirementRepr, 4> requirementReprs;
6906+
SourceLoc whereLoc;
6907+
SmallVector<RequirementRepr, 4> derivativeRequirementReprs;
6908+
if (P.Tok.is(tok::l_square) && P.peekToken().is(tok::kw_where)) {
6909+
P.consumeToken(tok::l_square);
69136910
bool firstTypeInComplete;
6914-
P.parseGenericWhereClause(whereLoc, requirementReprs, firstTypeInComplete,
6911+
P.parseGenericWhereClause(whereLoc, derivativeRequirementReprs,
6912+
firstTypeInComplete,
69156913
/*AllowLayoutConstraints*/ false);
6916-
auto *whereClause = TrailingWhereClause::create(
6917-
originalFn->getModule().getASTContext(), whereLoc, requirementReprs);
6914+
if (P.parseToken(tok::r_square, diag::sil_diff_witness_expected_keyword,
6915+
"]"))
6916+
return true;
6917+
}
6918+
6919+
// Parse original function name.
6920+
SILFunction *originalFn;
6921+
if (parseFunctionName(originalFn))
6922+
return true;
6923+
6924+
// Resolve derivative requirements.
6925+
if (!derivativeRequirementReprs.empty()) {
69186926
SmallVector<Requirement, 4> requirements;
6927+
auto *whereClause = TrailingWhereClause::create(
6928+
originalFn->getModule().getASTContext(), whereLoc,
6929+
derivativeRequirementReprs);
69196930
convertRequirements(P, originalFn, whereClause->getRequirements(),
69206931
requirements);
6921-
assert(requirements.size() == requirementReprs.size());
6932+
assert(requirements.size() == derivativeRequirementReprs.size());
69226933
derivativeGenSig = evaluateOrDefault(
69236934
P.Context.evaluator,
69246935
AbstractGenericSignatureRequest{

lib/SIL/SILPrinter.cpp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3060,24 +3060,24 @@ void SILDefaultWitnessTable::dump() const {
30603060
// SWIFT_ENABLE_TENSORFLOW
30613061
void SILDifferentiabilityWitness::print(
30623062
llvm::raw_ostream &OS, bool verbose) const {
3063+
OS << "// differentiability witness for "
3064+
<< demangleSymbol(originalFunction->getName()) << "\n";
30633065
// sil_differentiability_witness @original-function-name : $original-sil-type
30643066
PrintOptions qualifiedSILTypeOptions = PrintOptions::printQualifiedSILType();
30653067
OS << "sil_differentiability_witness ";
30663068
printLinkage(OS, linkage, ForDefinition);
3067-
OS << "@" << originalFunction->getName() << " : "
3068-
<< originalFunction->getLoweredType();
3069-
// parameters (0, 1, ...)
3070-
OS << " parameters (";
3069+
// [parameters 0 1 ...]
3070+
OS << "[parameters ";
30713071
interleave(parameterIndices->getIndices(),
30723072
[&](unsigned index) { OS << index; },
3073-
[&] { OS << ", "; });
3074-
// results (0, 1, ...)
3075-
OS << ") results (";
3073+
[&] { OS << " "; });
3074+
// [results 0 1 ...]
3075+
OS << "] [results ";
30763076
interleave(resultIndices->getIndices(),
30773077
[&](unsigned index) { OS << index; },
3078-
[&] { OS << ", "; });
3079-
OS << ')';
3080-
// wrt 0, 1, ...
3078+
[&] { OS << " "; });
3079+
OS << ']';
3080+
// [where ...]
30813081
if (derivativeGenericSignature) {
30823082
// NOTE: This needs to be changed if there is no utility for parsing
30833083
// generic signatures. Idea: we could instead print the type of the original
@@ -3096,16 +3096,20 @@ void SILDifferentiabilityWitness::print(
30963096
}
30973097
}
30983098
if (!requirements.empty()) {
3099-
OS << " where ";
3099+
OS << " [where ";
31003100
auto SubPrinter = PrintOptions::printSIL();
31013101
interleave(requirements,
31023102
[&](Requirement req) {
31033103
req.print(OS, SubPrinter);
31043104
return;
31053105
},
31063106
[&] { OS << ", "; });
3107+
OS << ']';
31073108
}
31083109
}
3110+
// original: @original-function-name : $original-sil-type
3111+
OS << " @" << originalFunction->getName() << " : "
3112+
<< originalFunction->getLoweredType();
31093113
// {
31103114
// jvp: @jvp-function-name : $jvp-sil-type
31113115
// vjp: @vjp-function-name : $vjp-sil-type

lib/SIL/SILVerifier.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5358,14 +5358,16 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const {
53585358
return;
53595359
#endif
53605360
auto origFnType = originalFunction->getLoweredFunctionType();
5361+
CanGenericSignature derivativeCanGenSig;
5362+
if (auto *derivativeGenSig = getDerivativeGenericSignature())
5363+
derivativeCanGenSig = derivativeGenSig->getCanonicalSignature();
53615364
if (jvp) {
53625365
// TODO(TF-893): Change `SILFunctionType::getAutoDiffDerivativeFunctionType`
53635366
// to accept result indices.
53645367
auto expectedJVPType = origFnType->getAutoDiffDerivativeFunctionType(
53655368
getParameterIndices(), /*resultIndex*/ *resultIndices->begin(),
53665369
AutoDiffDerivativeFunctionKind::JVP, M.Types,
5367-
LookUpConformanceInModule(M.getSwiftModule()),
5368-
getDerivativeGenericSignature()->getCanonicalSignature());
5370+
LookUpConformanceInModule(M.getSwiftModule()), derivativeCanGenSig);
53695371
SILVerifier(*jvp).requireSameType(
53705372
SILType::getPrimitiveObjectType(jvp->getLoweredFunctionType()),
53715373
SILType::getPrimitiveObjectType(expectedJVPType),
@@ -5377,8 +5379,7 @@ void SILDifferentiabilityWitness::verify(const SILModule &M) const {
53775379
auto expectedVJPType = origFnType->getAutoDiffDerivativeFunctionType(
53785380
getParameterIndices(), /*resultIndex*/ *resultIndices->begin(),
53795381
AutoDiffDerivativeFunctionKind::VJP, M.Types,
5380-
LookUpConformanceInModule(M.getSwiftModule()),
5381-
getDerivativeGenericSignature()->getCanonicalSignature());
5382+
LookUpConformanceInModule(M.getSwiftModule()), derivativeCanGenSig);
53825383
SILVerifier(*vjp).requireSameType(
53835384
SILType::getPrimitiveObjectType(vjp->getLoweredFunctionType()),
53845385
SILType::getPrimitiveObjectType(expectedVJPType),

test/AutoDiff/sil_differentiability_witness_parse.sil

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,13 @@ bb0(%0 : $*T.TangentVector, %1 : $*T.TangentVector):
133133
// static AdditiveArithmetic<>.zero.getter
134134
sil [serialized] [always_inline] @$ss18AdditiveArithmeticPss27ExpressibleByIntegerLiteralRzrlE4zeroxvgZ : $@convention(method) <τ_0_0 where τ_0_0 : AdditiveArithmetic, τ_0_0 : ExpressibleByIntegerLiteral> (@thick τ_0_0.Type) -> @out τ_0_0
135135

136-
sil_differentiability_witness hidden @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 parameters (0, 1) results (0) where τ_0_0 : _Differentiable {
136+
sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 {
137137
jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector)
138138
vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float))
139139
}
140140

141-
// CHECK-LABEL: sil_differentiability_witness hidden @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 parameters (0, 1) results (0) where τ_0_0 : _Differentiable {
141+
// CHECK-LABEL: // differentiability witness for foo
142+
// CHECK: sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @foo : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 {
142143
// CHECK: jvp: @AD__foo__jvp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector)
143144
// CHECK: vjp: @AD__foo__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float))
144145
// CHECK: }

0 commit comments

Comments
 (0)