Skip to content

Commit 1c2b702

Browse files
committed
---
yaml --- r: 311262 b: refs/heads/tensorflow-merge c: cbad7aa h: refs/heads/master
1 parent 8a85446 commit 1c2b702

File tree

4 files changed

+164
-143
lines changed

4 files changed

+164
-143
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1379,7 +1379,7 @@ refs/heads/chase-my-tail: 8bb91443a9e81bbfac92a2621a0af887a1da8dbf
13791379
refs/heads/consider-outer-alternatives: 708bac749ec60a22a79e2eefbe734f9488a7370d
13801380
refs/heads/revert-25740-oops-i-linked-it-again: fdd41aeb682fc488572bdc1cf71b2ff6997ba576
13811381
refs/heads/swift-5.1-branch-06-12-2019: e63b7b2d3b93c48232d386099d0ec525d21d8f8d
1382-
refs/heads/tensorflow-merge: b95ae5c94d596f38d39dcefd7853dab8b58b29a0
1382+
refs/heads/tensorflow-merge: cbad7aac3a6d6edb321919ae39f6353429d7f820
13831383
refs/heads/update-checkout-sha-info: 5832743c5c2a842976c42a508a4c6dcceefb0aef
13841384
refs/tags/swift-5.1-DEVELOPMENT-SNAPSHOT-2019-06-12-a: 228f0448d9bb909aacbba4afcb7c600a405d15da
13851385
refs/tags/swift-5.1-DEVELOPMENT-SNAPSHOT-2019-06-14-a: 922861a77b5fc2bf46bc917da70ceb15eef76836

branches/tensorflow-merge/include/swift/AST/DiagnosticsSema.def

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2502,9 +2502,9 @@ ERROR(differentiable_attr_cannot_diff_wrt_objects_or_existentials,none,
25022502
"class objects and protocol existentials (%0) cannot be differentiated "
25032503
"with respect to", (Type))
25042504
ERROR(differentiable_attr_function_not_same_type_context,none,
2505-
"%0 is not defined in the current declaration context", (DeclName))
2505+
"%0 is not defined in the current type context", (DeclName))
25062506
ERROR(differentiable_attr_specified_not_function,none,
2507-
"%0 is not a function to be used as %select(primal|adjoint)",
2507+
"%0 is not a function to be used as %select{primal|adjoint}1",
25082508
(DeclName, bool))
25092509
ERROR(differentiable_attr_ambiguous_function_identifier,none,
25102510
"ambiguous or overloaded identifier %0 cannot be used in @differentiable "

branches/tensorflow-merge/lib/Sema/TypeCheckAttr.cpp

Lines changed: 150 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -2231,124 +2231,92 @@ void AttributeChecker::visitFrozenAttr(FrozenAttr *attr) {
22312231
}
22322232
}
22332233

