@@ -6745,21 +6745,225 @@ bool SILParserTUState::parseSILDefaultWitnessTable(Parser &P) {
6745
6745
return false ;
6746
6746
}
6747
6747
6748
+ // SWIFT_ENABLE_TENSORFLOW
6749
+ // TODO: Dedupe with `SILParser::convertRequirements` upstream.
6750
+ // Consider defining this as `Parser::convertRequirements`.
6751
+ static void convertRequirements (Parser &P, SILFunction *F,
6752
+ ArrayRef<RequirementRepr> From,
6753
+ SmallVectorImpl<Requirement> &To) {
6754
+ if (From.empty ()) {
6755
+ To.clear ();
6756
+ return ;
6757
+ }
6758
+
6759
+ auto *GenericEnv = F->getLoweredFunctionType ()
6760
+ ->getGenericSignature ()
6761
+ ->getGenericEnvironment ();
6762
+ assert (GenericEnv);
6763
+ (void )GenericEnv;
6764
+
6765
+ IdentTypeReprLookup PerformLookup (P);
6766
+ // Use parser lexical scopes to resolve references
6767
+ // to the generic parameters.
6768
+ auto ResolveToInterfaceType = [&](TypeLoc Ty) -> Type {
6769
+ Ty.getTypeRepr ()->walk (PerformLookup);
6770
+ swift::performTypeLocChecking (P.Context , Ty, /* isSILMode*/ true ,
6771
+ /* isSILType*/ true , GenericEnv, &P.SF );
6772
+ assert (Ty.getType ());
6773
+ return Ty.getType ()->mapTypeOutOfContext ();
6774
+ };
6775
+
6776
+ for (auto &Req : From) {
6777
+ if (Req.getKind () == RequirementReprKind::SameType) {
6778
+ auto FirstType = ResolveToInterfaceType (Req.getFirstTypeLoc ());
6779
+ auto SecondType = ResolveToInterfaceType (Req.getSecondTypeLoc ());
6780
+ Requirement ConvertedRequirement (RequirementKind::SameType, FirstType,
6781
+ SecondType);
6782
+ To.push_back (ConvertedRequirement);
6783
+ continue ;
6784
+ }
6785
+
6786
+ if (Req.getKind () == RequirementReprKind::TypeConstraint) {
6787
+ auto Subject = ResolveToInterfaceType (Req.getSubjectLoc ());
6788
+ auto Constraint = ResolveToInterfaceType (Req.getConstraintLoc ());
6789
+ Requirement ConvertedRequirement (RequirementKind::Conformance, Subject,
6790
+ Constraint);
6791
+ To.push_back (ConvertedRequirement);
6792
+ continue ;
6793
+ }
6794
+
6795
+ if (Req.getKind () == RequirementReprKind::LayoutConstraint) {
6796
+ auto Subject = ResolveToInterfaceType (Req.getSubjectLoc ());
6797
+ Requirement ConvertedRequirement (RequirementKind::Layout, Subject,
6798
+ Req.getLayoutConstraint ());
6799
+ To.push_back (ConvertedRequirement);
6800
+ continue ;
6801
+ }
6802
+ llvm_unreachable (" Unsupported requirement kind" );
6803
+ }
6804
+ }
6805
+
6748
6806
// / decl-sil-differentiability-witness ::=
6749
6807
// / 'sil_differentiability_witness'
6750
- // / sil-function-name
6751
- // / 'wrt' autodiff-index-subset
6752
- // / 'sources' autodiff-index-subset
6753
- // / ('derivative_generic_signature' generic-signature)?
6754
- // / '{' ('jvp' sil-function-name)? ('vjp' sil-function-name)? '}'
6808
+ // / sil-function-name ':' sil-type
6809
+ // / 'parameters' autodiff-index-subset
6810
+ // / 'results' autodiff-index-subset
6811
+ // / ('where' generic-signature)?
6812
+ // / '{'
6813
+ // / ('jvp' sil-function-name ':' sil-type)?
6814
+ // / ('vjp' sil-function-name ':' sil-type)?
6815
+ // / '}'
6755
6816
// /
6756
6817
// / autodiff-index-subset ::=
6757
- // / [0-9]+ (',', [0-9]+)*
6818
+ // / '(' [0-9]+ (',', [0-9]+)* ')'
6758
6819
bool SILParserTUState::parseSILDifferentiabilityWitness (Parser &P) {
6759
6820
P.consumeToken (tok::kw_sil_differentiability_witness);
6760
- // TODO(TF-867): Implement parsing. Test round-tripping with printing.
6821
+ SILParser State (P);
6822
+
6823
+ // Parse the linkage.
6824
+ Optional<SILLinkage> linkage;
6825
+ if (parseSILLinkage (linkage, P))
6826
+ return true ;
6827
+ if (!linkage)
6828
+ linkage = SILLinkage::PublicExternal;
6829
+
6830
+ Scope S (&P, ScopeKind::TopLevel);
6831
+ Scope Body (&P, ScopeKind::FunctionBody);
6832
+
6833
+ auto parseFunctionNameAndType = [&](SILFunction *&fn) -> bool {
6834
+ Identifier name;
6835
+ SILType ty;
6836
+ SourceLoc fnNameLoc = P.Tok .getLoc ();
6837
+ // We need to turn on InSILBody to parse the function reference.
6838
+ Lexer::SILBodyRAII tmp (*P.L );
6839
+ GenericEnvironment *ignoredEnv;
6840
+ if ((State.parseGlobalName (name)) ||
6841
+ P.parseToken (tok::colon, diag::expected_sil_colon_value_ref) ||
6842
+ State.parseSILType (ty, ignoredEnv, /* IsFuncDecl*/ true ))
6843
+ return true ;
6844
+
6845
+ // The function doesn't exist yet. Create a zombie forward declaration.
6846
+ auto fnType = ty.getAs <SILFunctionType>();
6847
+ if (!fnType || !ty.isObject ()) {
6848
+ P.diagnose (fnNameLoc, diag::expected_sil_function_type);
6849
+ return true ;
6850
+ }
6851
+ fn = State.getGlobalNameForReference (name, fnType, fnNameLoc, true );
6852
+ State.TUState .PotentialZombieFns .insert (fn);
6853
+ return false ;
6854
+ };
6855
+
6856
+ SourceLoc lastLoc = P.getEndOfPreviousLoc ();
6857
+
6858
+ SILFunction *originalFn;
6859
+ if (parseFunctionNameAndType (originalFn))
6860
+ return true ;
6861
+
6862
+ auto parseAutoDiffIndexSubset =
6863
+ [&](StringRef label, AutoDiffIndexSubset *& paramIndexSubset) -> bool {
6864
+ if (P.parseSpecificIdentifier (
6865
+ label, diag::sil_diff_witness_expected_keyword, label))
6866
+ return true ;
6867
+ if (P.parseToken (tok::l_paren, diag::sil_diff_witness_expected_keyword,
6868
+ " (" ))
6869
+ return true ;
6870
+ // Parse parameter index list.
6871
+ SmallVector<unsigned , 8 > paramIndices;
6872
+ // Function that parses an index into `paramIndices`. Returns true on error.
6873
+ auto parseParam = [&]() -> bool {
6874
+ unsigned index;
6875
+ // TODO: Reject non-ascending parameter index lists.
6876
+ if (P.parseUnsignedInteger (index, lastLoc,
6877
+ diag::sil_diff_witness_expected_parameter_list))
6878
+ return true ;
6879
+ paramIndices.push_back (index);
6880
+ return false ;
6881
+ };
6882
+ // Parse first.
6883
+ if (parseParam ())
6884
+ return true ;
6885
+ // Parse rest.
6886
+ while (P.consumeIf (tok::comma))
6887
+ if (parseParam ())
6888
+ return true ;
6889
+ if (P.parseToken (tok::r_paren, diag::sil_diff_witness_expected_keyword,
6890
+ " (" ))
6891
+ return true ;
6892
+ auto maxIndexRef =
6893
+ std::max_element (paramIndices.begin (), paramIndices.end ());
6894
+ paramIndexSubset = AutoDiffIndexSubset::get (
6895
+ P.Context , maxIndexRef ? *maxIndexRef + 1 : 0 , paramIndices);
6896
+ return false ;
6897
+ };
6898
+ AutoDiffIndexSubset *parameterIndices = nullptr ;
6899
+ AutoDiffIndexSubset *resultIndices = nullptr ;
6900
+ if (parseAutoDiffIndexSubset (" parameters" , parameterIndices))
6901
+ return true ;
6902
+ if (parseAutoDiffIndexSubset (" results" , resultIndices))
6903
+ return true ;
6904
+
6905
+ GenericSignature *derivativeGenSig = nullptr ;
6906
+ // Parse a trailing 'where' clause if any.
6907
+ if (P.Tok .is (tok::kw_where)) {
6908
+ SourceLoc whereLoc;
6909
+ SmallVector<RequirementRepr, 4 > requirementReprs;
6910
+ bool firstTypeInComplete;
6911
+ P.parseGenericWhereClause (whereLoc, requirementReprs, firstTypeInComplete,
6912
+ /* AllowLayoutConstraints*/ false );
6913
+ auto *whereClause = TrailingWhereClause::create (
6914
+ originalFn->getModule ().getASTContext (), whereLoc, requirementReprs);
6915
+ SmallVector<Requirement, 4 > requirements;
6916
+ convertRequirements (P, originalFn, whereClause->getRequirements (),
6917
+ requirements);
6918
+ assert (requirements.size () == requirementReprs.size ());
6919
+ derivativeGenSig = evaluateOrDefault (
6920
+ P.Context .evaluator ,
6921
+ AbstractGenericSignatureRequest{
6922
+ originalFn->getLoweredFunctionType ()->getGenericSignature (),
6923
+ /* addedGenericParams=*/ {},
6924
+ std::move (requirements)},
6925
+ nullptr );
6926
+ }
6927
+
6928
+ SILFunction *jvp = nullptr ;
6929
+ SILFunction *vjp = nullptr ;
6930
+ if (P.Tok .is (tok::l_brace)) {
6931
+ SourceLoc LBraceLoc = P.Tok .getLoc ();
6932
+ P.consumeToken (tok::l_brace);
6933
+
6934
+ if (P.Tok .is (tok::identifier) && P.Tok .getText () == " jvp" ) {
6935
+ P.consumeToken (tok::identifier);
6936
+ if (P.parseToken (tok::colon, diag::sil_diff_witness_expected_keyword,
6937
+ " :" ))
6938
+ return true ;
6939
+ Scope Body (&P, ScopeKind::FunctionBody);
6940
+ if (parseFunctionNameAndType (jvp))
6941
+ return true ;
6942
+ }
6943
+
6944
+ if (P.Tok .is (tok::identifier) && P.Tok .getText () == " vjp" ) {
6945
+ P.consumeToken (tok::identifier);
6946
+ if (P.parseToken (tok::colon, diag::sil_diff_witness_expected_keyword,
6947
+ " :" ))
6948
+ return true ;
6949
+ Scope Body (&P, ScopeKind::FunctionBody);
6950
+ if (parseFunctionNameAndType (vjp))
6951
+ return true ;
6952
+ }
6953
+
6954
+ if (P.parseMatchingToken (tok::r_brace, lastLoc, diag::expected_sil_rbrace,
6955
+ LBraceLoc))
6956
+ return true ;
6957
+ }
6958
+
6959
+ // TODO: Parse `isSerialized` flag.
6960
+ bool isSerialized = false ;
6961
+ SILDifferentiabilityWitness::create (
6962
+ M, *linkage, originalFn, parameterIndices, resultIndices,
6963
+ derivativeGenSig, jvp, vjp, isSerialized);
6761
6964
return false ;
6762
6965
}
6966
+ // SWIFT_ENABLE_TENSORFLOW END
6763
6967
6764
6968
llvm::Optional<llvm::coverage::Counter> SILParser::parseSILCoverageExpr (
6765
6969
llvm::coverage::CounterExpressionBuilder &Builder) {
0 commit comments