Skip to content

Commit bd14b05

Browse files
authored
[stdlib] Add PointwiseMultiplicative protocol. (#25772)
`PointwiseMultiplicative` represents a ring over multiplication with identity element "one" and a multiplicative inverse (reciprocal). Conform `Differentiable` synthesized struct types to `PointwiseMultiplicative` protocol if possible. Enables revamping optimizers, which requires division (i.e. multiplication with multiplicative inverse).
1 parent ddf4bb2 commit bd14b05

14 files changed

+410
-85
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2548,26 +2548,22 @@ ERROR(broken_encodable_requirement,none,
25482548
ERROR(broken_decodable_requirement,none,
25492549
"Decodable protocol is broken: unexpected requirement", ())
25502550
// SWIFT_ENABLE_TENSORFLOW
2551-
ERROR(broken_key_path_iterable_requirement,none,
2552-
"KeyPathIterable protocol is broken: unexpected requirement", ())
2553-
ERROR(broken_tensor_array_protocol_requirement,none,
2554-
"TensorArrayProtocol protocol is broken: unexpected requirement", ())
2555-
ERROR(broken_tensor_group_requirement,none,
2556-
"TensorGroup protocol is broken: unexpected requirement", ())
2557-
ERROR(parameterized_no_parameters_struct,none,
2558-
"cannot automatically synthesize %0 because 'Parameters' struct does not "
2559-
"exist", (Type))
2560-
ERROR(parameterized_invalid_parameters_struct,none,
2561-
"cannot automatically synthesize %0 because 'Parameters' struct is "
2562-
"invalid", (Type))
25632551
ERROR(broken_additive_arithmetic_requirement,none,
25642552
"AdditiveArithmetic protocol is broken: unexpected requirement", ())
2553+
ERROR(broken_pointwise_multiplicative_requirement,none,
2554+
"PointwiseMultiplicative protocol is broken: unexpected requirement", ())
25652555
ERROR(broken_elementary_functions_requirement,none,
25662556
"ElementaryFunctions protocol is broken: unexpected requirement", ())
25672557
ERROR(broken_vector_protocol_requirement,none,
25682558
"VectorProtocol protocol is broken: unexpected requirement", ())
25692559
ERROR(broken_differentiable_requirement,none,
25702560
"Differentiable protocol is broken: unexpected requirement", ())
2561+
ERROR(broken_key_path_iterable_requirement,none,
2562+
"KeyPathIterable protocol is broken: unexpected requirement", ())
2563+
ERROR(broken_tensor_array_protocol_requirement,none,
2564+
"TensorArrayProtocol protocol is broken: unexpected requirement", ())
2565+
ERROR(broken_tensor_group_requirement,none,
2566+
"TensorGroup protocol is broken: unexpected requirement", ())
25712567
WARNING(differentiable_nondiff_type_implicit_noderivative_fixit,none,
25722568
"stored property %0 has no derivative because it does not conform to "
25732569
"'Differentiable'; add an explicit '@noDerivative' attribute"

include/swift/AST/KnownIdentifiers.def

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,10 @@ IDENTIFIER_(unpackTensorHandles)
139139
IDENTIFIER_(tensorHandleCount)
140140
// TensorGroup
141141
IDENTIFIER_(typeList)
142-
// AdditiveArithmetic, VectorProtocol
142+
// AdditiveArithmetic, PointwiseMultiplicative, VectorProtocol
143143
IDENTIFIER(zero)
144+
IDENTIFIER(one)
145+
IDENTIFIER(reciprocal)
144146
IDENTIFIER(VectorSpaceScalar)
145147
IDENTIFIER(adding)
146148
IDENTIFIER(subtracting)

include/swift/AST/KnownProtocols.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ PROTOCOL(Encodable)
7878
PROTOCOL(Decodable)
7979
// SWIFT_ENABLE_TENSORFLOW
8080
PROTOCOL(AdditiveArithmetic)
81+
PROTOCOL(PointwiseMultiplicative)
8182
PROTOCOL(ElementaryFunctions)
8283
PROTOCOL(KeyPathIterable)
8384
PROTOCOL(TensorArrayProtocol)

lib/IRGen/GenMeta.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4206,6 +4206,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
42064206
case KnownProtocolKind::StringInterpolationProtocol:
42074207
// SWIFT_ENABLE_TENSORFLOW
42084208
case KnownProtocolKind::AdditiveArithmetic:
4209+
case KnownProtocolKind::PointwiseMultiplicative:
42094210
case KnownProtocolKind::ElementaryFunctions:
42104211
case KnownProtocolKind::KeyPathIterable:
42114212
case KnownProtocolKind::TensorArrayProtocol:

lib/Sema/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ add_swift_host_library(swiftSema STATIC
2828
DerivedConformanceEquatableHashable.cpp
2929
DerivedConformanceError.cpp
3030
# SWIFT_ENABLE_TENSORFLOW
31-
DerivedConformanceAdditiveArithmetic.cpp
31+
DerivedConformanceRingMathProtocols.cpp
3232
DerivedConformanceElementaryFunctions.cpp
3333
DerivedConformanceVectorProtocol.cpp
3434
DerivedConformanceDifferentiable.cpp

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,14 +568,17 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived,
568568
auto diffableType = TypeLoc::withoutLoc(diffableProto->getDeclaredType());
569569
auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
570570
auto addArithType = TypeLoc::withoutLoc(addArithProto->getDeclaredType());
571+
auto *pointMulProto =
572+
C.getProtocol(KnownProtocolKind::PointwiseMultiplicative);
573+
auto pointMulType = TypeLoc::withoutLoc(pointMulProto->getDeclaredType());
571574
auto *mathProto = C.getProtocol(KnownProtocolKind::ElementaryFunctions);
572575
auto mathType = TypeLoc::withoutLoc(mathProto->getDeclaredType());
573576
auto *vectorProto = C.getProtocol(KnownProtocolKind::VectorProtocol);
574577
auto vectorType = TypeLoc::withoutLoc(vectorProto->getDeclaredType());
575578
auto *kpIterableProto = C.getProtocol(KnownProtocolKind::KeyPathIterable);
576579
auto kpIterableType = TypeLoc::withoutLoc(kpIterableProto->getDeclaredType());
577580

578-
SmallVector<TypeLoc, 3> inherited{diffableType};
581+
SmallVector<TypeLoc, 4> inherited{diffableType};
579582

580583
// Cache original members and their associated types for later use.
581584
SmallVector<VarDecl *, 8> diffProperties;
@@ -594,6 +597,14 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived,
594597
addArithProto, parentDC, None);
595598
});
596599

600+
// Associated struct can derive `PointwiseMultiplicative` if the associated
601+
// types of all stored properties conform to `PointwiseMultiplicative`.
602+
bool canDerivePointwiseMultiplicative =
603+
llvm::all_of(diffProperties, [&](VarDecl *vd) {
604+
return TC.conformsToProtocol(getAssociatedType(vd, parentDC, id),
605+
pointMulProto, parentDC, None);
606+
});
607+
597608
// Associated struct can derive `ElementaryFunctions` if the associated types
598609
// of all stored properties conform to `ElementaryFunctions`.
599610
bool canDeriveElementaryFunctions =
@@ -634,6 +645,10 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived,
634645
None))
635646
inherited.push_back(kpIterableType);
636647
}
648+
// If all members conform to `PointwiseMultiplicative`, make the associated
649+
// struct conform to `PointwiseMultiplicative`.
650+
if (canDerivePointwiseMultiplicative)
651+
inherited.push_back(pointMulType);
637652
// If all members conform to `ElementaryFunctions`, make the associated struct
638653
// conform to `ElementaryFunctions`.
639654
if (canDeriveElementaryFunctions)

0 commit comments

Comments
 (0)