Skip to content

Commit 1262193

Browse files
committed
Sema: Use InferredGenericSignatureRequest in TypeCheckAttr.cpp
1 parent d847bc1 commit 1262193

File tree

2 files changed

+36
-64
lines changed

2 files changed

+36
-64
lines changed

lib/Sema/TypeCheckAttr.cpp

Lines changed: 35 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
#include "swift/AST/DiagnosticsParse.h"
2626
#include "swift/AST/Effects.h"
2727
#include "swift/AST/GenericEnvironment.h"
28-
#include "swift/AST/GenericSignatureBuilder.h"
2928
#include "swift/AST/ImportCache.h"
3029
#include "swift/AST/ModuleNameLookup.h"
3130
#include "swift/AST/NameLookup.h"
@@ -2231,28 +2230,17 @@ void AttributeChecker::visitSpecializeAttr(SpecializeAttr *attr) {
22312230
return;
22322231
}
22332232

2234-
// Form a new generic signature based on the old one.
2235-
GenericSignatureBuilder Builder(D->getASTContext());
2233+
InferredGenericSignatureRequest request{
2234+
DC->getParentModule(),
2235+
genericSig.getPointer(),
2236+
/*genericParams=*/nullptr,
2237+
WhereClauseOwner(FD, attr),
2238+
/*addedRequirements=*/{},
2239+
/*inferenceSources=*/{},
2240+
/*allowConcreteGenericParams=*/true};
22362241

2237-
// First, add the old generic signature.
2238-
Builder.addGenericSignature(genericSig);
2239-
2240-
// Go over the set of requirements, adding them to the builder.
2241-
WhereClauseOwner(FD, attr).visitRequirements(TypeResolutionStage::Interface,
2242-
[&](const Requirement &req, RequirementRepr *reqRepr) {
2243-
// Add the requirement to the generic signature builder.
2244-
using FloatingRequirementSource =
2245-
GenericSignatureBuilder::FloatingRequirementSource;
2246-
Builder.addRequirement(req, reqRepr,
2247-
FloatingRequirementSource::forExplicit(
2248-
reqRepr->getSeparatorLoc()),
2249-
nullptr, DC->getParentModule());
2250-
return false;
2251-
});
2252-
2253-
// Check the result.
2254-
auto specializedSig = std::move(Builder).computeGenericSignature(
2255-
/*allowConcreteGenericParams=*/true);
2242+
auto specializedSig = evaluateOrDefault(Ctx.evaluator, request,
2243+
GenericSignature());
22562244

22572245
// Check the validity of provided requirements.
22582246
checkSpecializeAttrRequirements(attr, genericSig, specializedSig, Ctx);
@@ -4266,7 +4254,8 @@ bool resolveDifferentiableAttrDerivativeGenericSignature(
42664254
// - If the `@differentiable` attribute has a `where` clause, use it to
42674255
// compute the derivative generic signature.
42684256
// - Otherwise, use the original function's generic signature by default.
4269-
derivativeGenSig = original->getGenericSignature();
4257+
auto originalGenSig = original->getGenericSignature();
4258+
derivativeGenSig = originalGenSig;
42704259

42714260
// Handle the `where` clause, if it exists.
42724261
// - Resolve attribute where clause requirements and store in the attribute
@@ -4291,7 +4280,6 @@ bool resolveDifferentiableAttrDerivativeGenericSignature(
42914280
return true;
42924281
}
42934282

4294-
auto originalGenSig = original->getGenericSignature();
42954283
if (!originalGenSig) {
42964284
// `where` clauses are valid only when the original function is generic.
42974285
diags
@@ -4304,51 +4292,34 @@ bool resolveDifferentiableAttrDerivativeGenericSignature(
43044292
return true;
43054293
}
43064294

4307-
// Build a new generic signature for autodiff derivative functions.
4308-
GenericSignatureBuilder builder(ctx);
4309-
// Add the original function's generic signature.
4310-
builder.addGenericSignature(originalGenSig);
4311-
4312-
using FloatingRequirementSource =
4313-
GenericSignatureBuilder::FloatingRequirementSource;
4314-
4315-
bool errorOccurred = false;
4316-
WhereClauseOwner(original, attr)
4317-
.visitRequirements(
4318-
TypeResolutionStage::Structural,
4319-
[&](const Requirement &req, RequirementRepr *reqRepr) {
4320-
switch (req.getKind()) {
4321-
case RequirementKind::SameType:
4322-
case RequirementKind::Superclass:
4323-
case RequirementKind::Conformance:
4324-
break;
4325-
4326-
// Layout requirements are not supported.
4327-
case RequirementKind::Layout:
4328-
diags
4329-
.diagnose(attr->getLocation(),
4330-
diag::differentiable_attr_layout_req_unsupported)
4331-
.highlight(reqRepr->getSourceRange());
4332-
errorOccurred = true;
4333-
return false;
4334-
}
4295+
InferredGenericSignatureRequest request{
4296+
original->getParentModule(),
4297+
originalGenSig.getPointer(),
4298+
/*genericParams=*/nullptr,
4299+
WhereClauseOwner(original, attr),
4300+
/*addedRequirements=*/{},
4301+
/*inferenceSources=*/{},
4302+
/*allowConcreteParams=*/true};
4303+
4304+
// Compute generic signature for derivative functions.
4305+
derivativeGenSig = evaluateOrDefault(ctx.evaluator, request,
4306+
GenericSignature());
43354307

4336-
// Add requirement to generic signature builder.
4337-
builder.addRequirement(
4338-
req, reqRepr, FloatingRequirementSource::forExplicit(
4339-
reqRepr->getSeparatorLoc()),
4340-
nullptr, original->getModuleContext());
4341-
return false;
4342-
});
4308+
bool hadInvalidRequirements = false;
4309+
for (auto req : derivativeGenSig.requirementsNotSatisfiedBy(originalGenSig)) {
4310+
if (req.getKind() == RequirementKind::Layout) {
4311+
// Layout requirements are not supported.
4312+
diags
4313+
.diagnose(attr->getLocation(),
4314+
diag::differentiable_attr_layout_req_unsupported);
4315+
hadInvalidRequirements = true;
4316+
}
4317+
}
43434318

4344-
if (errorOccurred) {
4319+
if (hadInvalidRequirements) {
43454320
attr->setInvalid();
43464321
return true;
43474322
}
4348-
4349-
// Compute generic signature for derivative functions.
4350-
derivativeGenSig = std::move(builder).computeGenericSignature(
4351-
/*allowConcreteGenericParams=*/true);
43524323
}
43534324

43544325
attr->setDerivativeGenericSignature(derivativeGenSig);

test/AutoDiff/Sema/differentiable_attr_type_checking.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ func invalidRequirementConformance<Scalar>(x: Scalar) -> Scalar {
191191
return x
192192
}
193193

194+
// expected-error @+1 {{'@differentiable' attribute does not yet support layout requirements}}
194195
@differentiable(reverse where T: AnyObject)
195196
func invalidAnyObjectRequirement<T: Differentiable>(x: T) -> T {
196197
return x

0 commit comments

Comments
 (0)