Skip to content

Ncgenerics test fixes kavon v7 #71515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 42 additions & 89 deletions include/swift/AST/Evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,47 +70,6 @@ class PrettyStackTraceRequest : public llvm::PrettyStackTraceEntry {
}
};

/// An llvm::ErrorInfo container for a request in which a cycle was detected
/// and diagnosed.
template <typename Request>
struct CyclicalRequestError :
public llvm::ErrorInfo<CyclicalRequestError<Request>> {
public:
static char ID;
const Request &request;
const Evaluator &evaluator;

CyclicalRequestError(const Request &request, const Evaluator &evaluator)
: request(request), evaluator(evaluator) {}

virtual void log(llvm::raw_ostream &out) const override;

virtual std::error_code convertToErrorCode() const override {
// This is essentially unused, but is a temporary requirement for
// llvm::ErrorInfo subclasses.
llvm_unreachable("shouldn't get std::error_code from CyclicalRequestError");
}
};

template <typename Request>
char CyclicalRequestError<Request>::ID = '\0';

/// Evaluates a given request or returns a default value if a cycle is detected.
template <typename Request>
typename Request::OutputType
evaluateOrDefault(
Evaluator &eval, Request req, typename Request::OutputType def) {
auto result = eval(req);
if (auto err = result.takeError()) {
llvm::handleAllErrors(std::move(err),
[](const CyclicalRequestError<Request> &E) {
// cycle detected
});
return def;
}
return *result;
}

