@@ -363,6 +363,89 @@ static ValueDecl *deriveTensorArrayProtocol_tensorHandleCount(
363
363
return tensorHandleCountDecl;
364
364
}
365
365
366
+
367
+ // / Derive the body for the '_typeList' getter.
368
+ static void
369
+ deriveBodyTensorArrayProtocol_typeList (AbstractFunctionDecl *funcDecl) {
370
+ auto *parentDC = funcDecl->getParent ();
371
+ auto *nominal = funcDecl->getDeclContext ()->getSelfNominalTypeDecl ();
372
+ auto &C = nominal->getASTContext ();
373
+
374
+ auto *tensorGroupProto = C.getProtocol (KnownProtocolKind::TensorGroup);
375
+ auto *typeListReq = getProtocolRequirement (tensorGroupProto, C.Id_typeList );
376
+
377
+ // Concatenate all member `_typeList` arrays.
378
+ Type arrayType = BoundGenericType::get (
379
+ C.getArrayDecl (), Type (),
380
+ {C.getTensorDataTypeDecl ()->getDeclaredInterfaceType ()});
381
+ auto *arrayTypeExpr = TypeExpr::createImplicit (arrayType, C);
382
+ auto plusOpLookup = C.getArrayDecl ()->lookupDirect (C.getIdentifier (" +" ));
383
+ assert (plusOpLookup.size () == 1 && " Ambiguous 'Array.+' operator." );
384
+ ValueDecl *plusOpDecl = plusOpLookup.front ();
385
+ auto plusOpDRE = new (C)
386
+ DeclRefExpr (plusOpDecl, DeclNameLoc (), /* Implicit*/ true );
387
+ auto plusOpExpr = new (C)
388
+ DotSyntaxCallExpr (plusOpDRE, SourceLoc (), arrayTypeExpr);
389
+ Expr *typeListExpr = ArrayExpr::create (C, SourceLoc (), {}, {}, SourceLoc ());
390
+ for (auto member : nominal->getStoredProperties ()) {
391
+ auto memberType =
392
+ parentDC->mapTypeIntoContext (member->getValueInterfaceType ());
393
+ auto *memberTypeExpr = TypeExpr::createImplicit (memberType, C);
394
+ auto *memberTypeListExpr = new (C)
395
+ MemberRefExpr (memberTypeExpr, SourceLoc (), typeListReq,
396
+ DeclNameLoc (), /* Implicit*/ true );
397
+ // Create expression `lhsArg + rhsArg`.
398
+ auto *plusOpArgs =
399
+ TupleExpr::create (C, SourceLoc (), {typeListExpr, memberTypeListExpr},
400
+ {}, {}, SourceLoc (), /* HasTrailingClosure*/ false ,
401
+ /* Implicit*/ true );
402
+ typeListExpr = new (C) BinaryExpr (plusOpExpr, plusOpArgs,
403
+ /* Implicit*/ true );
404
+ }
405
+
406
+ // Return the resulting data types array.
407
+ auto *returnStmt = new (C) ReturnStmt (SourceLoc (), typeListExpr);
408
+ auto *body = BraceStmt::create (C, SourceLoc (), {returnStmt}, SourceLoc (),
409
+ /* Implicit*/ true );
410
+ funcDecl->setBody (BraceStmt::create (C, SourceLoc (), {body}, SourceLoc (),
411
+ /* Implicit*/ true ));
412
+ }
413
+
414
+ // / Derive a '_typeList' implementation.
415
+ static ValueDecl *deriveTensorArrayProtocol_typeList (
416
+ DerivedConformance &derived) {
417
+ auto nominal = derived.Nominal ;
418
+ auto &TC = derived.TC ;
419
+ ASTContext &C = TC.Context ;
420
+
421
+ auto parentDC = derived.getConformanceContext ();
422
+ Type dataTypeArrayType = BoundGenericType::get (
423
+ C.getArrayDecl (), Type (),
424
+ {C.getTensorDataTypeDecl ()->getDeclaredInterfaceType ()});
425
+ auto returnType = parentDC->mapTypeIntoContext (dataTypeArrayType);
426
+
427
+ // Create `_typeList` property declaration.
428
+ VarDecl *typeListDecl;
429
+ PatternBindingDecl *patDecl;
430
+ std::tie (typeListDecl, patDecl) = derived.declareDerivedProperty (
431
+ C.Id_typeList , returnType, returnType, /* isStatic*/ false ,
432
+ /* isFinal*/ false );
433
+
434
+ // Add `@inlinable` to the `_typeList` declaration.
435
+ if (nominal->getEffectiveAccess () > AccessLevel::Internal)
436
+ typeListDecl->getAttrs ().add (new (C) InlinableAttr (/* implicit*/ true ));
437
+
438
+ // Create `_typeList` getter.
439
+ auto *getterDecl = derived.declareDerivedPropertyGetter (
440
+ TC, typeListDecl, returnType);
441
+ getterDecl->setBodySynthesizer (deriveBodyTensorArrayProtocol_typeList);
442
+ typeListDecl->setAccessors (StorageImplInfo::getImmutableComputed (),
443
+ SourceLoc (), {getterDecl}, SourceLoc ());
444
+ derived.addMembersToConformanceContext ({getterDecl, typeListDecl, patDecl});
445
+
446
+ return typeListDecl;
447
+ }
448
+
366
449
// Synthesize body for `init(_owning:count:)`.
367
450
static void
368
451
deriveBodyTensorArrayProtocol_init (AbstractFunctionDecl *funcDecl) {
@@ -584,6 +667,8 @@ ValueDecl *DerivedConformance::deriveTensorArrayProtocol(
584
667
return deriveTensorArrayProtocol_unpackTensorHandles (*this );
585
668
if (requirement->getBaseName () == TC.Context .Id_tensorHandleCount )
586
669
return deriveTensorArrayProtocol_tensorHandleCount (*this );
670
+ if (requirement->getBaseName () == TC.Context .Id_typeList )
671
+ return deriveTensorArrayProtocol_typeList (*this );
587
672
if (requirement->getBaseName () == DeclBaseName::createConstructor ())
588
673
return deriveTensorArrayProtocol_init (*this );
589
674
TC.diagnose (requirement->getLoc (),
0 commit comments