|
16 | 16 |
|
17 | 17 | #include "TypeChecker.h"
|
18 | 18 | #include "swift/AST/GenericSignature.h"
|
| 19 | +#include "swift/Basic/OptionSet.h" |
19 | 20 | #include "swift/Sema/ConstraintGraph.h"
|
20 | 21 | #include "swift/Sema/ConstraintSystem.h"
|
21 | 22 | #include "llvm/ADT/BitVector.h"
|
22 | 23 | #include "llvm/ADT/DenseMap.h"
|
23 | 24 | #include "llvm/ADT/SmallVector.h"
|
24 | 25 | #include "llvm/ADT/TinyPtrVector.h"
|
| 26 | +#include "llvm/Support/SaveAndRestore.h" |
25 | 27 | #include "llvm/Support/raw_ostream.h"
|
26 | 28 | #include <cstddef>
|
27 | 29 | #include <functional>
|
@@ -187,6 +189,144 @@ static void determineBestChoicesInContext(
|
187 | 189 | /*allow fixes*/ false, listener, None);
|
188 | 190 | };
|
189 | 191 |
|
| 192 | + // Determine whether the candidate type is a subclass of the superclass |
| 193 | + // type. |
| 194 | + auto isSubclassOf = [&](Type candidateType, Type superclassType) { |
| 195 | + // Conversion from a concrete type to its existential value. |
| 196 | + if (superclassType->isExistentialType() && !superclassType->isAny()) { |
| 197 | + llvm::SaveAndRestore<ConstraintSystemOptions> options( |
| 198 | + cs.Options, cs.Options - ConstraintSystemFlags::AllowFixes); |
| 199 | + |
| 200 | + auto result = cs.matchExistentialTypes( |
| 201 | + candidateType, superclassType, ConstraintKind::SelfObjectOfProtocol, |
| 202 | + /*options=*/{}, cs.getConstraintLocator({})); |
| 203 | + return !result.isFailure(); |
| 204 | + } |
| 205 | + |
| 206 | + auto *subclassDecl = candidateType->getClassOrBoundGenericClass(); |
| 207 | + auto *superclassDecl = superclassType->getClassOrBoundGenericClass(); |
| 208 | + |
| 209 | + if (!(subclassDecl && superclassDecl)) |
| 210 | + return false; |
| 211 | + |
| 212 | + return superclassDecl->isSuperclassOf(subclassDecl); |
| 213 | + }; |
| 214 | + |
| 215 | + enum class MatchFlag { |
| 216 | + OnParam, |
| 217 | + Literal, |
| 218 | + }; |
| 219 | + |
| 220 | + using MatchOptions = OptionSet<MatchFlag>; |
| 221 | + |
| 222 | + // Perform a limited set of checks to determine whether the candidate |
| 223 | + // could possibly match the parameter type: |
| 224 | + // |
| 225 | + // - Equality |
| 226 | + // - Protocol conformance(s) |
| 227 | + // - Optional injection |
| 228 | + // - Superclass conversion |
| 229 | + // - Array-to-pointer conversion |
| 230 | + // - Value to existential conversion |
| 231 | + // - Exact match on top-level types |
| 232 | + std::function<double(GenericSignature, Type, Type, MatchOptions)> |
| 233 | + scoreCandidateMatch = [&](GenericSignature genericSig, |
| 234 | + Type candidateType, Type paramType, |
| 235 | + MatchOptions options) -> double { |
| 236 | + // Dependent members cannot be handled here because |
| 237 | + // they require substitution of the base type which |
| 238 | + // could come from a different argument. |
| 239 | + if (paramType->getAs<DependentMemberType>()) |
| 240 | + return 0; |
| 241 | + |
| 242 | + // Exact match between candidate and parameter types. |
| 243 | + if (candidateType->isEqual(paramType)) |
| 244 | + return options.contains(MatchFlag::Literal) ? 0.3 : 1; |
| 245 | + |
| 246 | + if (options.contains(MatchFlag::Literal)) |
| 247 | + return 0; |
| 248 | + |
| 249 | + // Check whether match would require optional injection. |
| 250 | + { |
| 251 | + SmallVector<Type, 2> candidateOptionals; |
| 252 | + SmallVector<Type, 2> paramOptionals; |
| 253 | + |
| 254 | + candidateType = |
| 255 | + candidateType->lookThroughAllOptionalTypes(candidateOptionals); |
| 256 | + paramType = paramType->lookThroughAllOptionalTypes(paramOptionals); |
| 257 | + |
| 258 | + if (!candidateOptionals.empty() || !paramOptionals.empty()) { |
| 259 | + if (paramOptionals.size() >= candidateOptionals.size()) |
| 260 | + return scoreCandidateMatch(genericSig, candidateType, paramType, |
| 261 | + options); |
| 262 | + |
| 263 | + // Optionality mismatch. |
| 264 | + return 0; |
| 265 | + } |
| 266 | + } |
| 267 | + |
| 268 | + // Candidate could be injected into optional parameter type |
| 269 | + // or converted to a superclass. |
| 270 | + if (isSubclassOf(candidateType, paramType)) |
| 271 | + return 1; |
| 272 | + |
| 273 | + // Possible Array<T> -> Unsafe*Pointer conversion. |
| 274 | + if (options.contains(MatchFlag::OnParam)) { |
| 275 | + if (cs.isArrayType(candidateType) && |
| 276 | + paramType->getAnyPointerElementType()) |
| 277 | + return 1; |
| 278 | + } |
| 279 | + |
| 280 | + // If both argument and parameter are tuples of the same arity, |
| 281 | + // it's a match. |
| 282 | + { |
| 283 | + if (auto *candidateTuple = candidateType->getAs<TupleType>()) { |
| 284 | + auto *paramTuple = paramType->getAs<TupleType>(); |
| 285 | + if (paramTuple && |
| 286 | + candidateTuple->getNumElements() == paramTuple->getNumElements()) |
| 287 | + return 1; |
| 288 | + } |
| 289 | + } |
| 290 | + |
| 291 | + // Check protocol requirement(s) if this parameter is a |
| 292 | + // generic parameter type. |
| 293 | + GenericSignature::RequiredProtocols protocolRequirements; |
| 294 | + if (genericSig) { |
| 295 | + if (auto *GP = paramType->getAs<GenericTypeParamType>()) { |
| 296 | + protocolRequirements = genericSig->getRequiredProtocols(GP); |
| 297 | + // It's a generic parameter which might be connected via |
| 298 | + // same-type constraints to other generic parameters but |
| 299 | + // we cannot check that here, so let's add a tiny score |
| 300 | + // just to acknowledge that it could possibly match. |
| 301 | + if (protocolRequirements.empty()) { |
| 302 | + return 0.01; |
| 303 | + } |
| 304 | + |
| 305 | + if (llvm::all_of(protocolRequirements, [&](ProtocolDecl *protocol) { |
| 306 | + return TypeChecker::conformsToProtocol(candidateType, protocol, |
| 307 | + cs.DC->getParentModule(), |
| 308 | + /*allowMissing=*/false); |
| 309 | + })) |
| 310 | + return 0.7; |
| 311 | + } |
| 312 | + } |
| 313 | + |
| 314 | + // Parameter is generic, let's check whether top-level |
| 315 | + // types match i.e. Array<Element> as a parameter. |
| 316 | + // |
| 317 | + // This is slightly better than all of the conformances matching |
| 318 | + // because the parameter is concrete and could split the graph. |
| 319 | + if (paramType->hasTypeParameter()) { |
| 320 | + auto *candidateDecl = candidateType->getAnyNominal(); |
| 321 | + auto *paramDecl = paramType->getAnyNominal(); |
| 322 | + |
| 323 | + if (candidateDecl && paramDecl && candidateDecl == paramDecl) |
| 324 | + return 0.8; |
| 325 | + } |
| 326 | + |
| 327 | + return 0; |
| 328 | + }; |
| 329 | + |
190 | 330 | // The choice with the best score.
|
191 | 331 | double bestScore = 0.0;
|
192 | 332 | SmallVector<std::pair<Constraint *, double>, 2> favoredChoices;
|
@@ -256,23 +396,6 @@ static void determineBestChoicesInContext(
|
256 | 396 | if (paramType->is<FunctionType>())
|
257 | 397 | continue;
|
258 | 398 |
|
259 |
| - // Check protocol requirement(s) if this parameter is a |
260 |
| - // generic parameter type. |
261 |
| - GenericSignature::RequiredProtocols protocolRequirements; |
262 |
| - if (genericSig) { |
263 |
| - if (auto *GP = paramType->getAs<GenericTypeParamType>()) { |
264 |
| - protocolRequirements = genericSig->getRequiredProtocols(GP); |
265 |
| - // It's a generic parameter which might be connected via |
266 |
| - // same-type constraints to other generic parameters but |
267 |
| - // we cannot check that here, so let's ignore it. |
268 |
| - if (protocolRequirements.empty()) |
269 |
| - continue; |
270 |
| - } |
271 |
| - |
272 |
| - if (paramType->getAs<DependentMemberType>()) |
273 |
| - return; |
274 |
| - } |
275 |
| - |
276 | 399 | // The idea here is to match the parameter type against
|
277 | 400 | // all of the argument candidate types and pick the best
|
278 | 401 | // match (i.e. exact equality one).
|
@@ -306,32 +429,14 @@ static void determineBestChoicesInContext(
|
306 | 429 | // The specifier only matters for `inout` check.
|
307 | 430 | candidateType = candidateType->getWithoutSpecifierType();
|
308 | 431 |
|
309 |
| - // We don't check generic requirements against literal default |
310 |
| - // types because it creates more noise than signal for operators. |
311 |
| - if (!protocolRequirements.empty() && !isLiteralDefault) { |
312 |
| - if (llvm::all_of( |
313 |
| - protocolRequirements, [&](ProtocolDecl *protocol) { |
314 |
| - return TypeChecker::conformsToProtocol( |
315 |
| - candidateType, protocol, cs.DC->getParentModule(), |
316 |
| - /*allowMissing=*/false); |
317 |
| - })) { |
318 |
| - // Score is lower here because we still prefer concrete |
319 |
| - // overloads over the generic ones when possible. |
320 |
| - bestCandidateScore = std::max(bestCandidateScore, 0.7); |
321 |
| - continue; |
322 |
| - } |
323 |
| - } else if (paramType->hasTypeParameter()) { |
324 |
| - // i.e. Array<Element> or Optional<Wrapped> as a parameter. |
325 |
| - // This is slightly better than all of the conformances matching |
326 |
| - // because the parameter is concrete and could split the graph. |
327 |
| - if (paramType->getAnyNominal() == candidateType->getAnyNominal()) { |
328 |
| - bestCandidateScore = std::max(bestCandidateScore, 0.8); |
329 |
| - continue; |
330 |
| - } |
331 |
| - } else if (candidateType->isEqual(paramType)) { |
332 |
| - // Exact match on one of the candidate bindings. |
333 |
| - bestCandidateScore = |
334 |
| - std::max(bestCandidateScore, isLiteralDefault ? 0.3 : 1.0); |
| 432 | + MatchOptions options(MatchFlag::OnParam); |
| 433 | + if (isLiteralDefault) |
| 434 | + options |= MatchFlag::Literal; |
| 435 | + |
| 436 | + auto score = scoreCandidateMatch(genericSig, candidateType, |
| 437 | + paramType, options); |
| 438 | + if (score > 0) { |
| 439 | + bestCandidateScore = std::max(bestCandidateScore, score); |
335 | 440 | continue;
|
336 | 441 | }
|
337 | 442 |
|
@@ -365,11 +470,12 @@ static void determineBestChoicesInContext(
|
365 | 470 | if (score > 0 ||
|
366 | 471 | (decl->isOperator() &&
|
367 | 472 | !decl->getBaseIdentifier().isStandardComparisonOperator())) {
|
368 |
| - if (llvm::any_of( |
369 |
| - resultTypes, [&overloadType](const Type candidateResultTy) { |
370 |
| - auto overloadResultTy = overloadType->getResult(); |
371 |
| - return candidateResultTy->isEqual(overloadResultTy); |
372 |
| - })) { |
| 473 | + if (llvm::any_of(resultTypes, [&](const Type candidateResultTy) { |
| 474 | + return scoreCandidateMatch(genericSig, |
| 475 | + overloadType->getResult(), |
| 476 | + candidateResultTy, |
| 477 | + /*options=*/{}) > 0; |
| 478 | + })) { |
373 | 479 | score += 1.0;
|
374 | 480 | }
|
375 | 481 | }
|
|
0 commit comments