Skip to content

Commit c2f7451

Browse files
committed
[CSOptimizer] Favor SIMD related arithmetic operator choices if argument is SIMD<N> type
1 parent 672ae3d commit c2f7451

File tree

1 file changed

+56
-5
lines changed

1 file changed

+56
-5
lines changed

lib/Sema/CSOptimizer.cpp

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,34 @@ void forEachDisjunctionChoice(
8585
}
8686
}
8787

88+
static bool isSIMDType(Type type) {
89+
auto *NTD = dyn_cast_or_null<StructDecl>(type->getAnyNominal());
90+
if (!NTD)
91+
return false;
92+
93+
auto typeName = NTD->getName().str();
94+
if (!typeName.startswith("SIMD"))
95+
return false;
96+
97+
return NTD->getParentModule()->getName().is("Swift");
98+
}
99+
100+
static bool isArithmeticOperatorOnSIMDProtocol(ValueDecl *decl) {
101+
if (!isSIMDOperator(decl))
102+
return false;
103+
104+
if (!decl->getBaseIdentifier().isArithmeticOperator())
105+
return false;
106+
107+
auto *DC = decl->getDeclContext();
108+
if (auto *P = DC->getSelfProtocolDecl()) {
109+
if (auto knownKind = P->getKnownProtocolKind())
110+
return *knownKind == KnownProtocolKind::SIMD;
111+
}
112+
113+
return false;
114+
}
115+
88116
} // end anonymous namespace
89117

90118
/// Given a set of disjunctions, attempt to determine
@@ -152,18 +180,31 @@ static void determineBestChoicesInContext(
152180
resultTypes.push_back(resultType);
153181
}
154182

183+
auto isViableOverload = [&](ValueDecl *decl) {
184+
// Allow standard arithmetic operator overloads on SIMD protocol
185+
// to be considered because we can favor them when then argument
186+
// is a known SIMD<N> type.
187+
if (isArithmeticOperatorOnSIMDProtocol(decl))
188+
return true;
189+
190+
// Don't consider generic overloads because we need conformance
191+
// checking functionality to determine best favoring, preferring
192+
// such overloads based only on concrete types leads to subpar
193+
// choices due to missed information.
194+
if (decl->getInterfaceType()->is<GenericFunctionType>())
195+
return false;
196+
197+
return true;
198+
};
199+
155200
// The choice with the best score.
156201
double bestScore = 0.0;
157202
SmallVector<std::pair<Constraint *, double>, 2> favoredChoices;
158203

159204
forEachDisjunctionChoice(
160205
cs, disjunction,
161206
[&](Constraint *choice, ValueDecl *decl, FunctionType *overloadType) {
162-
// Don't consider generic overloads because we need conformance
163-
// checking functionality to determine best favoring, preferring
164-
// such overloads based only on concrete types leads to subpar
165-
// choices due to missed information.
166-
if (decl->getInterfaceType()->is<GenericFunctionType>())
207+
if (!isViableOverload(decl))
167208
return;
168209

169210
if (overloadType->getNumParams() != argFuncType->getNumParams())
@@ -204,6 +245,16 @@ static void determineBestChoicesInContext(
204245
if (candidateType->isEqual(paramType)) {
205246
argScore = std::max(
206247
argScore, /*fromLiteral=*/candidate.second ? 0.3 : 1.0);
248+
continue;
249+
}
250+
251+
// If argument is SIMD<N> type i.e. SIMD1<...> it's appropriate
252+
// to favor of the overloads that are declared on SIMD protocol
253+
// and expect a particular `Scalar` if it's known.
254+
if (isSIMDType(candidateType) &&
255+
isArithmeticOperatorOnSIMDProtocol(decl)) {
256+
argScore = std::max(argScore, 1.0);
257+
continue;
207258
}
208259
}
209260

0 commit comments

Comments
 (0)