Skip to content

Commit fefbf82

Browse files
committed
Derive ElementaryFunctions conformances for structs.
`ElementaryFunctions` derived conformances enable elementary math functions to work with product spaces formed from `ElementaryFunctions`-conforming types. Enables efficient, elegant mathematical optimizers.
1 parent 9040108 commit fefbf82

13 files changed

+687
-8
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2532,6 +2532,8 @@ ERROR(parameterized_invalid_parameters_struct,none,
25322532
"invalid", (Type))
25332533
ERROR(broken_additive_arithmetic_requirement,none,
25342534
"AdditiveArithmetic protocol is broken: unexpected requirement", ())
2535+
ERROR(broken_elementary_functions_requirement,none,
2536+
"ElementaryFunctions protocol is broken: unexpected requirement", ())
25352537
ERROR(broken_vector_protocol_requirement,none,
25362538
"VectorProtocol protocol is broken: unexpected requirement", ())
25372539
ERROR(broken_differentiable_requirement,none,

include/swift/AST/KnownProtocols.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ PROTOCOL(Encodable)
7878
PROTOCOL(Decodable)
7979
// SWIFT_ENABLE_TENSORFLOW
8080
PROTOCOL(AdditiveArithmetic)
81+
PROTOCOL(ElementaryFunctions)
8182
PROTOCOL(KeyPathIterable)
8283
PROTOCOL(TensorArrayProtocol)
8384
PROTOCOL(TensorGroup)

lib/IRGen/GenMeta.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4206,6 +4206,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
42064206
case KnownProtocolKind::StringInterpolationProtocol:
42074207
// SWIFT_ENABLE_TENSORFLOW
42084208
case KnownProtocolKind::AdditiveArithmetic:
4209+
case KnownProtocolKind::ElementaryFunctions:
42094210
case KnownProtocolKind::KeyPathIterable:
42104211
case KnownProtocolKind::TensorArrayProtocol:
42114212
case KnownProtocolKind::TensorGroup:

lib/Sema/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ add_swift_host_library(swiftSema STATIC
2828
DerivedConformanceError.cpp
2929
# SWIFT_ENABLE_TENSORFLOW
3030
DerivedConformanceAdditiveArithmetic.cpp
31+
DerivedConformanceElementaryFunctions.cpp
3132
DerivedConformanceVectorProtocol.cpp
3233
DerivedConformanceDifferentiable.cpp
3334
DerivedConformanceKeyPathIterable.cpp
Lines changed: 338 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,338 @@
1+
//===--- DerivedConformanceElementaryFunctions.cpp ------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
//
13+
// This file implements explicit derivation of the ElementaryFunctions protocol
14+
// for struct types.
15+
//
16+
//===----------------------------------------------------------------------===//
17+
18+
#include "CodeSynthesis.h"
19+
#include "TypeChecker.h"
20+
#include "swift/AST/Decl.h"
21+
#include "swift/AST/Expr.h"
22+
#include "swift/AST/GenericSignature.h"
23+
#include "swift/AST/Module.h"
24+
#include "swift/AST/ParameterList.h"
25+
#include "swift/AST/Pattern.h"
26+
#include "swift/AST/ProtocolConformance.h"
27+
#include "swift/AST/Stmt.h"
28+
#include "swift/AST/Types.h"
29+
#include "DerivedConformances.h"
30+
31+
using namespace swift;
32+
33+
// Represents synthesizable `ElementaryFunction` protocol requirements.
34+
enum ElementaryFunction {
35+
#define ELEMENTARY_FUNCTION(ID, NAME) ID,
36+
#include "DerivedConformanceElementaryFunctions.def"
37+
#undef ELEMENTARY_FUNCTION
38+
};
39+
40+
static StringRef getElementaryFunctionName(ElementaryFunction op) {
41+
switch (op) {
42+
#define ELEMENTARY_FUNCTION(ID, NAME) case ElementaryFunction::ID: return NAME;
43+
#include "DerivedConformanceElementaryFunctions.def"
44+
#undef ELEMENTARY_FUNCTION
45+
}
46+
}
47+
48+
// Return the protocol requirement with the specified name.
49+
// TODO: Move function to shared place for use with other derived conformances.
50+
static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) {
51+
auto lookup = proto->lookupDirect(name);
52+
llvm::erase_if(lookup, [](ValueDecl *v) {
53+
return !isa<ProtocolDecl>(v->getDeclContext()) ||
54+
!v->isProtocolRequirement();
55+
});
56+
assert(lookup.size() == 1 && "Ambiguous protocol requirement");
57+
return lookup.front();
58+
}
59+
60+
// Return true if given nominal type has a `let` stored with an initial value.
61+
// TODO: Move function to shared place for use with other derived conformances.
62+
static bool hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal) {
63+
return llvm::any_of(nominal->getStoredProperties(), [&](VarDecl *v) {
64+
return v->isLet() && v->hasInitialValue();
65+
});
66+
}
67+
68+
// Return the `ElementaryFunction` protocol requirement corresponding to the
69+
// given elementary function.
70+
static ValueDecl *getElementaryFunctionRequirement(
71+
ASTContext &C, ElementaryFunction op) {
72+
auto *mathProto = C.getProtocol(KnownProtocolKind::ElementaryFunctions);
73+
auto operatorId = C.getIdentifier(getElementaryFunctionName(op));
74+
switch (op) {
75+
#define ELEMENTARY_FUNCTION_UNARY(ID, NAME) \
76+
case ID: \
77+
return getProtocolRequirement(mathProto, operatorId);
78+
#include "DerivedConformanceElementaryFunctions.def"
79+
#undef ELEMENTARY_FUNCTION_UNARY
80+
case Root:
81+
return getProtocolRequirement(mathProto, operatorId);
82+
case Pow:
83+
case PowInt:
84+
auto lookup = mathProto->lookupDirect(operatorId);
85+
lookup.erase(std::remove_if(lookup.begin(), lookup.end(),
86+
[](ValueDecl *v) {
87+
return !isa<ProtocolDecl>(
88+
v->getDeclContext()) ||
89+
!v->isProtocolRequirement();
90+
}),
91+
lookup.end());
92+
assert(lookup.size() == 2 && "Expected two 'pow' functions");
93+
auto *powFuncDecl = cast<FuncDecl>(lookup.front());
94+
auto secondParamType =
95+
powFuncDecl->getParameters()->get(1)->getInterfaceType();
96+
if (secondParamType->getAnyNominal() == C.getIntDecl())
97+
return op == PowInt ? lookup.front() : lookup[1];
98+
else
99+
return op == PowInt ? lookup[1] : lookup.front();
100+
}
101+
}
102+
103+
// Get the effective memberwise initializer of the given nominal type, or create
104+
// it if it does not exist.
105+
static ConstructorDecl *getOrCreateEffectiveMemberwiseInitializer(
106+
TypeChecker &TC, NominalTypeDecl *nominal) {
107+
auto &C = nominal->getASTContext();
108+
if (auto *initDecl = nominal->getEffectiveMemberwiseInitializer())
109+
return initDecl;
110+
auto *initDecl = createImplicitConstructor(
111+
TC, nominal, ImplicitConstructorKind::Memberwise);
112+
nominal->addMember(initDecl);
113+
C.addSynthesizedDecl(initDecl);
114+
return initDecl;
115+
}
116+
117+
bool DerivedConformance::canDeriveElementaryFunctions(NominalTypeDecl *nominal,
118+
DeclContext *DC) {
119+
// Nominal type must be a struct. (Zero stored properties is okay.)
120+
auto *structDecl = dyn_cast<StructDecl>(nominal);
121+
if (!structDecl)
122+
return false;
123+
// Must not have any `let` stored properties with an initial value.
124+
// - This restriction may be lifted later with support for "true" memberwise
125+
// initializers that initialize all stored properties, including initial
126+
// value information.
127+
if (hasLetStoredPropertyWithInitialValue(nominal))
128+
return false;
129+
// All stored properties must conform to `ElementaryFunctions`.
130+
auto &C = nominal->getASTContext();
131+
auto *mathProto = C.getProtocol(KnownProtocolKind::ElementaryFunctions);
132+
return llvm::all_of(structDecl->getStoredProperties(), [&](VarDecl *v) {
133+
if (!v->hasInterfaceType())
134+
C.getLazyResolver()->resolveDeclSignature(v);
135+
if (!v->hasInterfaceType())
136+
return false;
137+
auto varType = DC->mapTypeIntoContext(v->getValueInterfaceType());
138+
return (bool)TypeChecker::conformsToProtocol(varType, mathProto, DC, None);
139+
});
140+
}
141+
142+
// Synthesize body for the given `ElementaryFunction` protocol requirement.
143+
static void deriveBodyElementaryFunction(AbstractFunctionDecl *funcDecl,
144+
ElementaryFunction op) {
145+
auto *parentDC = funcDecl->getParent();
146+
auto *nominal = parentDC->getSelfNominalTypeDecl();
147+
auto &C = nominal->getASTContext();
148+
149+
// Create memberwise initializer: `Nominal.init(...)`.
150+
auto *memberwiseInitDecl = nominal->getEffectiveMemberwiseInitializer();
151+
assert(memberwiseInitDecl && "Memberwise initializer must exist");
152+
auto *initDRE =
153+
new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(), /*Implicit*/ true);
154+
initDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
155+
auto *nominalTypeExpr = TypeExpr::createForDecl(SourceLoc(), nominal,
156+
funcDecl, /*Implicit*/ true);
157+
auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, nominalTypeExpr);
158+
159+
// Get operator protocol requirement.
160+
auto *mathProto = C.getProtocol(KnownProtocolKind::ElementaryFunctions);
161+
auto *operatorReq = getElementaryFunctionRequirement(C, op);
162+
163+
// Create reference(s) to operator parameters: one for unary functions and two
164+
// for binary functions.
165+
auto params = funcDecl->getParameters();
166+
auto *firstParamDRE =
167+
new (C) DeclRefExpr(params->get(0), DeclNameLoc(), /*Implicit*/ true);
168+
Expr *secondParamDRE = nullptr;
169+
if (params->size() == 2)
170+
secondParamDRE =
171+
new (C) DeclRefExpr(params->get(1), DeclNameLoc(), /*Implicit*/ true);
172+
173+
// Create call expression combining lhs and rhs members using member operator.
174+
auto createMemberOpCallExpr = [&](VarDecl *member) -> Expr * {
175+
auto module = nominal->getModuleContext();
176+
auto memberType =
177+
parentDC->mapTypeIntoContext(member->getValueInterfaceType());
178+
auto confRef = module->lookupConformance(memberType, mathProto);
179+
assert(confRef && "Member does not conform to math protocol");
180+
181+
// Get member type's elementary function, e.g. `Member.cos`.
182+
// Use protocol requirement declaration for the operator by default: this
183+
// will be dynamically dispatched.
184+
ValueDecl *memberOpDecl = operatorReq;
185+
// If conformance reference is concrete, then use concrete witness
186+
// declaration for the operator.
187+
if (confRef->isConcrete())
188+
memberOpDecl = confRef->getConcrete()->getWitnessDecl(
189+
operatorReq, C.getLazyResolver());
190+
assert(memberOpDecl && "Member operator declaration must exist");
191+
auto memberOpDRE =
192+
new (C) DeclRefExpr(memberOpDecl, DeclNameLoc(), /*Implicit*/ true);
193+
auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C);
194+
auto memberOpExpr =
195+
new (C) DotSyntaxCallExpr(memberOpDRE, SourceLoc(), memberTypeExpr);
196+
197+
// - For unary ops, create expression:
198+
// `<op>(x.member)`.
199+
// - For `pow(_ x: Self, _ y: Self)`, create expression:
200+
// `<op>(x.member, y.member)`.
201+
// - For `pow(_ x: Self, _ n: Int)` and `root(_ x: Self, n: Int)`, create:
202+
// `<op>(x.member, n)`.
203+
Expr *firstArg = new (C) MemberRefExpr(firstParamDRE, SourceLoc(), member,
204+
DeclNameLoc(), /*Implicit*/ true);
205+
Expr *secondArg = nullptr;
206+
if (secondParamDRE) {
207+
if (op == PowInt || op == Root)
208+
secondArg = secondParamDRE;
209+
else
210+
secondArg = new (C) MemberRefExpr(secondParamDRE, SourceLoc(), member,
211+
DeclNameLoc(), /*Implicit*/ true);
212+
}
213+
SmallVector<Expr *, 2> memberOpArgs{firstArg};
214+
if (secondArg)
215+
memberOpArgs.push_back(secondArg);
216+
SmallVector<Identifier, 2> memberOpArgLabels(memberOpArgs.size());
217+
auto *memberOpCallExpr = CallExpr::createImplicit(
218+
C, memberOpExpr, memberOpArgs, memberOpArgLabels);
219+
return memberOpCallExpr;
220+
};
221+
222+
// Create array of member operator call expressions.
223+
llvm::SmallVector<Expr *, 2> memberOpCallExprs;
224+
llvm::SmallVector<Identifier, 2> memberNames;
225+
for (auto member : nominal->getStoredProperties()) {
226+
memberOpCallExprs.push_back(createMemberOpCallExpr(member));
227+
memberNames.push_back(member->getName());
228+
}
229+
// Call memberwise initializer with member operator call expressions.
230+
auto *callExpr =
231+
CallExpr::createImplicit(C, initExpr, memberOpCallExprs, memberNames);
232+
ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), callExpr, true);
233+
funcDecl->setBody(
234+
BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(), true));
235+
}
236+
237+
#define ELEMENTARY_FUNCTION(ID, NAME) \
238+
static void deriveBodyElementaryFunctions_##ID(AbstractFunctionDecl *funcDecl, \
239+
void *) { \
240+
deriveBodyElementaryFunction(funcDecl, ID); \
241+
}
242+
#include "DerivedConformanceElementaryFunctions.def"
243+
#undef ELEMENTARY_FUNCTION
244+
245+
// Synthesize function declaration for the given math operator.
246+
static ValueDecl *deriveElementaryFunction(DerivedConformance &derived,
247+
ElementaryFunction op) {
248+
auto nominal = derived.Nominal;
249+
auto parentDC = derived.getConformanceContext();
250+
auto &C = derived.TC.Context;
251+
auto selfInterfaceType = parentDC->getDeclaredInterfaceType();
252+
253+
// Create parameter declaration with the given name and type.
254+
auto createParamDecl = [&](StringRef name, Type type) -> ParamDecl * {
255+
auto *param = new (C)
256+
ParamDecl(VarDecl::Specifier::Default, SourceLoc(), SourceLoc(),
257+
Identifier(), SourceLoc(), C.getIdentifier(name), parentDC);
258+
param->setInterfaceType(type);
259+
return param;
260+
};
261+
262+
ParameterList *params = nullptr;
263+
264+
switch (op) {
265+
#define ELEMENTARY_FUNCTION_UNARY(ID, NAME) \
266+
case ID: \
267+
params = \
268+
ParameterList::create(C, {createParamDecl("x", selfInterfaceType)}); \
269+
break;
270+
#include "DerivedConformanceElementaryFunctions.def"
271+
#undef ELEMENTARY_FUNCTION_UNARY
272+
case Pow:
273+
params =
274+
ParameterList::create(C, {createParamDecl("x", selfInterfaceType),
275+
createParamDecl("y", selfInterfaceType)});
276+
break;
277+
case PowInt:
278+
case Root:
279+
params = ParameterList::create(
280+
C, {createParamDecl("x", selfInterfaceType),
281+
createParamDecl("n", C.getIntDecl()->getDeclaredInterfaceType())});
282+
break;
283+
}
284+
285+
auto operatorId = C.getIdentifier(getElementaryFunctionName(op));
286+
DeclName operatorDeclName(C, operatorId, params);
287+
auto operatorDecl =
288+
FuncDecl::create(C, SourceLoc(), StaticSpellingKind::KeywordStatic,
289+
SourceLoc(), operatorDeclName, SourceLoc(),
290+
/*Throws*/ false, SourceLoc(),
291+
/*GenericParams*/ nullptr, params,
292+
TypeLoc::withoutLoc(selfInterfaceType), parentDC);
293+
operatorDecl->setImplicit();
294+
switch (op) {
295+
#define ELEMENTARY_FUNCTION(ID, NAME) \
296+
case ID: \
297+
operatorDecl->setBodySynthesizer(deriveBodyElementaryFunctions_##ID, \
298+
nullptr); \
299+
break;
300+
#include "DerivedConformanceElementaryFunctions.def"
301+
#undef ELEMENTARY_FUNCTION
302+
}
303+
if (auto env = parentDC->getGenericEnvironmentOfContext())
304+
operatorDecl->setGenericEnvironment(env);
305+
operatorDecl->computeType();
306+
operatorDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);
307+
operatorDecl->setValidationToChecked();
308+
309+
derived.addMembersToConformanceContext({operatorDecl});
310+
C.addSynthesizedDecl(operatorDecl);
311+
312+
return operatorDecl;
313+
}
314+
315+
ValueDecl *
316+
DerivedConformance::deriveElementaryFunctions(ValueDecl *requirement) {
317+
// Diagnose conformances in disallowed contexts.
318+
if (checkAndDiagnoseDisallowedContext(requirement))
319+
return nullptr;
320+
// Create memberwise initializer for nominal type if it doesn't already exist.
321+
getOrCreateEffectiveMemberwiseInitializer(TC, Nominal);
322+
#define ELEMENTARY_FUNCTION_UNARY(ID, NAME) \
323+
if (requirement->getBaseName() == TC.Context.getIdentifier(NAME)) \
324+
return deriveElementaryFunction(*this, ID);
325+
#include "DerivedConformanceElementaryFunctions.def"
326+
#undef ELEMENTARY_FUNCTION_UNARY
327+
if (requirement->getBaseName() == TC.Context.getIdentifier("root"))
328+
return deriveElementaryFunction(*this, Root);
329+
if (requirement->getBaseName() == TC.Context.getIdentifier("pow")) {
330+
auto *powFuncDecl = cast<FuncDecl>(requirement);
331+
return powFuncDecl->getParameters()->get(1)->getName().str() == "n"
332+
? deriveElementaryFunction(*this, PowInt)
333+
: deriveElementaryFunction(*this, Pow);
334+
}
335+
TC.diagnose(requirement->getLoc(),
336+
diag::broken_elementary_functions_requirement);
337+
return nullptr;
338+
}

0 commit comments

Comments
 (0)