@@ -2231,124 +2231,92 @@ void AttributeChecker::visitFrozenAttr(FrozenAttr *attr) {
2231
2231
}
2232
2232
}
2233
2233
2234
+ // SWIFT_ENABLE_TENSORFLOW
2235
+ // Returns the function declaration corresponding to the given function name and
2236
+ // lookup context. If the function declaration cannot be resolved, emits a
2237
+ // diagnostic and returns nullptr.
2234
2238
static FuncDecl *getResolvedFuncDecl (
2235
- DifferentiableAttr::FunctionSpecifier funcSpecifier,
2236
- TypeChecker &TC, FuncDecl *original,
2237
- const std::function<bool (FuncDecl *)> &isValidLookupResult,
2238
- const std::function<void()> &overloadDiagnostic, bool isPrimal) {
2239
+ DeclName funcName, SourceLoc funcNameLoc, TypeChecker &TC,
2240
+ DeclContext *lookupContext,
2241
+ const std::function<bool (FuncDecl *)> &isValidFuncDecl,
2242
+ const std::function<bool(FuncDecl *)> &hasValidTypeContext,
2243
+ const std::function<void()> &overloadDiagnostic,
2244
+ const std::function<void()> &ambiguousDiagnostic,
2245
+ const std::function<void()> ¬FunctionDiagnostic) {
2246
+
2239
2247
FuncDecl *resolvedFuncDecl = nullptr ;
2240
- UnresolvedDeclRefExpr UDRE (funcSpecifier.Name , DeclRefKind::Ordinary,
2241
- funcSpecifier.Loc );
2242
- auto expr = TC.resolveDeclRefExpr (&UDRE, original->getInnermostDeclContext ());
2243
- // If it's an unresolved dot expression, it must be a class method or an
2244
- // instance method.
2245
- if (auto dotExpr = dyn_cast<UnresolvedDotExpr>(expr)) {
2246
- // Look up the decl name directly in the current type context.
2247
- auto typeCtx = original->getInnermostTypeContext ();
2248
- auto lookupResult = TC.lookupMember (typeCtx,
2249
- typeCtx->getDeclaredInterfaceType (),
2250
- funcSpecifier.Name );
2251
- // Declare error flags.
2252
- bool exprIsNotFunction = false ;
2253
- bool overloadNotFound = false ;
2254
-
2255
- for (auto lookupEntry : lookupResult) {
2256
- auto funcDecl = dyn_cast<FuncDecl>(lookupEntry.getValueDecl ());
2257
- // Set flag if the lookup result is not a function declaration.
2258
- if (!funcDecl) {
2259
- exprIsNotFunction = true ;
2260
- continue ;
2261
- }
2262
- // Set flag if lookup result is invalid.
2263
- if (!isValidLookupResult (funcDecl)) {
2264
- overloadNotFound = true ;
2265
- continue ;
2266
- }
2267
- // If more than one lookup result has the expected function type, then
2268
- // the function is ambgiuous.
2269
- if (resolvedFuncDecl) {
2270
- TC.diagnose (funcSpecifier.Loc .getBaseNameLoc (),
2271
- diag::differentiable_attr_ambiguous_function_identifier,
2272
- funcSpecifier.Name );
2273
- return nullptr ;
2274
- }
2275
- // Resolve the function declaration.
2276
- resolvedFuncDecl = funcDecl;
2277
- }
2278
- // If the function declaration could not be resolved, check error flags.
2279
- if (!resolvedFuncDecl) {
2280
- if (overloadNotFound) {
2281
- overloadDiagnostic ();
2282
- return nullptr ;
2283
- }
2284
- assert (exprIsNotFunction && " Function declaration could not be resolved" );
2285
- TC.diagnose (funcSpecifier.Loc .getBaseNameLoc (),
2286
- diag::differentiable_attr_specified_not_function,
2287
- funcSpecifier.Name , isPrimal);
2288
- return nullptr ;
2248
+
2249
+ // Initialize error flags.
2250
+ bool notAFuncDecl = false ;
2251
+ bool wrongTypeContext = false ;
2252
+ bool overloadNotFound = false ;
2253
+
2254
+ // Perform lookup, ignoring access control.
2255
+ auto options = defaultUnqualifiedLookupOptions |
2256
+ NameLookupFlags::IgnoreAccessControl;
2257
+ auto results =
2258
+ TC.lookupUnqualified (lookupContext, funcName, funcNameLoc, options);
2259
+
2260
+ // Note: static methods are omitted from `TypeChecker.lookupUnqualified` in
2261
+ // Swift 3. The code below is a workaround for resolving them.
2262
+ //
2263
+ // This is necessary because the stdlib is compiled with `-swift-version 3`
2264
+ // for Swift 3 compatibility, and floating point types use the
2265
+ // `@differentiable` attribute with static adjoint methods (such as
2266
+ // `_adjointAdd`).
2267
+ if (lookupContext->getASTContext ().isSwiftVersion3 () && results.empty () &&
2268
+ lookupContext->isTypeContext ()) {
2269
+ auto tmp = TC.lookupMember (lookupContext,
2270
+ lookupContext->getSelfTypeInContext (), funcName);
2271
+ for (auto choice : tmp) {
2272
+ auto decl = choice.getValueDecl ();
2273
+ if (!decl) continue ;
2274
+ auto funcDecl = dyn_cast<FuncDecl>(decl);
2275
+ if (!funcDecl) continue ;
2276
+ results.add (LookupResultEntry (funcDecl));
2289
2277
}
2290
2278
}
2291
- // If it's resolved to a type, it's not what we want.
2292
- else if (isa<TypeExpr>(expr))
2293
- TC.diagnose (funcSpecifier.Loc .getBaseNameLoc (),
2294
- diag::differentiable_attr_specified_not_function,
2295
- funcSpecifier.Name , isPrimal);
2296
- // If it's directly resolved to a concrete declaration, it must be a free
2297
- // function in the module context.
2298
- else if (auto declRefExpr = dyn_cast<DeclRefExpr>(expr)) {
2299
- auto funcDecl = dyn_cast<FuncDecl>(declRefExpr->getDecl ());
2300
- // If the candidate is not a function, then it's an error.
2301
- if (!funcDecl) {
2302
- TC.diagnose (funcSpecifier.Loc .getBaseNameLoc (),
2303
- diag::differentiable_attr_specified_not_function,
2304
- funcSpecifier.Name , isPrimal);
2305
- return nullptr ;
2306
- }
2307
2279
2308
- // If the original and the primal or adjoint have different parents, or
2309
- // if they both have no type context and are in different modules, then
2310
- // it's an error.
2311
- auto inCompatibleContexts = [&](FuncDecl *decl1, FuncDecl *decl2) {
2312
- if (!decl1->getInnermostTypeContext () &&
2313
- !decl2->getInnermostTypeContext () &&
2314
- decl1->getParentModule () == decl2->getParentModule ())
2315
- return true ;
2316
- if (decl1->getParent () == decl2->getParent ())
2317
- return true ;
2318
- return false ;
2319
- };
2280
+ for (auto choice : results) {
2281
+ auto decl = choice.getValueDecl ();
2282
+ if (!decl) continue ;
2320
2283
2321
- if (!inCompatibleContexts (original, funcDecl)) {
2322
- TC.diagnose (funcSpecifier.Loc .getBaseNameLoc (),
2323
- diag::differentiable_attr_function_not_same_type_context,
2324
- funcSpecifier.Name );
2325
- return nullptr ;
2284
+ auto funcDecl = dyn_cast<FuncDecl>(decl);
2285
+ if (!funcDecl) {
2286
+ notAFuncDecl = true ;
2287
+ continue ;
2288
+ }
2289
+ if (!hasValidTypeContext (funcDecl)) {
2290
+ wrongTypeContext = true ;
2291
+ continue ;
2292
+ }
2293
+ if (!isValidFuncDecl (funcDecl)) {
2294
+ overloadNotFound = true ;
2295
+ continue ;
2296
+ }
2297
+ if (resolvedFuncDecl) {
2298
+ ambiguousDiagnostic ();
2299
+ resolvedFuncDecl = nullptr ;
2300
+ break ;
2326
2301
}
2327
- // Otherwise, the original and the function are declared in the same
2328
- // context. Save this candidate for further type checking.
2329
2302
resolvedFuncDecl = funcDecl;
2330
2303
}
2331
- // Overloaded names are not supported.
2332
- // FIXME: Resolve using the expected function type.
2333
- else if (isa<OverloadedDeclRefExpr>(expr)) {
2334
- TC.diagnose (funcSpecifier.Loc .getBaseNameLoc (),
2335
- diag::differentiable_attr_ambiguous_function_identifier,
2336
- funcSpecifier.Name );
2337
- return nullptr ;
2338
- }
2339
- // Error expressions have been handled already.
2340
- else if (isa<ErrorExpr>(expr))
2341
- return nullptr ; // Diagnostics already emitted.
2342
- else
2343
- llvm_unreachable (" Unhandled expr kind" );
2344
-
2345
- assert (resolvedFuncDecl && " Function declaration should have been resolved" );
2346
-
2347
- // Perform an additional check to handle cases that were not
2348
- // UnresolvedDotExpr.
2349
- if (!isValidLookupResult (resolvedFuncDecl)) {
2350
- overloadDiagnostic ();
2351
- return nullptr ;
2304
+ // If function declaration could not be resolved, emit the appropriate
2305
+ // diagnostic.
2306
+ if (!resolvedFuncDecl) {
2307
+ if (results.empty ()) {
2308
+ TC.diagnose (funcNameLoc, diag::use_unresolved_identifier, funcName,
2309
+ funcName.isOperator ());
2310
+ } else if (wrongTypeContext) {
2311
+ TC.diagnose (funcNameLoc,
2312
+ diag::differentiable_attr_function_not_same_type_context,
2313
+ funcName);
2314
+ } else if (overloadNotFound) {
2315
+ overloadDiagnostic ();
2316
+ } else {
2317
+ assert (notAFuncDecl && " Expected 'not a function' error" );
2318
+ notFunctionDiagnostic ();
2319
+ }
2352
2320
}
2353
2321
2354
2322
return resolvedFuncDecl;
@@ -2381,10 +2349,48 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
2381
2349
auto originalParamsTy =
2382
2350
originalParams.getInterfaceType (original->getASTContext ());
2383
2351
2352
+ // If the original function and the primal/adjoint have different parents, or
2353
+ // if they both have no type context and are in different modules, then
2354
+ // it's an error.
2355
+ auto hasValidTypeContext = [&](FuncDecl *decl) {
2356
+ if (!original->getInnermostTypeContext () &&
2357
+ !decl->getInnermostTypeContext () &&
2358
+ original->getParentModule () == decl->getParentModule ())
2359
+ return true ;
2360
+ if (auto typeCtx1 = original->getInnermostTypeContext ()) {
2361
+ if (auto typeCtx2 = decl->getInnermostTypeContext ()) {
2362
+ auto type1 = typeCtx1->getDeclaredInterfaceType ();
2363
+ auto type2 = typeCtx2->getDeclaredInterfaceType ();
2364
+ return type1->isEqual (type2);
2365
+ }
2366
+ }
2367
+ return original->getParent () == decl->getParent ();
2368
+ };
2369
+
2384
2370
// Resolve the primal declaration, if it exists.
2385
2371
FuncDecl *resolvedPrimal = nullptr ;
2386
2372
if (attr->getPrimal ()) {
2387
2373
auto primalSpecifier = attr->getPrimal ().getValue ();
2374
+ auto primalNameLoc = primalSpecifier.Loc .getBaseNameLoc ();
2375
+
2376
+ auto primalTypeCtx = original->getInnermostTypeContext ();
2377
+ if (!primalTypeCtx) primalTypeCtx = original->getParent ();
2378
+
2379
+ auto primalOverloadDiagnostic = [&]() {
2380
+ TC.diagnose (primalNameLoc,
2381
+ diag::differentiable_attr_primal_overload_not_found,
2382
+ primalSpecifier.Name , originalParamsTy);
2383
+ };
2384
+ auto primalAmbiguousDiagnostic = [&]() {
2385
+ TC.diagnose (primalNameLoc,
2386
+ diag::differentiable_attr_ambiguous_function_identifier,
2387
+ primalSpecifier.Name );
2388
+ };
2389
+ auto primalNotFunctionDiagnostic = [&]() {
2390
+ TC.diagnose (primalNameLoc,
2391
+ diag::differentiable_attr_specified_not_function,
2392
+ primalSpecifier.Name , /* isPrimal*/ true );
2393
+ };
2388
2394
2389
2395
auto isValidPrimal = [&](FuncDecl *primalCandidate) {
2390
2396
// Returns true if the primal candidate
@@ -2417,16 +2423,12 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
2417
2423
return true ;
2418
2424
};
2419
2425
2420
- auto primalOverloadDiagnostic = [&] {
2421
- TC. diagnose (primalSpecifier.Loc . getBaseNameLoc () ,
2422
- diag::differentiable_attr_primal_overload_not_found ,
2423
- primalSpecifier. Name , originalParamsTy);
2424
- } ;
2426
+ resolvedPrimal =
2427
+ getResolvedFuncDecl (primalSpecifier.Name , primalNameLoc ,
2428
+ TC, primalTypeCtx, isValidPrimal, hasValidTypeContext ,
2429
+ primalOverloadDiagnostic, primalAmbiguousDiagnostic,
2430
+ primalNotFunctionDiagnostic) ;
2425
2431
2426
- resolvedPrimal = getResolvedFuncDecl (primalSpecifier, TC, original,
2427
- isValidPrimal,
2428
- primalOverloadDiagnostic,
2429
- /* isPrimal*/ true );
2430
2432
if (!resolvedPrimal) return ;
2431
2433
// Memorize the primal reference in the attribute.
2432
2434
attr->setPrimalFunction (resolvedPrimal);
@@ -2567,26 +2569,45 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
2567
2569
}
2568
2570
2569
2571
// Resolve the adjoint declaration.
2572
+ FuncDecl *resolvedAdjoint = nullptr ;
2573
+ auto adjointSpecifier = attr->getAdjoint ();
2574
+ auto adjointNameLoc = adjointSpecifier.Loc .getBaseNameLoc ();
2575
+
2576
+ auto adjointTypeCtx = original->getInnermostTypeContext ();
2577
+ if (!adjointTypeCtx) adjointTypeCtx = original->getParent ();
2578
+
2579
+ auto adjointOverloadDiagnostic = [&]() {
2580
+ TC.diagnose (adjointNameLoc,
2581
+ diag::differentiable_attr_adjoint_overload_not_found,
2582
+ adjointSpecifier.Name , expectedAdjointFnTy);
2583
+ };
2584
+ auto adjointAmbiguousDiagnostic = [&]() {
2585
+ TC.diagnose (adjointNameLoc,
2586
+ diag::differentiable_attr_ambiguous_function_identifier,
2587
+ adjointSpecifier.Name );
2588
+ };
2589
+ auto adjointNotFunctionDiagnostic = [&]() {
2590
+ TC.diagnose (adjointNameLoc,
2591
+ diag::differentiable_attr_specified_not_function,
2592
+ adjointSpecifier.Name , /* isPrimal*/ false );
2593
+ };
2594
+
2570
2595
auto isValidAdjoint = [&](FuncDecl *adjointCandidate) {
2571
2596
// Returns true if adjoint candidate has the expected type.
2572
2597
auto adjointType = adjointCandidate->getInterfaceType ()
2573
2598
->getUnlabeledType (original->getASTContext ());
2574
2599
return adjointType->isEqual (expectedAdjointFnTy);
2575
2600
};
2576
2601
2577
- auto adjointOverloadDiagnostic = [&] {
2578
- TC. diagnose (attr-> getAdjoint (). Loc . getBaseNameLoc () ,
2579
- diag::differentiable_attr_adjoint_overload_not_found ,
2580
- attr-> getAdjoint (). Name , expectedAdjointFnTy);
2581
- } ;
2602
+ resolvedAdjoint =
2603
+ getResolvedFuncDecl (adjointSpecifier. Name , adjointNameLoc ,
2604
+ TC, adjointTypeCtx, isValidAdjoint, hasValidTypeContext ,
2605
+ adjointOverloadDiagnostic, adjointAmbiguousDiagnostic,
2606
+ adjointNotFunctionDiagnostic) ;
2582
2607
2583
- FuncDecl *resolvedAdjoint = getResolvedFuncDecl (attr->getAdjoint (), TC,
2584
- original, isValidAdjoint,
2585
- adjointOverloadDiagnostic,
2586
- /* isPrimal*/ false );
2587
2608
if (!resolvedAdjoint) return ;
2588
- // Done checking @differentiable attribute. Memorize the adjoint reference in
2589
- // the attribute.
2609
+ // Done checking @differentiable attribute.
2610
+ // Memorize the adjoint reference in the attribute.
2590
2611
attr->setAdjointFunction (resolvedAdjoint);
2591
2612
}
2592
2613
0 commit comments