2234+
// SWIFT_ENABLE_TENSORFLOW
2235+
// Returns the function declaration corresponding to the given function name and
2236+
// lookup context. If the function declaration cannot be resolved, emits a
2237+
// diagnostic and returns nullptr.
22342238
static FuncDecl *getResolvedFuncDecl(
2235-
DifferentiableAttr::FunctionSpecifier funcSpecifier,
2236-
TypeChecker &TC, FuncDecl *original,
2237-
const std::function<bool(FuncDecl *)> &isValidLookupResult,
2238-
const std::function<void()> &overloadDiagnostic, bool isPrimal) {
2239+
DeclName funcName, SourceLoc funcNameLoc, TypeChecker &TC,
2240+
DeclContext *lookupContext,
2241+
const std::function<bool(FuncDecl *)> &isValidFuncDecl,
2242+
const std::function<bool(FuncDecl *)> &hasValidTypeContext,
2243+
const std::function<void()> &overloadDiagnostic,
2244+
const std::function<void()> &ambiguousDiagnostic,
2245+
const std::function<void()> &notFunctionDiagnostic) {
2246+
22392247
FuncDecl *resolvedFuncDecl = nullptr;
2240-
UnresolvedDeclRefExpr UDRE(funcSpecifier.Name, DeclRefKind::Ordinary,
2241-
funcSpecifier.Loc);
2242-
auto expr = TC.resolveDeclRefExpr(&UDRE, original->getInnermostDeclContext());
2243-
// If it's an unresolved dot expression, it must be a class method or an
2244-
// instance method.
2245-
if (auto dotExpr = dyn_cast<UnresolvedDotExpr>(expr)) {
2246-
// Look up the decl name directly in the current type context.
2247-
auto typeCtx = original->getInnermostTypeContext();
2248-
auto lookupResult = TC.lookupMember(typeCtx,
2249-
typeCtx->getDeclaredInterfaceType(),
2250-
funcSpecifier.Name);
2251-
// Declare error flags.
2252-
bool exprIsNotFunction = false;
2253-
bool overloadNotFound = false;
2254-
2255-
for (auto lookupEntry : lookupResult) {
2256-
auto funcDecl = dyn_cast<FuncDecl>(lookupEntry.getValueDecl());
2257-
// Set flag if the lookup result is not a function declaration.
2258-
if (!funcDecl) {
2259-
exprIsNotFunction = true;
2260-
continue;
2261-
}
2262-
// Set flag if lookup result is invalid.
2263-
if (!isValidLookupResult(funcDecl)) {
2264-
overloadNotFound = true;
2265-
continue;
2266-
}
2267-
// If more than one lookup result has the expected function type, then
2268-
// the function is ambgiuous.
2269-
if (resolvedFuncDecl) {
2270-
TC.diagnose(funcSpecifier.Loc.getBaseNameLoc(),
2271-
diag::differentiable_attr_ambiguous_function_identifier,
2272-
funcSpecifier.Name);
2273-
return nullptr;
2274-
}
2275-
// Resolve the function declaration.
2276-
resolvedFuncDecl = funcDecl;
2277-
}
2278-
// If the function declaration could not be resolved, check error flags.
2279-
if (!resolvedFuncDecl) {
2280-
if (overloadNotFound) {
2281-
overloadDiagnostic();
2282-
return nullptr;
2283-
}
2284-
assert(exprIsNotFunction && "Function declaration could not be resolved");
2285-
TC.diagnose(funcSpecifier.Loc.getBaseNameLoc(),
2286-
diag::differentiable_attr_specified_not_function,
2287-
funcSpecifier.Name, isPrimal);
2288-
return nullptr;
2248+
2249+
// Initialize error flags.
2250+
bool notAFuncDecl = false;
2251+
bool wrongTypeContext = false;
2252+
bool overloadNotFound = false;
2253+
2254+
// Perform lookup, ignoring access control.
2255+
auto options = defaultUnqualifiedLookupOptions |
2256+
NameLookupFlags::IgnoreAccessControl;
2257+
auto results =
2258+
TC.lookupUnqualified(lookupContext, funcName, funcNameLoc, options);
2259+
2260+
// Note: static methods are omitted from `TypeChecker.lookupUnqualified` in
2261+
// Swift 3. The code below is a workaround for resolving them.
2262+
//
2263+
// This is necessary because the stdlib is compiled with `-swift-version 3`
2264+
// for Swift 3 compatibility, and floating point types use the
2265+
// `@differentiable` attribute with static adjoint methods (such as
2266+
// `_adjointAdd`).
2267+
if (lookupContext->getASTContext().isSwiftVersion3() && results.empty() &&
2268+
lookupContext->isTypeContext()) {
2269+
auto tmp = TC.lookupMember(lookupContext,
2270+
lookupContext->getSelfTypeInContext(), funcName);
2271+
for (auto choice : tmp) {
2272+
auto decl = choice.getValueDecl();
2273+
if (!decl) continue;
2274+
auto funcDecl = dyn_cast<FuncDecl>(decl);
2275+
if (!funcDecl) continue;
2276+
results.add(LookupResultEntry(funcDecl));
22892277
}
22902278
}
2291-
// If it's resolved to a type, it's not what we want.
2292-
else if (isa<TypeExpr>(expr))
2293-
TC.diagnose(funcSpecifier.Loc.getBaseNameLoc(),
2294-
diag::differentiable_attr_specified_not_function,
2295-
funcSpecifier.Name, isPrimal);
2296-
// If it's directly resolved to a concrete declaration, it must be a free
2297-
// function in the module context.
2298-
else if (auto declRefExpr = dyn_cast<DeclRefExpr>(expr)) {
2299-
auto funcDecl = dyn_cast<FuncDecl>(declRefExpr->getDecl());
2300-
// If the candidate is not a function, then it's an error.
2301-
if (!funcDecl) {
2302-
TC.diagnose(funcSpecifier.Loc.getBaseNameLoc(),
2303-
diag::differentiable_attr_specified_not_function,
2304-
funcSpecifier.Name, isPrimal);
2305-
return nullptr;
2306-
}
23072279

2308-
// If the original and the primal or adjoint have different parents, or
2309-
// if they both have no type context and are in different modules, then
2310-
// it's an error.
2311-
auto inCompatibleContexts = [&](FuncDecl *decl1, FuncDecl *decl2) {
2312-
if (!decl1->getInnermostTypeContext() &&
2313-
!decl2->getInnermostTypeContext() &&
2314-
decl1->getParentModule() == decl2->getParentModule())
2315-
return true;
2316-
if (decl1->getParent() == decl2->getParent())
2317-
return true;
2318-
return false;
2319-
};
2280+
for (auto choice : results) {
2281+
auto decl = choice.getValueDecl();
2282+
if (!decl) continue;
23202283

2321-
if (!inCompatibleContexts(original, funcDecl)) {
2322-
TC.diagnose(funcSpecifier.Loc.getBaseNameLoc(),
2323-
diag::differentiable_attr_function_not_same_type_context,
2324-
funcSpecifier.Name);
2325-
return nullptr;
2284+
auto funcDecl = dyn_cast<FuncDecl>(decl);
2285+
if (!funcDecl) {
2286+
notAFuncDecl = true;
2287+
continue;
2288+
}
2289+
if (!hasValidTypeContext(funcDecl)) {
2290+
wrongTypeContext = true;
2291+
continue;
2292+
}
2293+
if (!isValidFuncDecl(funcDecl)) {
2294+
overloadNotFound = true;
2295+
continue;
2296+
}
2297+
if (resolvedFuncDecl) {
2298+
ambiguousDiagnostic();
2299+
resolvedFuncDecl = nullptr;
2300+
break;
23262301
}
2327-
// Otherwise, the original and the function are declared in the same
2328-
// context. Save this candidate for further type checking.
23292302
resolvedFuncDecl = funcDecl;
23302303
}
2331-
// Overloaded names are not supported.
2332-
// FIXME: Resolve using the expected function type.
2333-
else if (isa<OverloadedDeclRefExpr>(expr)) {
2334-
TC.diagnose(funcSpecifier.Loc.getBaseNameLoc(),
2335-
diag::differentiable_attr_ambiguous_function_identifier,
2336-
funcSpecifier.Name);
2337-
return nullptr;
2338-
}
2339-
// Error expressions have been handled already.
2340-
else if (isa<ErrorExpr>(expr))
2341-
return nullptr; // Diagnostics already emitted.
2342-
else
2343-
llvm_unreachable("Unhandled expr kind");
2344-
2345-
assert(resolvedFuncDecl && "Function declaration should have been resolved");
2346-
2347-
// Perform an additional check to handle cases that were not
2348-
// UnresolvedDotExpr.
2349-
if (!isValidLookupResult(resolvedFuncDecl)) {
2350-
overloadDiagnostic();
2351-
return nullptr;
2304+
// If function declaration could not be resolved, emit the appropriate
2305+
// diagnostic.
2306+
if (!resolvedFuncDecl) {
2307+
if (results.empty()) {
2308+
TC.diagnose(funcNameLoc, diag::use_unresolved_identifier, funcName,
2309+
funcName.isOperator());
2310+
} else if (wrongTypeContext) {
2311+
TC.diagnose(funcNameLoc,
2312+
diag::differentiable_attr_function_not_same_type_context,
2313+
funcName);
2314+
} else if (overloadNotFound) {
2315+
overloadDiagnostic();
2316+
} else {
2317+
assert(notAFuncDecl && "Expected 'not a function' error");
2318+
notFunctionDiagnostic();
2319+
}
23522320
}
23532321

23542322
return resolvedFuncDecl;
@@ -2381,10 +2349,48 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
23812349
auto originalParamsTy =
23822350
originalParams.getInterfaceType(original->getASTContext());
23832351

2352+
// If the original function and the primal/adjoint have different parents, or
2353+
// if they both have no type context and are in different modules, then
2354+
// it's an error.
2355+
auto hasValidTypeContext = [&](FuncDecl *decl) {
2356+
if (!original->getInnermostTypeContext() &&
2357+
!decl->getInnermostTypeContext() &&
2358+
original->getParentModule() == decl->getParentModule())
2359+
return true;
2360+
if (auto typeCtx1 = original->getInnermostTypeContext()) {
2361+
if (auto typeCtx2 = decl->getInnermostTypeContext()) {
2362+
auto type1 = typeCtx1->getDeclaredInterfaceType();
2363+
auto type2 = typeCtx2->getDeclaredInterfaceType();
2364+
return type1->isEqual(type2);
2365+
}
2366+
}
2367+
return original->getParent() == decl->getParent();
2368+
};
2369+
23842370
// Resolve the primal declaration, if it exists.
23852371
FuncDecl *resolvedPrimal = nullptr;
23862372
if (attr->getPrimal()) {
23872373
auto primalSpecifier = attr->getPrimal().getValue();
2374+
auto primalNameLoc = primalSpecifier.Loc.getBaseNameLoc();
2375+
2376+
auto primalTypeCtx = original->getInnermostTypeContext();
2377+
if (!primalTypeCtx) primalTypeCtx = original->getParent();
2378+
2379+
auto primalOverloadDiagnostic = [&]() {
2380+
TC.diagnose(primalNameLoc,
2381+
diag::differentiable_attr_primal_overload_not_found,
2382+
primalSpecifier.Name, originalParamsTy);
2383+
};
2384+
auto primalAmbiguousDiagnostic = [&]() {
2385+
TC.diagnose(primalNameLoc,
2386+
diag::differentiable_attr_ambiguous_function_identifier,
2387+
primalSpecifier.Name);
2388+
};
2389+
auto primalNotFunctionDiagnostic = [&]() {
2390+
TC.diagnose(primalNameLoc,
2391+
diag::differentiable_attr_specified_not_function,
2392+
primalSpecifier.Name, /*isPrimal*/ true);
2393+
};
23882394

23892395
auto isValidPrimal = [&](FuncDecl *primalCandidate) {
23902396
// Returns true if the primal candidate
@@ -2417,16 +2423,12 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
24172423
return true;
24182424
};
24192425

2420-
auto primalOverloadDiagnostic = [&] {
2421-
TC.diagnose(primalSpecifier.Loc.getBaseNameLoc(),
2422-
diag::differentiable_attr_primal_overload_not_found,
2423-
primalSpecifier.Name, originalParamsTy);
2424-
};
2426+
resolvedPrimal =
2427+
getResolvedFuncDecl(primalSpecifier.Name, primalNameLoc,
2428+
TC, primalTypeCtx, isValidPrimal, hasValidTypeContext,
2429+
primalOverloadDiagnostic, primalAmbiguousDiagnostic,
2430+
primalNotFunctionDiagnostic);
24252431

2426-
resolvedPrimal = getResolvedFuncDecl(primalSpecifier, TC, original,
2427-
isValidPrimal,
2428-
primalOverloadDiagnostic,
2429-
/*isPrimal*/ true);
24302432
if (!resolvedPrimal) return;
24312433
// Memorize the primal reference in the attribute.
24322434
attr->setPrimalFunction(resolvedPrimal);
@@ -2567,26 +2569,45 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
25672569
}
25682570

25692571
// Resolve the adjoint declaration.
2572+
FuncDecl *resolvedAdjoint = nullptr;
2573+
auto adjointSpecifier = attr->getAdjoint();
2574+
auto adjointNameLoc = adjointSpecifier.Loc.getBaseNameLoc();
2575+
2576+
auto adjointTypeCtx = original->getInnermostTypeContext();
2577+
if (!adjointTypeCtx) adjointTypeCtx = original->getParent();
2578+
2579+
auto adjointOverloadDiagnostic = [&]() {
2580+
TC.diagnose(adjointNameLoc,
2581+
diag::differentiable_attr_adjoint_overload_not_found,
2582+
adjointSpecifier.Name, expectedAdjointFnTy);
2583+
};
2584+
auto adjointAmbiguousDiagnostic = [&]() {
2585+
TC.diagnose(adjointNameLoc,
2586+
diag::differentiable_attr_ambiguous_function_identifier,
2587+
adjointSpecifier.Name);
2588+
};
2589+
auto adjointNotFunctionDiagnostic = [&]() {
2590+
TC.diagnose(adjointNameLoc,
2591+
diag::differentiable_attr_specified_not_function,
2592+
adjointSpecifier.Name, /*isPrimal*/ false);
2593+
};
2594+
25702595
auto isValidAdjoint = [&](FuncDecl *adjointCandidate) {
25712596
// Returns true if adjoint candidate has the expected type.
25722597
auto adjointType = adjointCandidate->getInterfaceType()
25732598
->getUnlabeledType(original->getASTContext());
25742599
return adjointType->isEqual(expectedAdjointFnTy);
25752600
};
25762601

2577-
auto adjointOverloadDiagnostic = [&] {
2578-
TC.diagnose(attr->getAdjoint().Loc.getBaseNameLoc(),
2579-
diag::differentiable_attr_adjoint_overload_not_found,
2580-
attr->getAdjoint().Name, expectedAdjointFnTy);
2581-
};
2602+
resolvedAdjoint =
2603+
getResolvedFuncDecl(adjointSpecifier.Name, adjointNameLoc,
2604+
TC, adjointTypeCtx, isValidAdjoint, hasValidTypeContext,
2605+
adjointOverloadDiagnostic, adjointAmbiguousDiagnostic,
2606+
adjointNotFunctionDiagnostic);
25822607

2583-
FuncDecl *resolvedAdjoint = getResolvedFuncDecl(attr->getAdjoint(), TC,
2584-
original, isValidAdjoint,
2585-
adjointOverloadDiagnostic,
2586-
/*isPrimal*/ false);
25872608
if (!resolvedAdjoint) return;
2588-
// Done checking @differentiable attribute. Memorize the adjoint reference in
2589-
// the attribute.
2609+
// Done checking @differentiable attribute.
2610+
// Memorize the adjoint reference in the attribute.
25902611
attr->setAdjointFunction(resolvedAdjoint);
25912612
}
25922613

0 commit comments

Comments
 (0)