Skip to content

Commit 95b85a9

Browse files
authored
Derive ElementaryFunctions conformances for structs. (#25500)
`ElementaryFunctions` derived conformances enable elementary math functions to work with product spaces formed from `ElementaryFunctions`-conforming types. Conform `Differentiable` synthesized associated types to `ElementaryFunctions` if possible. Similar to existing logic for `VectorProtocol`. Enables efficient, elegant mathematical optimizers.
1 parent 9040108 commit 95b85a9

17 files changed

+732
-20
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2532,6 +2532,8 @@ ERROR(parameterized_invalid_parameters_struct,none,
25322532
"invalid", (Type))
25332533
ERROR(broken_additive_arithmetic_requirement,none,
25342534
"AdditiveArithmetic protocol is broken: unexpected requirement", ())
2535+
ERROR(broken_elementary_functions_requirement,none,
2536+
"ElementaryFunctions protocol is broken: unexpected requirement", ())
25352537
ERROR(broken_vector_protocol_requirement,none,
25362538
"VectorProtocol protocol is broken: unexpected requirement", ())
25372539
ERROR(broken_differentiable_requirement,none,

include/swift/AST/KnownProtocols.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ PROTOCOL(Encodable)
7878
PROTOCOL(Decodable)
7979
// SWIFT_ENABLE_TENSORFLOW
8080
PROTOCOL(AdditiveArithmetic)
81+
PROTOCOL(ElementaryFunctions)
8182
PROTOCOL(KeyPathIterable)
8283
PROTOCOL(TensorArrayProtocol)
8384
PROTOCOL(TensorGroup)

lib/IRGen/GenMeta.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4206,6 +4206,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
42064206
case KnownProtocolKind::StringInterpolationProtocol:
42074207
// SWIFT_ENABLE_TENSORFLOW
42084208
case KnownProtocolKind::AdditiveArithmetic:
4209+
case KnownProtocolKind::ElementaryFunctions:
42094210
case KnownProtocolKind::KeyPathIterable:
42104211
case KnownProtocolKind::TensorArrayProtocol:
42114212
case KnownProtocolKind::TensorGroup:

lib/Sema/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ add_swift_host_library(swiftSema STATIC
2828
DerivedConformanceError.cpp
2929
# SWIFT_ENABLE_TENSORFLOW
3030
DerivedConformanceAdditiveArithmetic.cpp
31+
DerivedConformanceElementaryFunctions.cpp
3132
DerivedConformanceVectorProtocol.cpp
3233
DerivedConformanceDifferentiable.cpp
3334
DerivedConformanceKeyPathIterable.cpp

lib/Sema/DerivedConformanceAdditiveArithmetic.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ static void deriveBodyMathOperator(AbstractFunctionDecl *funcDecl,
183183
memberOpExprs.push_back(createMemberOpExpr(member));
184184
memberNames.push_back(member->getName());
185185
}
186-
// Call memberwise initialier with member operator call expressions.
186+
// Call memberwise initializer with member operator call expressions.
187187
auto *callExpr =
188188
CallExpr::createImplicit(C, initExpr, memberOpExprs, memberNames);
189189
ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), callExpr, true);

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -568,12 +568,14 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived,
568568
auto diffableType = TypeLoc::withoutLoc(diffableProto->getDeclaredType());
569569
auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
570570
auto addArithType = TypeLoc::withoutLoc(addArithProto->getDeclaredType());
571+
auto *mathProto = C.getProtocol(KnownProtocolKind::ElementaryFunctions);
572+
auto mathType = TypeLoc::withoutLoc(mathProto->getDeclaredType());
571573
auto *vectorProto = C.getProtocol(KnownProtocolKind::VectorProtocol);
572574
auto vectorType = TypeLoc::withoutLoc(vectorProto->getDeclaredType());
573575
auto *kpIterableProto = C.getProtocol(KnownProtocolKind::KeyPathIterable);
574576
auto kpIterableType = TypeLoc::withoutLoc(kpIterableProto->getDeclaredType());
575577

576-
SmallVector<TypeLoc, 3> inherited {diffableType};
578+
SmallVector<TypeLoc, 3> inherited{diffableType};
577579

578580
// Cache original members and their associated types for later use.
579581
SmallVector<VarDecl *, 8> diffProperties;
@@ -589,9 +591,16 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived,
589591
bool canDeriveAdditiveArithmetic =
590592
llvm::all_of(diffProperties, [&](VarDecl *vd) {
591593
return TC.conformsToProtocol(getAssociatedType(vd, parentDC, id),
592-
addArithProto, parentDC,
593-
None);
594-
});
594+
addArithProto, parentDC, None);
595+
});
596+
597+
// Associated struct can derive `ElementaryFunctions` if the associated types
598+
// of all stored properties conform to `ElementaryFunctions`.
599+
bool canDeriveElementaryFunctions =
600+
llvm::all_of(diffProperties, [&](VarDecl *vd) {
601+
return TC.conformsToProtocol(getAssociatedType(vd, parentDC, id),
602+
mathProto, parentDC, None);
603+
});
595604

596605
// Associated struct can derive `VectorProtocol` if the associated types of
597606
// all members conform to `VectorProtocol` and share the same
@@ -625,6 +634,10 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived,
625634
None))
626635
inherited.push_back(kpIterableType);
627636
}
637+
// If all members conform to `ElementaryFunctions`, make the associated struct
638+
// conform to `ElementaryFunctions`.
639+
if (canDeriveElementaryFunctions)
640+
inherited.push_back(mathType);
628641
// If all members also conform to `VectorProtocol` with the same `Scalar`
629642
// type, make the associated struct conform to `VectorProtocol` instead of
630643
// just `AdditiveArithmetic`.

0 commit comments

Comments
 (0)