@@ -93,33 +93,35 @@ static ConstructorDecl *getMemberwiseInitializer(NominalTypeDecl *nominal) {
93
93
}
94
94
95
95
// Return the `Scalar` associated type for a ValueDecl if it conforms to
96
- // `VectorNumeric`.
96
+ // `VectorNumeric` in the given context .
97
97
// If the decl does not conform to `VectorNumeric`, return a null `Type`.
98
- static Type getVectorNumericScalarAssocType (ValueDecl *decl) {
98
+ static Type getVectorNumericScalarAssocType (VarDecl *decl, DeclContext *DC ) {
99
99
auto &C = decl->getASTContext ();
100
100
auto *vectorNumericProto = C.getProtocol (KnownProtocolKind::VectorNumeric);
101
- auto declType =
102
- decl->getDeclContext ()->mapTypeIntoContext (decl->getInterfaceType ());
103
- auto conf = TypeChecker::conformsToProtocol (declType, vectorNumericProto,
104
- decl->getDeclContext (),
101
+ if (!decl->hasType ())
102
+ C.getLazyResolver ()->resolveDeclSignature (decl);
103
+ if (!decl->hasType ())
104
+ return Type ();
105
+ auto declType = decl->getType ()->hasArchetype ()
106
+ ? decl->getType ()
107
+ : DC->mapTypeIntoContext (decl->getType ());
108
+ auto conf = TypeChecker::conformsToProtocol (declType, vectorNumericProto, DC,
105
109
ConformanceCheckFlags::Used);
106
110
if (!conf)
107
111
return Type ();
108
112
Type scalarType = ProtocolConformanceRef::getTypeWitnessByName (
109
- decl-> getInterfaceType () , *conf, C.Id_Scalar , C.getLazyResolver ());
113
+ declType , *conf, C.Id_Scalar , C.getLazyResolver ());
110
114
assert (scalarType && " 'Scalar' associated type not found" );
111
115
return scalarType;
112
116
}
113
117
114
- static Type deriveVectorNumeric_Scalar (NominalTypeDecl *nominal) {
115
- // Must be a struct type.
118
+ static Type deriveVectorNumeric_Scalar (NominalTypeDecl *nominal,
119
+ DeclContext *DC) {
120
+ // Nominal type must be a struct. (Zero stored properties is okay.)
116
121
auto *structDecl = dyn_cast<StructDecl>(nominal);
117
122
auto &C = nominal->getASTContext ();
118
123
if (!structDecl)
119
124
return Type ();
120
- // Struct must have at least one stored property.
121
- if (structDecl->getStoredProperties ().empty ())
122
- return Type ();
123
125
// If all stored properties conform to `VectorNumeric` and have the same
124
126
// `Scalar` associated type, return that `Scalar` associated type.
125
127
// Otherwise, the `Scalar` type cannot be derived.
@@ -129,7 +131,7 @@ static Type deriveVectorNumeric_Scalar(NominalTypeDecl *nominal) {
129
131
C.getLazyResolver ()->resolveDeclSignature (member);
130
132
if (!member->hasInterfaceType ())
131
133
return Type ();
132
- auto scalarType = getVectorNumericScalarAssocType (member);
134
+ auto scalarType = getVectorNumericScalarAssocType (member, DC );
133
135
// If stored property does not conform to `VectorNumeric`, return null
134
136
// `Type`.
135
137
if (!scalarType)
@@ -146,31 +148,31 @@ static Type deriveVectorNumeric_Scalar(NominalTypeDecl *nominal) {
146
148
return sameScalarType;
147
149
}
148
150
149
- bool DerivedConformance::canDeriveAdditiveArithmetic (NominalTypeDecl *nominal) {
150
- // Must be a struct type.
151
+ bool DerivedConformance::canDeriveAdditiveArithmetic (NominalTypeDecl *nominal,
152
+ DeclContext *DC) {
153
+ // Nominal type must be a struct. (Zero stored properties is okay.)
151
154
auto *structDecl = dyn_cast<StructDecl>(nominal);
152
155
if (!structDecl)
153
156
return false ;
154
- // Struct must have at least one stored property.
155
- if (structDecl->getStoredProperties ().empty ())
156
- return false ;
157
157
// All stored properties must conform to `AdditiveArithmetic`.
158
158
auto &C = nominal->getASTContext ();
159
159
auto *addArithProto = C.getProtocol (KnownProtocolKind::AdditiveArithmetic);
160
160
return llvm::all_of (structDecl->getStoredProperties (), [&](VarDecl *v) {
161
- if (!v->hasInterfaceType ())
161
+ if (!v->getType ())
162
162
C.getLazyResolver ()->resolveDeclSignature (v);
163
- if (!v->hasInterfaceType ())
163
+ if (!v->getType ())
164
164
return false ;
165
- auto conf = TypeChecker::conformsToProtocol (v->getType (), addArithProto,
166
- v->getDeclContext (),
167
- ConformanceCheckFlags::Used);
168
- return (bool )conf;
165
+ auto declType = v->getType ()->hasArchetype ()
166
+ ? v->getType ()
167
+ : DC->mapTypeIntoContext (v->getType ());
168
+ return (bool )TypeChecker::conformsToProtocol (declType, addArithProto, DC,
169
+ ConformanceCheckFlags::Used);
169
170
});
170
171
}
171
172
172
- bool DerivedConformance::canDeriveVectorNumeric (NominalTypeDecl *nominal) {
173
- return bool (deriveVectorNumeric_Scalar (nominal));
173
+ bool DerivedConformance::canDeriveVectorNumeric (NominalTypeDecl *nominal,
174
+ DeclContext *DC) {
175
+ return bool (deriveVectorNumeric_Scalar (nominal, DC));
174
176
}
175
177
176
178
// Synthesize body for the given math operator.
@@ -203,7 +205,11 @@ static void deriveBodyMathOperator(AbstractFunctionDecl *funcDecl,
203
205
// Create expression combining lhs and rhs members using member operator.
204
206
auto createMemberOpExpr = [&](VarDecl *member) -> Expr * {
205
207
auto module = nominal->getModuleContext ();
206
- auto confRef = module ->lookupConformance (member->getType (), proto);
208
+ auto memberType = member->getType ()->hasArchetype ()
209
+ ? member->getType ()
210
+ : nominal->mapTypeIntoContext (
211
+ member->getType ()->mapTypeOutOfContext ());
212
+ auto confRef = module ->lookupConformance (memberType, proto);
207
213
assert (confRef && " Member does not conform to math protocol" );
208
214
209
215
// Get member type's math operator, e.g. `Member.+`.
@@ -213,8 +219,10 @@ static void deriveBodyMathOperator(AbstractFunctionDecl *funcDecl,
213
219
// If conformance reference is concrete, then use concrete witness
214
220
// declaration for the operator.
215
221
if (confRef->isConcrete ())
216
- memberOpDecl =
217
- confRef->getConcrete ()->getWitnessDecl (operatorReq, nullptr );
222
+ if (auto opDecl =
223
+ confRef->getConcrete ()->getWitnessDecl (operatorReq, nullptr ))
224
+ memberOpDecl = opDecl;
225
+ assert (memberOpDecl && " Member operator declaration must exist" );
218
226
auto memberOpDRE =
219
227
new (C) DeclRefExpr (memberOpDecl, DeclNameLoc (), /* Implicit*/ true );
220
228
auto *memberTypeExpr = TypeExpr::createImplicit (member->getType (), C);
@@ -296,8 +304,9 @@ static ValueDecl *deriveMathOperator(DerivedConformance &derived,
296
304
case Subtract:
297
305
return std::make_pair (selfInterfaceType, selfInterfaceType);
298
306
case ScalarMultiply:
299
- return std::make_pair (deriveVectorNumeric_Scalar (nominal),
300
- selfInterfaceType);
307
+ return std::make_pair (
308
+ deriveVectorNumeric_Scalar (nominal, parentDC)->mapTypeOutOfContext (),
309
+ selfInterfaceType);
301
310
}
302
311
};
303
312
@@ -376,7 +385,7 @@ static void deriveBodyAdditiveArithmetic_zero(AbstractFunctionDecl *funcDecl) {
376
385
return new (C) MemberRefExpr (memberTypeExpr, SourceLoc (), zeroReq,
377
386
DeclNameLoc (), /* Implicit*/ true );
378
387
}
379
- // Otherwise,return reference to concrete witness declaration for `zero`.
388
+ // Otherwise, return reference to concrete witness declaration for `zero`.
380
389
auto conf = confRef->getConcrete ();
381
390
auto zeroDecl = conf->getWitnessDecl (zeroReq, nullptr );
382
391
return new (C) MemberRefExpr (memberTypeExpr, SourceLoc (), zeroDecl,
@@ -462,8 +471,7 @@ ValueDecl *DerivedConformance::deriveVectorNumeric(ValueDecl *requirement) {
462
471
463
472
Type DerivedConformance::deriveVectorNumeric (AssociatedTypeDecl *requirement) {
464
473
if (requirement->getBaseName () == TC.Context .Id_Scalar ) {
465
- auto rawType = deriveVectorNumeric_Scalar (Nominal);
466
- return getConformanceContext ()->mapTypeIntoContext (rawType);
474
+ return deriveVectorNumeric_Scalar (Nominal, getConformanceContext ());
467
475
}
468
476
TC.diagnose (requirement->getLoc (), diag::broken_vector_numeric_requirement);
469
477
return nullptr ;
0 commit comments