Skip to content

Commit 9a7c32a

Browse files
committed
[Sema] Fix Differentiable.TangentVector derived conformances.
Fix logic for deriving `VectorProtocol` for derived `TangentVector` struct.
1 parent b42dd7c commit 9a7c32a

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -542,12 +542,12 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
542542
Type sameScalarType;
543543
bool canDeriveVectorProtocol = !diffProperties.empty() &&
544544
llvm::all_of(diffProperties, [&](VarDecl *vd) {
545-
auto conf = TC.conformsToProtocol(getTangentVectorType(vd, parentDC),
546-
vectorProto, nominal, None);
545+
auto tanType = getTangentVectorType(vd, parentDC);
546+
auto conf = TC.conformsToProtocol(tanType, vectorProto, nominal, None);
547547
if (!conf)
548548
return false;
549549
auto scalarType =
550-
conf->getTypeWitnessByName(vd->getType(), C.Id_VectorSpaceScalar);
550+
conf->getTypeWitnessByName(tanType, C.Id_VectorSpaceScalar);
551551
if (!sameScalarType) {
552552
sameScalarType = scalarType;
553553
return true;

test/Sema/struct_differentiable.swift

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,16 @@ func testAllMembersVectorProtocol() {
142142
assertConformsToVectorProtocol(AllMembersVectorProtocol.TangentVector.self)
143143
}
144144

145+
// Test generic `VectorProtocol`-conforming members.
146+
protocol Module: Differentiable where TangentVector: VectorProtocol {}
147+
struct Sequential<Layer1: Module, Layer2: Module>: Module
148+
where
149+
Layer1.TangentVector.VectorSpaceScalar == Layer2.TangentVector.VectorSpaceScalar
150+
{
151+
var layer1: Layer1
152+
var layer2: Layer2
153+
}
154+
145155
// Test type `AllMembersElementaryFunctions` whose members conforms to `ElementaryFunctions`,
146156
// in which case we should make `TangentVector` conform to `ElementaryFunctions`.
147157
struct MyVector2 : ElementaryFunctions, Differentiable, EuclideanDifferentiable {

0 commit comments

Comments
 (0)