/// Report that a request of the given kind is being evaluated, so it
/// can be recorded by the stats reporter.
template<typename Request>
Expand Down Expand Up @@ -256,37 +215,26 @@ class Evaluator {

/// Retrieve the result produced by evaluating a request that can
/// be cached.
template<typename Request,
template<typename Request, typename Fn,
typename std::enable_if<Request::isEverCached>::type * = nullptr>
llvm::Expected<typename Request::OutputType>
operator()(const Request &request) {
typename Request::OutputType
operator()(const Request &request, Fn defaultValueFn) {
// The request can be cached, but check a predicate to determine
// whether this particular instance is cached. This allows more
// fine-grained control over which instances get cache.
if (request.isCached())
return getResultCached(request);
return getResultCached(request, std::move(defaultValueFn));

return getResultUncached(request);
return getResultUncached(request, std::move(defaultValueFn));
}

/// Retrieve the result produced by evaluating a request that
/// will never be cached.
template<typename Request,
template<typename Request, typename Fn,
typename std::enable_if<!Request::isEverCached>::type * = nullptr>
llvm::Expected<typename Request::OutputType>
operator()(const Request &request) {
return getResultUncached(request);
}

/// Evaluate a set of requests and return their results as a tuple.
///
/// Use this to describe cases where there are multiple (known)
/// requests that all need to be satisfied.
template<typename ...Requests>
std::tuple<llvm::Expected<typename Requests::OutputType>...>
operator()(const Requests &...requests) {
return std::tuple<llvm::Expected<typename Requests::OutputType>...>(
(*this)(requests)...);
typename Request::OutputType
operator()(const Request &request, Fn defaultValueFn) {
return getResultUncached(request, std::move(defaultValueFn));
}

/// Cache a precomputed value for the given request, so that it will not
Expand All @@ -304,7 +252,9 @@ class Evaluator {
typename std::enable_if<!Request::hasExternalCache>::type* = nullptr>
void cacheOutput(const Request &request,
typename Request::OutputType &&output) {
cache.insert<Request>(request, std::move(output));
bool inserted = cache.insert<Request>(request, std::move(output));
assert(inserted && "Request result was already cached");
(void) inserted;
}

template<typename Request,
Expand Down Expand Up @@ -351,15 +301,14 @@ class Evaluator {
void finishedRequest(const ActiveRequest &request);

/// Produce the result of the request without caching.
template<typename Request>
llvm::Expected<typename Request::OutputType>
getResultUncached(const Request &request) {
template<typename Request, typename Fn>
typename Request::OutputType
getResultUncached(const Request &request, Fn defaultValueFn) {
auto activeReq = ActiveRequest(request);

// Check for a cycle.
if (checkDependency(activeReq)) {
return llvm::Error(
std::make_unique<CyclicalRequestError<Request>>(request, *this));
return defaultValueFn();
}

PrettyStackTraceRequest<Request> prettyStackTrace(request);
Expand All @@ -370,7 +319,7 @@ class Evaluator {

recorder.beginRequest<Request>();

auto &&result = getRequestFunction<Request>()(request, *this);
auto result = getRequestFunction<Request>()(request, *this);

recorder.endRequest<Request>(request);

Expand All @@ -381,16 +330,16 @@ class Evaluator {
// done.
finishedRequest(activeReq);

return std::move(result);
return result;
}

/// Get the result of a request, consulting an external cache
/// provided by the request to retrieve previously-computed results
/// and detect recursion.
template<typename Request,
template<typename Request, typename Fn,
typename std::enable_if<Request::hasExternalCache>::type * = nullptr>
llvm::Expected<typename Request::OutputType>
getResultCached(const Request &request) {
typename Request::OutputType
getResultCached(const Request &request, Fn defaultValueFn) {
// If there is a cached result, return it.
if (auto cached = request.getCachedResult()) {
recorder.replayCachedRequest(request);
Expand All @@ -399,13 +348,10 @@ class Evaluator {
}

// Compute the result.
auto result = getResultUncached(request);
auto result = getResultUncached(request, std::move(defaultValueFn));

// Cache the result if applicable.
if (!result)
return result;

request.cacheResult(*result);
request.cacheResult(result);

// Return it.
return result;
Expand All @@ -414,10 +360,10 @@ class Evaluator {
/// Get the result of a request, consulting the general cache to
/// retrieve previously-computed results and detect recursion.
template<
typename Request,
typename Request, typename Fn,
typename std::enable_if<!Request::hasExternalCache>::type * = nullptr>
llvm::Expected<typename Request::OutputType>
getResultCached(const Request &request) {
typename Request::OutputType
getResultCached(const Request &request, Fn defaultValueFn) {
// If we already have an entry for this request in the cache, return it.
auto known = cache.find_as<Request>(request);
if (known != cache.end<Request>()) {
Expand All @@ -428,12 +374,10 @@ class Evaluator {
}

// Compute the result.
auto result = getResultUncached(request);
if (!result)
return result;
auto result = getResultUncached(request, std::move(defaultValueFn));

// Cache the result.
cache.insert<Request>(request, *result);
cache.insert<Request>(request, result);
return result;
}

Expand Down Expand Up @@ -465,11 +409,20 @@ class Evaluator {
}
};

template <typename Request>
void CyclicalRequestError<Request>::log(llvm::raw_ostream &out) const {
out << "Cycle detected:\n";
simple_display(out, request);
out << "\n";
/// Evaluates a given request or returns a default value if a cycle is detected.
template<typename Request>
typename Request::OutputType
evaluateOrDefault(Evaluator &eval, Request req, typename Request::OutputType def) {
return eval(req, [def]() { return def; });
}

/// Evaluates a given request or returns a default value if a cycle is detected.
template<typename Request>
typename Request::OutputType
evaluateOrFatal(Evaluator &eval, Request req) {
return eval(req, []() -> typename Request::OutputType {
llvm::report_fatal_error("Request cycle");
});
}

} // end namespace evaluator
Expand Down
7 changes: 3 additions & 4 deletions include/swift/AST/RequestCache.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,11 @@ class RequestCache {
}

template <typename Request>
void insert(Request req, typename Request::OutputType val) {
bool insert(Request req, typename Request::OutputType val) {
auto *cache = getCache<Request>();
auto result = cache->insert({RequestKey<Request>(std::move(req)),
std::move(val)});
assert(result.second && "Request result was already cached");
(void) result;
std::move(val)});
return result.second;
}

template <typename Request>
Expand Down
13 changes: 5 additions & 8 deletions include/swift/AST/TypeCheckRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -2008,9 +2008,9 @@ class InferredGenericSignatureRequest :
GenericParamList *,
WhereClauseOwner,
SmallVector<Requirement, 2>,
SmallVector<TypeLoc, 2>,
bool, bool),
RequestFlags::Cached> {
SmallVector<TypeBase *, 2>,
SourceLoc, bool, bool),
RequestFlags::Uncached> {
public:
using SimpleRequest::SimpleRequest;

Expand All @@ -2024,13 +2024,10 @@ class InferredGenericSignatureRequest :
GenericParamList *genericParams,
WhereClauseOwner whereClause,
SmallVector<Requirement, 2> addedRequirements,
SmallVector<TypeLoc, 2> inferenceSources,
bool isExtension, bool allowInverses) const;
SmallVector<TypeBase *, 2> inferenceSources,
SourceLoc loc, bool isExtension, bool allowInverses) const;

public:
// Separate caching.
bool isCached() const { return true; }

/// Inferred generic signature requests don't have source-location info.
SourceLoc getNearestLoc() const {
return SourceLoc();
Expand Down
5 changes: 3 additions & 2 deletions include/swift/AST/TypeCheckerTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,9 @@ SWIFT_REQUEST(TypeChecker, InferredGenericSignatureRequest,
GenericParamList *,
WhereClauseOwner,
SmallVector<Requirement, 2>,
SmallVector<TypeLoc, 2>, bool, bool),
Cached, NoLocationInfo)
SmallVector<TypeBase *, 2>,
SourceLoc, bool, bool),
Uncached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, DistributedModuleIsAvailableRequest,
bool(ModuleDecl *), Cached, NoLocationInfo)
SWIFT_REQUEST(TypeChecker, InheritedTypeRequest,
Expand Down
32 changes: 13 additions & 19 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1330,16 +1330,9 @@ static GenericSignature getPlaceholderGenericSignature(
GenericSignature GenericContext::getGenericSignature() const {
// Don't use evaluateOrDefault() here, because getting the 'default value'
// is slightly expensive here so we don't want to do it eagerly.
auto result = getASTContext().evaluator(
GenericSignatureRequest{const_cast<GenericContext *>(this)});
if (auto err = result.takeError()) {
llvm::handleAllErrors(std::move(err),
[](const CyclicalRequestError<GenericSignatureRequest> &E) {
// cycle detected
});
return getPlaceholderGenericSignature(this);
}
return *result;
return getASTContext().evaluator(
GenericSignatureRequest{const_cast<GenericContext *>(this)},
[this]() { return getPlaceholderGenericSignature(this); });
}

GenericEnvironment *GenericContext::getGenericEnvironment() const {
Expand Down Expand Up @@ -3866,12 +3859,8 @@ bool ValueDecl::isRecursiveValidation() const {

Type ValueDecl::getInterfaceType() const {
auto &ctx = getASTContext();
if (auto type =
evaluateOrDefault(ctx.evaluator,
InterfaceTypeRequest{const_cast<ValueDecl *>(this)},
Type()))
return type;
return ErrorType::get(ctx);
return ctx.evaluator(InterfaceTypeRequest{const_cast<ValueDecl *>(this)},
[&ctx]() { return ErrorType::get(ctx); });
}

void ValueDecl::setInterfaceType(Type type) {
Expand Down Expand Up @@ -6741,7 +6730,6 @@ bool ProtocolDecl::isComputingRequirementSignature() const {
}

void ProtocolDecl::setRequirementSignature(RequirementSignature requirementSig) {
assert(!RequirementSig && "requirement signature already set");
RequirementSig = requirementSig;
}

Expand Down Expand Up @@ -10919,8 +10907,14 @@ Type ClassDecl::getSuperclass() const {

ClassDecl *ClassDecl::getSuperclassDecl() const {
ASTContext &ctx = getASTContext();
return evaluateOrDefault(ctx.evaluator,
SuperclassDeclRequest{const_cast<ClassDecl *>(this)}, nullptr);
auto result = evaluateOrDefault(ctx.evaluator,
SuperclassDeclRequest{const_cast<ClassDecl *>(this)},
const_cast<ClassDecl *>(this));

if (result == this)
return nullptr;

return result;
}

void ClassDecl::setSuperclass(Type superclass) {
Expand Down
14 changes: 4 additions & 10 deletions lib/AST/NameLookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3206,16 +3206,10 @@ SuperclassDeclRequest::evaluate(Evaluator &evaluator,
// inheritance hierarchy by evaluating its superclass. This forces the
// diagnostic at this point and then suppresses the superclass failure.
if (superclass) {
auto result = Ctx.evaluator(SuperclassDeclRequest{superclass});
bool hadCycle = false;
if (auto err = result.takeError()) {
llvm::handleAllErrors(std::move(err),
[&hadCycle](const CyclicalRequestError<SuperclassDeclRequest> &E) {
hadCycle = true;
});

if (hadCycle)
return nullptr;
if (evaluateOrDefault(Ctx.evaluator,
SuperclassDeclRequest{superclass},
superclass) == superclass) {
return nullptr;
}

return superclass;
Expand Down
Loading