25
25
#include " swift/AST/DiagnosticsParse.h"
26
26
#include " swift/AST/Effects.h"
27
27
#include " swift/AST/GenericEnvironment.h"
28
- #include " swift/AST/GenericSignatureBuilder.h"
29
28
#include " swift/AST/ImportCache.h"
30
29
#include " swift/AST/ModuleNameLookup.h"
31
30
#include " swift/AST/NameLookup.h"
@@ -2231,28 +2230,17 @@ void AttributeChecker::visitSpecializeAttr(SpecializeAttr *attr) {
2231
2230
return ;
2232
2231
}
2233
2232
2234
- // Form a new generic signature based on the old one.
2235
- GenericSignatureBuilder Builder (D->getASTContext ());
2233
+ InferredGenericSignatureRequest request{
2234
+ DC->getParentModule (),
2235
+ genericSig.getPointer (),
2236
+ /* genericParams=*/ nullptr ,
2237
+ WhereClauseOwner (FD, attr),
2238
+ /* addedRequirements=*/ {},
2239
+ /* inferenceSources=*/ {},
2240
+ /* allowConcreteGenericParams=*/ true };
2236
2241
2237
- // First, add the old generic signature.
2238
- Builder.addGenericSignature (genericSig);
2239
-
2240
- // Go over the set of requirements, adding them to the builder.
2241
- WhereClauseOwner (FD, attr).visitRequirements (TypeResolutionStage::Interface,
2242
- [&](const Requirement &req, RequirementRepr *reqRepr) {
2243
- // Add the requirement to the generic signature builder.
2244
- using FloatingRequirementSource =
2245
- GenericSignatureBuilder::FloatingRequirementSource;
2246
- Builder.addRequirement (req, reqRepr,
2247
- FloatingRequirementSource::forExplicit (
2248
- reqRepr->getSeparatorLoc ()),
2249
- nullptr , DC->getParentModule ());
2250
- return false ;
2251
- });
2252
-
2253
- // Check the result.
2254
- auto specializedSig = std::move (Builder).computeGenericSignature (
2255
- /* allowConcreteGenericParams=*/ true );
2242
+ auto specializedSig = evaluateOrDefault (Ctx.evaluator , request,
2243
+ GenericSignature ());
2256
2244
2257
2245
// Check the validity of provided requirements.
2258
2246
checkSpecializeAttrRequirements (attr, genericSig, specializedSig, Ctx);
@@ -4266,7 +4254,8 @@ bool resolveDifferentiableAttrDerivativeGenericSignature(
4266
4254
// - If the `@differentiable` attribute has a `where` clause, use it to
4267
4255
// compute the derivative generic signature.
4268
4256
// - Otherwise, use the original function's generic signature by default.
4269
- derivativeGenSig = original->getGenericSignature ();
4257
+ auto originalGenSig = original->getGenericSignature ();
4258
+ derivativeGenSig = originalGenSig;
4270
4259
4271
4260
// Handle the `where` clause, if it exists.
4272
4261
// - Resolve attribute where clause requirements and store in the attribute
@@ -4291,7 +4280,6 @@ bool resolveDifferentiableAttrDerivativeGenericSignature(
4291
4280
return true ;
4292
4281
}
4293
4282
4294
- auto originalGenSig = original->getGenericSignature ();
4295
4283
if (!originalGenSig) {
4296
4284
// `where` clauses are valid only when the original function is generic.
4297
4285
diags
@@ -4304,51 +4292,34 @@ bool resolveDifferentiableAttrDerivativeGenericSignature(
4304
4292
return true ;
4305
4293
}
4306
4294
4307
- // Build a new generic signature for autodiff derivative functions.
4308
- GenericSignatureBuilder builder (ctx);
4309
- // Add the original function's generic signature.
4310
- builder.addGenericSignature (originalGenSig);
4311
-
4312
- using FloatingRequirementSource =
4313
- GenericSignatureBuilder::FloatingRequirementSource;
4314
-
4315
- bool errorOccurred = false ;
4316
- WhereClauseOwner (original, attr)
4317
- .visitRequirements (
4318
- TypeResolutionStage::Structural,
4319
- [&](const Requirement &req, RequirementRepr *reqRepr) {
4320
- switch (req.getKind ()) {
4321
- case RequirementKind::SameType:
4322
- case RequirementKind::Superclass:
4323
- case RequirementKind::Conformance:
4324
- break ;
4325
-
4326
- // Layout requirements are not supported.
4327
- case RequirementKind::Layout:
4328
- diags
4329
- .diagnose (attr->getLocation (),
4330
- diag::differentiable_attr_layout_req_unsupported)
4331
- .highlight (reqRepr->getSourceRange ());
4332
- errorOccurred = true ;
4333
- return false ;
4334
- }
4295
+ InferredGenericSignatureRequest request{
4296
+ original->getParentModule (),
4297
+ originalGenSig.getPointer (),
4298
+ /* genericParams=*/ nullptr ,
4299
+ WhereClauseOwner (original, attr),
4300
+ /* addedRequirements=*/ {},
4301
+ /* inferenceSources=*/ {},
4302
+ /* allowConcreteParams=*/ true };
4303
+
4304
+ // Compute generic signature for derivative functions.
4305
+ derivativeGenSig = evaluateOrDefault (ctx.evaluator , request,
4306
+ GenericSignature ());
4335
4307
4336
- // Add requirement to generic signature builder.
4337
- builder.addRequirement (
4338
- req, reqRepr, FloatingRequirementSource::forExplicit (
4339
- reqRepr->getSeparatorLoc ()),
4340
- nullptr , original->getModuleContext ());
4341
- return false ;
4342
- });
4308
+ bool hadInvalidRequirements = false ;
4309
+ for (auto req : derivativeGenSig.requirementsNotSatisfiedBy (originalGenSig)) {
4310
+ if (req.getKind () == RequirementKind::Layout) {
4311
+ // Layout requirements are not supported.
4312
+ diags
4313
+ .diagnose (attr->getLocation (),
4314
+ diag::differentiable_attr_layout_req_unsupported);
4315
+ hadInvalidRequirements = true ;
4316
+ }
4317
+ }
4343
4318
4344
- if (errorOccurred ) {
4319
+ if (hadInvalidRequirements ) {
4345
4320
attr->setInvalid ();
4346
4321
return true ;
4347
4322
}
4348
-
4349
- // Compute generic signature for derivative functions.
4350
- derivativeGenSig = std::move (builder).computeGenericSignature (
4351
- /* allowConcreteGenericParams=*/ true );
4352
4323
}
4353
4324
4354
4325
attr->setDerivativeGenericSignature (derivativeGenSig);
0 commit comments