|
15 | 15 | //===----------------------------------------------------------------------===//
|
16 | 16 |
|
17 | 17 | #include "TypeChecker.h"
|
| 18 | +#include "swift/AST/ExistentialLayout.h" |
18 | 19 | #include "swift/AST/GenericSignature.h"
|
| 20 | +#include "swift/Basic/OptionSet.h" |
19 | 21 | #include "swift/Sema/ConstraintGraph.h"
|
20 | 22 | #include "swift/Sema/ConstraintSystem.h"
|
21 | 23 | #include "llvm/ADT/BitVector.h"
|
22 | 24 | #include "llvm/ADT/DenseMap.h"
|
23 | 25 | #include "llvm/ADT/SmallVector.h"
|
24 | 26 | #include "llvm/ADT/TinyPtrVector.h"
|
| 27 | +#include "llvm/Support/SaveAndRestore.h" |
25 | 28 | #include "llvm/Support/raw_ostream.h"
|
26 | 29 | #include <cstddef>
|
27 | 30 | #include <functional>
|
@@ -187,6 +190,162 @@ static void determineBestChoicesInContext(
|
187 | 190 | /*allow fixes*/ false, listener, None);
|
188 | 191 | };
|
189 | 192 |
|
| 193 | + // Determine whether the candidate type is a subclass of the superclass |
| 194 | + // type. |
| 195 | + std::function<bool(Type, Type)> isSubclassOf = [&](Type candidateType, |
| 196 | + Type superclassType) { |
| 197 | + // Conversion from a concrete type to its existential value. |
| 198 | + if (superclassType->isExistentialType() && !superclassType->isAny()) { |
| 199 | + auto layout = superclassType->getExistentialLayout(); |
| 200 | + |
| 201 | + if (auto layoutConstraint = layout.getLayoutConstraint()) { |
| 202 | + if (layoutConstraint->isClass() && |
| 203 | + !(candidateType->isClassExistentialType() || |
| 204 | + candidateType->mayHaveSuperclass())) |
| 205 | + return false; |
| 206 | + } |
| 207 | + |
| 208 | + if (layout.explicitSuperclass && |
| 209 | + !isSubclassOf(candidateType, layout.explicitSuperclass)) |
| 210 | + return false; |
| 211 | + |
| 212 | + return llvm::all_of(layout.getProtocols(), [&](ProtocolDecl *P) { |
| 213 | + if (auto superclass = P->getSuperclass()) { |
| 214 | + if (!isSubclassOf(candidateType, superclass)) |
| 215 | + return false; |
| 216 | + } |
| 217 | + |
| 218 | + return bool(TypeChecker::containsProtocol( |
| 219 | + candidateType, P, cs.DC->getParentModule(), |
| 220 | + /*skipConditionalRequirements=*/true, |
| 221 | + /*allowMissing=*/false)); |
| 222 | + }); |
| 223 | + } |
| 224 | + |
| 225 | + auto *subclassDecl = candidateType->getClassOrBoundGenericClass(); |
| 226 | + auto *superclassDecl = superclassType->getClassOrBoundGenericClass(); |
| 227 | + |
| 228 | + if (!(subclassDecl && superclassDecl)) |
| 229 | + return false; |
| 230 | + |
| 231 | + return superclassDecl->isSuperclassOf(subclassDecl); |
| 232 | + }; |
| 233 | + |
| 234 | + enum class MatchFlag { |
| 235 | + OnParam = 0x01, |
| 236 | + Literal = 0x02, |
| 237 | + }; |
| 238 | + |
| 239 | + using MatchOptions = OptionSet<MatchFlag>; |
| 240 | + |
| 241 | + // Perform a limited set of checks to determine whether the candidate |
| 242 | + // could possibly match the parameter type: |
| 243 | + // |
| 244 | + // - Equality |
| 245 | + // - Protocol conformance(s) |
| 246 | + // - Optional injection |
| 247 | + // - Superclass conversion |
| 248 | + // - Array-to-pointer conversion |
| 249 | + // - Value to existential conversion |
| 250 | + // - Exact match on top-level types |
| 251 | + std::function<double(GenericSignature, Type, Type, MatchOptions)> |
| 252 | + scoreCandidateMatch = [&](GenericSignature genericSig, |
| 253 | + Type candidateType, Type paramType, |
| 254 | + MatchOptions options) -> double { |
| 255 | + // Dependent members cannot be handled here because |
| 256 | + // they require substitution of the base type which |
| 257 | + // could come from a different argument. |
| 258 | + if (paramType->getAs<DependentMemberType>()) |
| 259 | + return 0; |
| 260 | + |
| 261 | + // Exact match between candidate and parameter types. |
| 262 | + if (candidateType->isEqual(paramType)) |
| 263 | + return options.contains(MatchFlag::Literal) ? 0.3 : 1; |
| 264 | + |
| 265 | + if (options.contains(MatchFlag::Literal)) |
| 266 | + return 0; |
| 267 | + |
| 268 | + // Check whether match would require optional injection. |
| 269 | + { |
| 270 | + SmallVector<Type, 2> candidateOptionals; |
| 271 | + SmallVector<Type, 2> paramOptionals; |
| 272 | + |
| 273 | + candidateType = |
| 274 | + candidateType->lookThroughAllOptionalTypes(candidateOptionals); |
| 275 | + paramType = paramType->lookThroughAllOptionalTypes(paramOptionals); |
| 276 | + |
| 277 | + if (!candidateOptionals.empty() || !paramOptionals.empty()) { |
| 278 | + if (paramOptionals.size() >= candidateOptionals.size()) |
| 279 | + return scoreCandidateMatch(genericSig, candidateType, paramType, |
| 280 | + options); |
| 281 | + |
| 282 | + // Optionality mismatch. |
| 283 | + return 0; |
| 284 | + } |
| 285 | + } |
| 286 | + |
| 287 | + // Candidate could be injected into optional parameter type |
| 288 | + // or converted to a superclass. |
| 289 | + if (isSubclassOf(candidateType, paramType)) |
| 290 | + return 1; |
| 291 | + |
| 292 | + // Possible Array<T> -> Unsafe*Pointer conversion. |
| 293 | + if (options.contains(MatchFlag::OnParam)) { |
| 294 | + if (candidateType->isArrayType() && |
| 295 | + paramType->getAnyPointerElementType()) |
| 296 | + return 1; |
| 297 | + } |
| 298 | + |
| 299 | + // If both argument and parameter are tuples of the same arity, |
| 300 | + // it's a match. |
| 301 | + { |
| 302 | + if (auto *candidateTuple = candidateType->getAs<TupleType>()) { |
| 303 | + auto *paramTuple = paramType->getAs<TupleType>(); |
| 304 | + if (paramTuple && |
| 305 | + candidateTuple->getNumElements() == paramTuple->getNumElements()) |
| 306 | + return 1; |
| 307 | + } |
| 308 | + } |
| 309 | + |
| 310 | + // Check protocol requirement(s) if this parameter is a |
| 311 | + // generic parameter type. |
| 312 | + GenericSignature::RequiredProtocols protocolRequirements; |
| 313 | + if (genericSig) { |
| 314 | + if (auto *GP = paramType->getAs<GenericTypeParamType>()) { |
| 315 | + protocolRequirements = genericSig->getRequiredProtocols(GP); |
| 316 | + // It's a generic parameter which might be connected via |
| 317 | + // same-type constraints to other generic parameters but |
| 318 | + // we cannot check that here, so let's add a tiny score |
| 319 | + // just to acknowledge that it could possibly match. |
| 320 | + if (protocolRequirements.empty()) { |
| 321 | + return 0.01; |
| 322 | + } |
| 323 | + |
| 324 | + if (llvm::all_of(protocolRequirements, [&](ProtocolDecl *protocol) { |
| 325 | + return TypeChecker::conformsToProtocol(candidateType, protocol, |
| 326 | + cs.DC->getParentModule(), |
| 327 | + /*allowMissing=*/false); |
| 328 | + })) |
| 329 | + return 0.7; |
| 330 | + } |
| 331 | + } |
| 332 | + |
| 333 | + // Parameter is generic, let's check whether top-level |
| 334 | + // types match i.e. Array<Element> as a parameter. |
| 335 | + // |
| 336 | + // This is slightly better than all of the conformances matching |
| 337 | + // because the parameter is concrete and could split the graph. |
| 338 | + if (paramType->hasTypeParameter()) { |
| 339 | + auto *candidateDecl = candidateType->getAnyNominal(); |
| 340 | + auto *paramDecl = paramType->getAnyNominal(); |
| 341 | + |
| 342 | + if (candidateDecl && paramDecl && candidateDecl == paramDecl) |
| 343 | + return 0.8; |
| 344 | + } |
| 345 | + |
| 346 | + return 0; |
| 347 | + }; |
| 348 | + |
190 | 349 | // The choice with the best score.
|
191 | 350 | double bestScore = 0.0;
|
192 | 351 | SmallVector<std::pair<Constraint *, double>, 2> favoredChoices;
|
@@ -256,23 +415,6 @@ static void determineBestChoicesInContext(
|
256 | 415 | if (paramType->is<FunctionType>())
|
257 | 416 | continue;
|
258 | 417 |
|
259 |
| - // Check protocol requirement(s) if this parameter is a |
260 |
| - // generic parameter type. |
261 |
| - GenericSignature::RequiredProtocols protocolRequirements; |
262 |
| - if (genericSig) { |
263 |
| - if (auto *GP = paramType->getAs<GenericTypeParamType>()) { |
264 |
| - protocolRequirements = genericSig->getRequiredProtocols(GP); |
265 |
| - // It's a generic parameter which might be connected via |
266 |
| - // same-type constraints to other generic parameters but |
267 |
| - // we cannot check that here, so let's ignore it. |
268 |
| - if (protocolRequirements.empty()) |
269 |
| - continue; |
270 |
| - } |
271 |
| - |
272 |
| - if (paramType->getAs<DependentMemberType>()) |
273 |
| - return; |
274 |
| - } |
275 |
| - |
276 | 418 | // The idea here is to match the parameter type against
|
277 | 419 | // all of the argument candidate types and pick the best
|
278 | 420 | // match (i.e. exact equality one).
|
@@ -306,32 +448,14 @@ static void determineBestChoicesInContext(
|
306 | 448 | // The specifier only matters for `inout` check.
|
307 | 449 | candidateType = candidateType->getWithoutSpecifierType();
|
308 | 450 |
|
309 |
| - // We don't check generic requirements against literal default |
310 |
| - // types because it creates more noise than signal for operators. |
311 |
| - if (!protocolRequirements.empty() && !isLiteralDefault) { |
312 |
| - if (llvm::all_of( |
313 |
| - protocolRequirements, [&](ProtocolDecl *protocol) { |
314 |
| - return TypeChecker::conformsToProtocol( |
315 |
| - candidateType, protocol, cs.DC->getParentModule(), |
316 |
| - /*allowMissing=*/false); |
317 |
| - })) { |
318 |
| - // Score is lower here because we still prefer concrete |
319 |
| - // overloads over the generic ones when possible. |
320 |
| - bestCandidateScore = std::max(bestCandidateScore, 0.7); |
321 |
| - continue; |
322 |
| - } |
323 |
| - } else if (paramType->hasTypeParameter()) { |
324 |
| - // i.e. Array<Element> or Optional<Wrapped> as a parameter. |
325 |
| - // This is slightly better than all of the conformances matching |
326 |
| - // because the parameter is concrete and could split the graph. |
327 |
| - if (paramType->getAnyNominal() == candidateType->getAnyNominal()) { |
328 |
| - bestCandidateScore = std::max(bestCandidateScore, 0.8); |
329 |
| - continue; |
330 |
| - } |
331 |
| - } else if (candidateType->isEqual(paramType)) { |
332 |
| - // Exact match on one of the candidate bindings. |
333 |
| - bestCandidateScore = |
334 |
| - std::max(bestCandidateScore, isLiteralDefault ? 0.3 : 1.0); |
| 451 | + MatchOptions options(MatchFlag::OnParam); |
| 452 | + if (isLiteralDefault) |
| 453 | + options |= MatchFlag::Literal; |
| 454 | + |
| 455 | + auto score = scoreCandidateMatch(genericSig, candidateType, |
| 456 | + paramType, options); |
| 457 | + if (score > 0) { |
| 458 | + bestCandidateScore = std::max(bestCandidateScore, score); |
335 | 459 | continue;
|
336 | 460 | }
|
337 | 461 |
|
@@ -365,11 +489,12 @@ static void determineBestChoicesInContext(
|
365 | 489 | if (score > 0 ||
|
366 | 490 | (decl->isOperator() &&
|
367 | 491 | !decl->getBaseIdentifier().isStandardComparisonOperator())) {
|
368 |
| - if (llvm::any_of( |
369 |
| - resultTypes, [&overloadType](const Type candidateResultTy) { |
370 |
| - auto overloadResultTy = overloadType->getResult(); |
371 |
| - return candidateResultTy->isEqual(overloadResultTy); |
372 |
| - })) { |
| 492 | + if (llvm::any_of(resultTypes, [&](const Type candidateResultTy) { |
| 493 | + return scoreCandidateMatch(genericSig, |
| 494 | + overloadType->getResult(), |
| 495 | + candidateResultTy, |
| 496 | + /*options=*/{}) > 0; |
| 497 | + })) { |
373 | 498 | score += 1.0;
|
374 | 499 | }
|
375 | 500 | }
|
|
0 commit comments