@@ -37,17 +37,17 @@ bool DerivedConformance::canDeriveTensorArrayProtocol(NominalTypeDecl *nominal,
37
37
auto *structDecl = dyn_cast<StructDecl>(nominal);
38
38
if (!structDecl)
39
39
return false ;
40
- // All stored properties must conform to `TensorArrayProtocol `.
40
+ // All stored properties must conform to `TensorGroup `.
41
41
auto &C = nominal->getASTContext ();
42
- auto *tensorArrayProto =
43
- C.getProtocol (KnownProtocolKind::TensorArrayProtocol );
42
+ auto *tensorGroupProto =
43
+ C.getProtocol (KnownProtocolKind::TensorGroup );
44
44
return llvm::all_of (structDecl->getStoredProperties (), [&](VarDecl *v) {
45
45
if (!v->hasInterfaceType ())
46
46
C.getLazyResolver ()->resolveDeclSignature (v);
47
47
if (!v->hasInterfaceType ())
48
48
return false ;
49
49
auto varType = DC->mapTypeIntoContext (v->getValueInterfaceType ());
50
- return (bool )TypeChecker::conformsToProtocol (varType, tensorArrayProto , DC,
50
+ return (bool )TypeChecker::conformsToProtocol (varType, tensorGroupProto , DC,
51
51
ConformanceCheckFlags::Used);
52
52
});
53
53
}
@@ -66,6 +66,20 @@ static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) {
66
66
return lookup.front ();
67
67
}
68
68
69
+ // Return the protocol requirement with the specified name.
70
+ static ValueDecl *getProtocolRequirement (ProtocolDecl *proto, DeclName name) {
71
+ auto lookup = proto->lookupDirect (name);
72
+ lookup.erase (std::remove_if (lookup.begin (), lookup.end (),
73
+ [](ValueDecl *v) {
74
+ return !isa<ProtocolDecl>(
75
+ v->getDeclContext ()) ||
76
+ !v->isProtocolRequirement ();
77
+ }),
78
+ lookup.end ());
79
+ assert (lookup.size () == 1 && " Ambiguous protocol requirement" );
80
+ return lookup.front ();
81
+ }
82
+
69
83
// Synthesize body for `_unpackTensorHandles(into:)`.
70
84
static void
71
85
deriveBodyTensorArrayProtocol_unpackTensorHandles (
@@ -349,12 +363,314 @@ static ValueDecl *deriveTensorArrayProtocol_tensorHandleCount(
349
363
return tensorHandleCountDecl;
350
364
}
351
365
366
+
367
+ // / Derive the body for the '_typeList' getter.
368
+ static void
369
+ deriveBodyTensorArrayProtocol_typeList (AbstractFunctionDecl *funcDecl) {
370
+ auto *parentDC = funcDecl->getParent ();
371
+ auto *nominal = funcDecl->getDeclContext ()->getSelfNominalTypeDecl ();
372
+ auto &C = nominal->getASTContext ();
373
+
374
+ auto *tensorGroupProto = C.getProtocol (KnownProtocolKind::TensorGroup);
375
+ auto *typeListReq = getProtocolRequirement (tensorGroupProto, C.Id_typeList );
376
+
377
+ // Concatenate all member `_typeList` arrays.
378
+ Type arrayType = BoundGenericType::get (
379
+ C.getArrayDecl (), Type (),
380
+ {C.getTensorDataTypeDecl ()->getDeclaredInterfaceType ()});
381
+ auto *arrayTypeExpr = TypeExpr::createImplicit (arrayType, C);
382
+ auto plusOpLookup = C.getArrayDecl ()->lookupDirect (C.getIdentifier (" +" ));
383
+ assert (plusOpLookup.size () == 1 && " Ambiguous 'Array.+' operator." );
384
+ ValueDecl *plusOpDecl = plusOpLookup.front ();
385
+ auto plusOpDRE = new (C)
386
+ DeclRefExpr (plusOpDecl, DeclNameLoc (), /* Implicit*/ true );
387
+ auto plusOpExpr = new (C)
388
+ DotSyntaxCallExpr (plusOpDRE, SourceLoc (), arrayTypeExpr);
389
+ Expr *typeListExpr = ArrayExpr::create (C, SourceLoc (), {}, {}, SourceLoc ());
390
+ for (auto member : nominal->getStoredProperties ()) {
391
+ auto memberType =
392
+ parentDC->mapTypeIntoContext (member->getValueInterfaceType ());
393
+ auto *memberTypeExpr = TypeExpr::createImplicit (memberType, C);
394
+ auto *memberTypeListExpr = new (C)
395
+ MemberRefExpr (memberTypeExpr, SourceLoc (), typeListReq,
396
+ DeclNameLoc (), /* Implicit*/ true );
397
+ // Create expression `lhsArg + rhsArg`.
398
+ auto *plusOpArgs =
399
+ TupleExpr::create (C, SourceLoc (), {typeListExpr, memberTypeListExpr},
400
+ {}, {}, SourceLoc (), /* HasTrailingClosure*/ false ,
401
+ /* Implicit*/ true );
402
+ typeListExpr = new (C) BinaryExpr (plusOpExpr, plusOpArgs,
403
+ /* Implicit*/ true );
404
+ }
405
+
406
+ // Return the resulting data types array.
407
+ auto *returnStmt = new (C) ReturnStmt (SourceLoc (), typeListExpr);
408
+ auto *body = BraceStmt::create (C, SourceLoc (), {returnStmt}, SourceLoc (),
409
+ /* Implicit*/ true );
410
+ funcDecl->setBody (BraceStmt::create (C, SourceLoc (), {body}, SourceLoc (),
411
+ /* Implicit*/ true ));
412
+ }
413
+
414
+ // / Derive a '_typeList' implementation.
415
+ static ValueDecl *deriveTensorArrayProtocol_typeList (
416
+ DerivedConformance &derived) {
417
+ auto nominal = derived.Nominal ;
418
+ auto &TC = derived.TC ;
419
+ ASTContext &C = TC.Context ;
420
+
421
+ auto parentDC = derived.getConformanceContext ();
422
+ Type dataTypeArrayType = BoundGenericType::get (
423
+ C.getArrayDecl (), Type (),
424
+ {C.getTensorDataTypeDecl ()->getDeclaredInterfaceType ()});
425
+ auto returnType = parentDC->mapTypeIntoContext (dataTypeArrayType);
426
+
427
+ // Create `_typeList` property declaration.
428
+ VarDecl *typeListDecl;
429
+ PatternBindingDecl *patDecl;
430
+ std::tie (typeListDecl, patDecl) = derived.declareDerivedProperty (
431
+ C.Id_typeList , returnType, returnType, /* isStatic*/ false ,
432
+ /* isFinal*/ false );
433
+
434
+ // Add `@inlinable` to the `_typeList` declaration.
435
+ if (nominal->getEffectiveAccess () > AccessLevel::Internal)
436
+ typeListDecl->getAttrs ().add (new (C) InlinableAttr (/* implicit*/ true ));
437
+
438
+ // Create `_typeList` getter.
439
+ auto *getterDecl = derived.declareDerivedPropertyGetter (
440
+ TC, typeListDecl, returnType);
441
+ getterDecl->setBodySynthesizer (deriveBodyTensorArrayProtocol_typeList);
442
+ typeListDecl->setAccessors (StorageImplInfo::getImmutableComputed (),
443
+ SourceLoc (), {getterDecl}, SourceLoc ());
444
+ derived.addMembersToConformanceContext ({getterDecl, typeListDecl, patDecl});
445
+
446
+ return typeListDecl;
447
+ }
448
+
449
+ // Synthesize body for `init(_owning:count:)`.
450
+ static void
451
+ deriveBodyTensorArrayProtocol_init (AbstractFunctionDecl *funcDecl) {
452
+ auto *parentDC = funcDecl->getParent ();
453
+ auto *nominal = parentDC->getSelfNominalTypeDecl ();
454
+ auto &C = nominal->getASTContext ();
455
+
456
+ // Obtain the address type.
457
+ auto cTensorHandleType = C.getOpaquePointerDecl ()->getDeclaredType ();
458
+ auto baseAddressType = BoundGenericType::get (
459
+ C.getUnsafePointerDecl (), Type (), {cTensorHandleType});
460
+ auto addressType = BoundGenericType::get (
461
+ C.getOptionalDecl (), Type (), {baseAddressType});
462
+ auto *addressTE = TypeExpr::createImplicit (addressType, C);
463
+
464
+ // Get references to `self` and parameter declarations.
465
+ auto *selfDecl = funcDecl->getImplicitSelfDecl ();
466
+ auto *selfDRE = new (C)
467
+ DeclRefExpr (selfDecl, DeclNameLoc (), /* Implicit*/ true );
468
+ auto *paramDecl = funcDecl->getParameters ()->get (0 );
469
+ auto *paramDRE = new (C)
470
+ DeclRefExpr (paramDecl, DeclNameLoc (), /* Implicit*/ true );
471
+
472
+ // Create an `if var` statement for the current address.
473
+ VarDecl *currAddressDecl = new (C) VarDecl (
474
+ /* IsStatic*/ false , VarDecl::Specifier::Var, /* IsCaptureList*/ false ,
475
+ SourceLoc (), C.getIdentifier (" currentAddress" ), funcDecl);
476
+ currAddressDecl->setImplicit ();
477
+ currAddressDecl->setHasNonPatternBindingInit (true );
478
+ currAddressDecl->setInterfaceType (baseAddressType);
479
+ currAddressDecl->setValidationToChecked ();
480
+
481
+ Pattern *currAddressPat = new (C)
482
+ NamedPattern (currAddressDecl, /* implicit*/ true );
483
+ currAddressPat = new (C)
484
+ VarPattern (SourceLoc (), /* isLet*/ false , currAddressPat,
485
+ /* implicit*/ true );
486
+ currAddressPat = new (C)
487
+ OptionalSomePattern (currAddressPat, currAddressPat->getEndLoc (),
488
+ /* implicit*/ true );
489
+ StmtConditionElement cond[] = {
490
+ StmtConditionElement (SourceLoc (), currAddressPat, /* Init*/ paramDRE)};
491
+
492
+ // Get the necessary protocol requirements.
493
+ auto *tensorGroupProto = C.getProtocol (KnownProtocolKind::TensorGroup);
494
+ auto *tensorArrayProto = C.getProtocol (
495
+ KnownProtocolKind::TensorArrayProtocol);
496
+ auto initName = DeclName (
497
+ C, DeclBaseName::createConstructor (), {C.getIdentifier (" _owning" )});
498
+ auto *initReq = getProtocolRequirement (tensorGroupProto, initName);
499
+ auto *tensorHandleCountReq = getProtocolRequirement (
500
+ tensorArrayProto, C.Id_tensorHandleCount );
501
+
502
+ Type intType = C.getIntDecl ()->getDeclaredType ();
503
+ TypeExpr *intTE = TypeExpr::createImplicit (intType, C);
504
+
505
+ // Iterate over members and call `self.t = T(_owning:)`.
506
+ llvm::SmallVector<ASTNode, 2 > thenMemberExprs;
507
+ llvm::SmallVector<ASTNode, 2 > elseMemberExprs;
508
+ for (auto member : nominal->getStoredProperties ()) {
509
+ auto memberType = parentDC->mapTypeIntoContext (
510
+ member->getValueInterfaceType ());
511
+ auto *memberTypeExpr = TypeExpr::createImplicit (memberType, C);
512
+ auto module = nominal->getModuleContext ();
513
+ auto confRef = module ->lookupConformance (
514
+ memberType, tensorGroupProto);
515
+ assert (confRef && " Member does not conform to `TensorGroup`" );
516
+
517
+ // Get member type's constructor, e.g. `MemberType.init(_owning:)`.
518
+ // Use protocol requirement declaration for the method by default: this
519
+ // will be dynamically dispatched.
520
+ ValueDecl *memberInitDecl = initReq;
521
+ // If conformance reference is concrete, then use concrete witness
522
+ // declaration for the constructor.
523
+ if (confRef->isConcrete ())
524
+ memberInitDecl = confRef->getConcrete ()->getWitnessDecl (
525
+ initReq, C.getLazyResolver ());
526
+ assert (memberInitDecl && " Member constructor declaration must exist" );
527
+ auto memberInitDRE = new (C) DeclRefExpr (
528
+ memberInitDecl, DeclNameLoc (), /* implicit*/ true );
529
+ memberInitDRE->setFunctionRefKind (FunctionRefKind::SingleApply);
530
+
531
+ // Create reference to member constructor: `MemberType.init(_owning:)`.
532
+ auto *memberInitExpr = new (C) ConstructorRefCallExpr (
533
+ memberInitDRE, memberTypeExpr);
534
+
535
+ auto *addressDRE = new (C) DeclRefExpr (
536
+ currAddressDecl, DeclNameLoc (), /* implicit*/ true );
537
+ auto *loadExpr = new (C) LoadExpr (addressDRE, baseAddressType);
538
+
539
+ // Initialize the member using its TensorGroup constructor.
540
+ // Note that, initialization is dependent on the branch of the
541
+ // if-statement taken.
542
+ auto *thenInitExpr = new (C) InjectIntoOptionalExpr (loadExpr, addressType);
543
+ auto *thenInitCallExpr = CallExpr::createImplicit (
544
+ C, memberInitExpr, {thenInitExpr}, {C.getIdentifier (" _owning" )});
545
+
546
+ // Create a nil expression with type UnsafePointer<CTensorHandle>? for the
547
+ // `else` branch.
548
+ auto *nilDecl = C.getOptionalNoneDecl ();
549
+ auto *nilDRE = new (C) DeclRefExpr (
550
+ nilDecl, DeclNameLoc (), /* implicit*/ true );
551
+ auto *elseInitExpr = new (C) DotSyntaxCallExpr (
552
+ nilDRE, SourceLoc (), addressTE);
553
+ auto *elseInitCallExpr = CallExpr::createImplicit (
554
+ C, memberInitExpr, {elseInitExpr}, {C.getIdentifier (" _owning" )});
555
+
556
+ // Assign the current member to the result of the initializer call.
557
+ auto *memberDRE = new (C) MemberRefExpr (
558
+ selfDRE, SourceLoc (), member, DeclNameLoc (), /* Implicit*/ true );
559
+
560
+ auto *thenAssignMemberExpr = new (C) AssignExpr (
561
+ memberDRE, SourceLoc (), thenInitCallExpr, /* Implicit*/ true );
562
+ auto *elseAssignMemberExpr = new (C) AssignExpr (
563
+ memberDRE, SourceLoc (), elseInitCallExpr, /* Implicit*/ true );
564
+
565
+ thenMemberExprs.push_back (thenAssignMemberExpr);
566
+ elseMemberExprs.push_back (elseAssignMemberExpr);
567
+
568
+ // Advance the current address.
569
+ DeclName advancedName (C, C.getIdentifier (" advanced" ),
570
+ {C.getIdentifier (" by" )});
571
+ auto *advancedMethodExpr =
572
+ new (C) UnresolvedDotExpr (addressDRE, SourceLoc (),
573
+ advancedName, DeclNameLoc (),
574
+ /* Implicit*/ true );
575
+
576
+ // Obtain `MemberType._tensorHandleCount`.
577
+ auto *memberCountMRE = new (C) MemberRefExpr (
578
+ memberDRE, SourceLoc (), tensorHandleCountReq, DeclNameLoc (),
579
+ /* Implicit*/ true );
580
+
581
+ // Cast the tensor handle count to Int.
582
+ auto intInitName = DeclName (C, DeclBaseName::createConstructor (),
583
+ {Identifier ()});
584
+ auto *intInitExpr =
585
+ new (C) UnresolvedDotExpr (intTE, SourceLoc (), intInitName,
586
+ DeclNameLoc (), /* Implicit*/ true );
587
+ auto *intInitCallExpr = CallExpr::createImplicit (
588
+ C, intInitExpr, {memberCountMRE}, {Identifier ()});
589
+
590
+ // Assign the new address.
591
+ auto *assignAddrCallExpr = CallExpr::createImplicit (
592
+ C, advancedMethodExpr, {intInitCallExpr}, {C.getIdentifier (" by" )});
593
+ auto *assignAddrExpr = new (C) AssignExpr (addressDRE, SourceLoc (),
594
+ assignAddrCallExpr,
595
+ /* Implicit*/ true );
596
+
597
+ thenMemberExprs.push_back (assignAddrExpr);
598
+ }
599
+
600
+ auto *thenBody = BraceStmt::create (
601
+ C, SourceLoc (), C.AllocateCopy (thenMemberExprs), SourceLoc (),
602
+ /* implicit*/ true );
603
+
604
+ auto *elseBody = BraceStmt::create (
605
+ C, SourceLoc (), C.AllocateCopy (elseMemberExprs), SourceLoc (),
606
+ /* implicit*/ true );
607
+
608
+ auto *ifStmt = new (C)
609
+ IfStmt (LabeledStmtInfo (), /* IfLoc*/ SourceLoc (),
610
+ /* Cond*/ C.AllocateCopy (cond), /* Then*/ thenBody,
611
+ /* ElseLoc*/ SourceLoc (), /* Else*/ elseBody, /* implicit*/ true );
612
+
613
+ funcDecl->setBody (BraceStmt::create (C, SourceLoc (), {ifStmt}, SourceLoc (),
614
+ /* implicit*/ true ));
615
+ }
616
+
617
+ // Synthesize the `init(_owning:count:)` function declaration.
618
+ static ValueDecl
619
+ *deriveTensorArrayProtocol_init (DerivedConformance &derived) {
620
+ auto &C = derived.TC .Context ;
621
+ auto nominal = derived.Nominal ;
622
+ auto parentDC = derived.getConformanceContext ();
623
+
624
+ // Obtain the address type.
625
+ auto cTensorHandleType = C.getOpaquePointerDecl ()->getDeclaredType ();
626
+ Type baseAddressType = BoundGenericType::get (
627
+ C.getUnsafePointerDecl (), Type (), {cTensorHandleType});
628
+ Type addressType = BoundGenericType::get (
629
+ C.getOptionalDecl (), Type (), {baseAddressType});
630
+ Type intType = C.getIntDecl ()->getDeclaredType ();
631
+
632
+ auto *param1 = new (C) ParamDecl (
633
+ VarDecl::Specifier::Default, SourceLoc (), SourceLoc (),
634
+ C.getIdentifier (" _owning" ), SourceLoc (), C.getIdentifier (" tensorHandles" ),
635
+ parentDC);
636
+ param1->setInterfaceType (addressType);
637
+ auto *param2 = new (C) ParamDecl (
638
+ VarDecl::Specifier::Default, SourceLoc (), SourceLoc (),
639
+ C.getIdentifier (" count" ), SourceLoc (), C.getIdentifier (" count" ), parentDC);
640
+ param2->setInterfaceType (intType);
641
+ ParameterList *params = ParameterList::create (C, {param1, param2});
642
+
643
+ DeclName name (C, DeclBaseName::createConstructor (), params);
644
+ auto *initDecl =
645
+ new (C) ConstructorDecl (name, SourceLoc (), OTK_None, SourceLoc (),
646
+ /* Throws*/ false , SourceLoc (), params,
647
+ /* GenericParams*/ nullptr , parentDC);
648
+ initDecl->setImplicit ();
649
+ initDecl->setSynthesized ();
650
+ initDecl->setBodySynthesizer (deriveBodyTensorArrayProtocol_init);
651
+
652
+ if (auto env = parentDC->getGenericEnvironmentOfContext ())
653
+ initDecl->setGenericEnvironment (env);
654
+ initDecl->computeType (AnyFunctionType::ExtInfo ().withThrows (false ));
655
+ initDecl->copyFormalAccessFrom (nominal, /* sourceIsParentContext*/ true );
656
+ initDecl->setValidationToChecked ();
657
+
658
+ derived.addMembersToConformanceContext ({initDecl});
659
+ C.addSynthesizedDecl (initDecl);
660
+
661
+ return initDecl;
662
+ }
663
+
352
664
ValueDecl *DerivedConformance::deriveTensorArrayProtocol (
353
665
ValueDecl *requirement) {
354
666
if (requirement->getBaseName () == TC.Context .Id_unpackTensorHandles )
355
667
return deriveTensorArrayProtocol_unpackTensorHandles (*this );
356
668
if (requirement->getBaseName () == TC.Context .Id_tensorHandleCount )
357
669
return deriveTensorArrayProtocol_tensorHandleCount (*this );
670
+ if (requirement->getBaseName () == TC.Context .Id_typeList )
671
+ return deriveTensorArrayProtocol_typeList (*this );
672
+ if (requirement->getBaseName () == DeclBaseName::createConstructor ())
673
+ return deriveTensorArrayProtocol_init (*this );
358
674
TC.diagnose (requirement->getLoc (),
359
675
diag::broken_tensor_array_protocol_requirement);
360
676
return nullptr ;
0 commit comments