@@ -37,17 +37,17 @@ bool DerivedConformance::canDeriveTensorArrayProtocol(NominalTypeDecl *nominal,
37
37
auto *structDecl = dyn_cast<StructDecl>(nominal);
38
38
if (!structDecl)
39
39
return false ;
40
- // All stored properties must conform to `TensorArrayProtocol `.
40
+ // All stored properties must conform to `TensorGroup `.
41
41
auto &C = nominal->getASTContext ();
42
- auto *tensorArrayProto =
43
- C.getProtocol (KnownProtocolKind::TensorArrayProtocol );
42
+ auto *tensorGroupProto =
43
+ C.getProtocol (KnownProtocolKind::TensorGroup );
44
44
return llvm::all_of (structDecl->getStoredProperties (), [&](VarDecl *v) {
45
45
if (!v->hasInterfaceType ())
46
46
C.getLazyResolver ()->resolveDeclSignature (v);
47
47
if (!v->hasInterfaceType ())
48
48
return false ;
49
49
auto varType = DC->mapTypeIntoContext (v->getValueInterfaceType ());
50
- return (bool )TypeChecker::conformsToProtocol (varType, tensorArrayProto , DC,
50
+ return (bool )TypeChecker::conformsToProtocol (varType, tensorGroupProto , DC,
51
51
ConformanceCheckFlags::Used);
52
52
});
53
53
}
@@ -66,6 +66,20 @@ static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) {
66
66
return lookup.front ();
67
67
}
68
68
69
+ // Return the protocol requirement with the specified name.
70
+ static ValueDecl *getProtocolRequirement (ProtocolDecl *proto, DeclName name) {
71
+ auto lookup = proto->lookupDirect (name);
72
+ lookup.erase (std::remove_if (lookup.begin (), lookup.end (),
73
+ [](ValueDecl *v) {
74
+ return !isa<ProtocolDecl>(
75
+ v->getDeclContext ()) ||
76
+ !v->isProtocolRequirement ();
77
+ }),
78
+ lookup.end ());
79
+ assert (lookup.size () == 1 && " Ambiguous protocol requirement" );
80
+ return lookup.front ();
81
+ }
82
+
69
83
// Synthesize body for `_unpackTensorHandles(into:)`.
70
84
static void
71
85
deriveBodyTensorArrayProtocol_unpackTensorHandles (
@@ -349,12 +363,246 @@ static ValueDecl *deriveTensorArrayProtocol_tensorHandleCount(
349
363
return tensorHandleCountDecl;
350
364
}
351
365
366
+ // Synthesize body for `init(_owning:count:)`.
367
+ static void
368
+ deriveBodyTensorArrayProtocol_init (AbstractFunctionDecl *funcDecl) {
369
+ auto *parentDC = funcDecl->getParent ();
370
+ auto *nominal = parentDC->getSelfNominalTypeDecl ();
371
+ auto &C = nominal->getASTContext ();
372
+
373
+ // Obtain the address type.
374
+ auto cTensorHandleType = C.getOpaquePointerDecl ()->getDeclaredType ();
375
+ auto baseAddressType = BoundGenericType::get (
376
+ C.getUnsafePointerDecl (), Type (), {cTensorHandleType});
377
+ auto addressType = BoundGenericType::get (
378
+ C.getOptionalDecl (), Type (), {baseAddressType});
379
+ auto *addressTE = TypeExpr::createImplicit (addressType, C);
380
+
381
+ // Get references to `self` and parameter declarations.
382
+ auto *selfDecl = funcDecl->getImplicitSelfDecl ();
383
+ auto *selfDRE = new (C)
384
+ DeclRefExpr (selfDecl, DeclNameLoc (), /* Implicit*/ true );
385
+ auto *paramDecl = funcDecl->getParameters ()->get (0 );
386
+ auto *paramDRE = new (C)
387
+ DeclRefExpr (paramDecl, DeclNameLoc (), /* Implicit*/ true );
388
+
389
+ // Create an `if var` statement for the current address.
390
+ VarDecl *currAddressDecl = new (C) VarDecl (
391
+ /* IsStatic*/ false , VarDecl::Specifier::Var, /* IsCaptureList*/ false ,
392
+ SourceLoc (), C.getIdentifier (" currentAddress" ), funcDecl);
393
+ currAddressDecl->setImplicit ();
394
+ currAddressDecl->setHasNonPatternBindingInit (true );
395
+ currAddressDecl->setInterfaceType (baseAddressType);
396
+ currAddressDecl->setValidationToChecked ();
397
+
398
+ Pattern *currAddressPat = new (C)
399
+ NamedPattern (currAddressDecl, /* implicit*/ true );
400
+ currAddressPat = new (C)
401
+ VarPattern (SourceLoc (), /* isLet*/ false , currAddressPat,
402
+ /* implicit*/ true );
403
+ currAddressPat = new (C)
404
+ OptionalSomePattern (currAddressPat, currAddressPat->getEndLoc (),
405
+ /* implicit*/ true );
406
+ StmtConditionElement cond[] = {
407
+ StmtConditionElement (SourceLoc (), currAddressPat, /* Init*/ paramDRE)};
408
+
409
+ // Get the necessary protocol requirements.
410
+ auto *tensorGroupProto = C.getProtocol (KnownProtocolKind::TensorGroup);
411
+ auto *tensorArrayProto = C.getProtocol (
412
+ KnownProtocolKind::TensorArrayProtocol);
413
+ auto initName = DeclName (
414
+ C, DeclBaseName::createConstructor (),
415
+ {C.getIdentifier (" _owning" ), C.getIdentifier (" count" )});
416
+ auto *initReq = getProtocolRequirement (tensorArrayProto, initName);
417
+ auto *tensorHandleCountReq = getProtocolRequirement (
418
+ tensorArrayProto, C.Id_tensorHandleCount );
419
+
420
+ Type intType = C.getIntDecl ()->getDeclaredType ();
421
+ TypeExpr *intTE = TypeExpr::createImplicit (intType, C);
422
+
423
+ // Goes through the member TensorGroups and call
424
+ // `self.t = T(_owning:count:)`.
425
+ llvm::SmallVector<ASTNode, 2 > thenMemberExprs;
426
+ llvm::SmallVector<ASTNode, 2 > elseMemberExprs;
427
+ for (auto member : nominal->getStoredProperties ()) {
428
+ auto memberType = parentDC->mapTypeIntoContext (
429
+ member->getValueInterfaceType ());
430
+ auto *memberTypeExpr = TypeExpr::createImplicit (memberType, C);
431
+ auto module = nominal->getModuleContext ();
432
+ auto confRef = module ->lookupConformance (
433
+ memberType, tensorGroupProto);
434
+ assert (confRef && " Member does not conform to `TensorGroup`" );
435
+
436
+ // Get member type's constructor, e.g. `MemberType.init(_owning:)`.
437
+ // Use protocol requirement declaration for the method by default: this
438
+ // will be dynamically dispatched.
439
+ ValueDecl *memberInitDecl = initReq;
440
+ // If conformance reference is concrete, then use concrete witness
441
+ // declaration for the constructor.
442
+ if (confRef->isConcrete ())
443
+ memberInitDecl = confRef->getConcrete ()->getWitnessDecl (
444
+ initReq, C.getLazyResolver ());
445
+ assert (memberInitDecl && " Member constructor declaration must exist" );
446
+ auto memberInitDRE = new (C) DeclRefExpr (
447
+ memberInitDecl, DeclNameLoc (), /* implicit*/ true );
448
+ memberInitDRE->setFunctionRefKind (FunctionRefKind::SingleApply);
449
+
450
+ // Create reference to member constructor: `MemberType.init(_owning:)`.
451
+ auto *memberInitExpr = new (C) ConstructorRefCallExpr (
452
+ memberInitDRE, memberTypeExpr);
453
+
454
+ auto *addressDRE = new (C) DeclRefExpr (
455
+ currAddressDecl, DeclNameLoc (), /* implicit*/ true );
456
+ auto *loadExpr = new (C) LoadExpr (addressDRE, baseAddressType);
457
+
458
+ // Initialize the member using its TensorGroup constructor.
459
+ // Note that, initialization is dependent on the branch of the
460
+ // if-statement taken.
461
+ auto *thenInitExpr = new (C) InjectIntoOptionalExpr (loadExpr, addressType);
462
+ auto *thenInitCallExpr = CallExpr::createImplicit (
463
+ C, memberInitExpr, {thenInitExpr}, {C.getIdentifier (" _owning" )});
464
+
465
+ // Create a nil expression with type UnsafePointer<CTensorHandle>? for the
466
+ // `else` branch.
467
+ auto *nilDecl = C.getOptionalNoneDecl ();
468
+ auto *nilDRE = new (C) DeclRefExpr (
469
+ nilDecl, DeclNameLoc (), /* implicit*/ true );
470
+ auto *elseInitExpr = new (C) DotSyntaxCallExpr (
471
+ nilDRE, SourceLoc (), addressTE);
472
+ auto *elseInitCallExpr = CallExpr::createImplicit (
473
+ C, memberInitExpr, {elseInitExpr}, {C.getIdentifier (" _owning" )});
474
+
475
+ // Assign the current member to the result of the initializer call.
476
+ auto *memberDRE = new (C) MemberRefExpr (
477
+ selfDRE, SourceLoc (), member, DeclNameLoc (), /* Implicit*/ true );
478
+
479
+ auto *thenAssignMemberExpr = new (C) AssignExpr (
480
+ memberDRE, SourceLoc (), thenInitCallExpr, /* Implicit*/ true );
481
+ auto *elseAssignMemberExpr = new (C) AssignExpr (
482
+ memberDRE, SourceLoc (), elseInitCallExpr, /* Implicit*/ true );
483
+
484
+ thenMemberExprs.push_back (thenAssignMemberExpr);
485
+ elseMemberExprs.push_back (elseAssignMemberExpr);
486
+
487
+ // Advance the current address.
488
+ DeclName advancedName (C, C.getIdentifier (" advanced" ),
489
+ {C.getIdentifier (" by" )});
490
+ auto *advancedMethodExpr =
491
+ new (C) UnresolvedDotExpr (addressDRE, SourceLoc (),
492
+ advancedName, DeclNameLoc (),
493
+ /* Implicit*/ true );
494
+
495
+ // Obtain `MemberType._tensorHandleCount`.
496
+ auto *memberCountMRE = new (C) MemberRefExpr (
497
+ memberDRE, SourceLoc (), tensorHandleCountReq, DeclNameLoc (),
498
+ /* Implicit*/ true );
499
+
500
+ // Cast the tensor handle count to Int.
501
+ auto intInitName = DeclName (C, DeclBaseName::createConstructor (),
502
+ {Identifier ()});
503
+ auto *intInitExpr =
504
+ new (C) UnresolvedDotExpr (intTE, SourceLoc (), intInitName,
505
+ DeclNameLoc (), /* Implicit*/ true );
506
+ auto *intInitCallExpr = CallExpr::createImplicit (
507
+ C, intInitExpr, {memberCountMRE}, {Identifier ()});
508
+
509
+ // Assign the new address.
510
+ auto *assignAddrCallExpr = CallExpr::createImplicit (
511
+ C, advancedMethodExpr, {intInitCallExpr}, {C.getIdentifier (" by" )});
512
+ auto *assignAddrExpr = new (C) AssignExpr (addressDRE, SourceLoc (),
513
+ assignAddrCallExpr,
514
+ /* Implicit*/ true );
515
+
516
+ thenMemberExprs.push_back (assignAddrExpr);
517
+ }
518
+
519
+ auto *thenBody = BraceStmt::create (
520
+ C, SourceLoc (), C.AllocateCopy (thenMemberExprs), SourceLoc (),
521
+ /* implicit*/ true );
522
+
523
+ auto *elseBody = BraceStmt::create (
524
+ C, SourceLoc (), C.AllocateCopy (elseMemberExprs), SourceLoc (),
525
+ /* implicit*/ true );
526
+
527
+ auto *ifStmt = new (C)
528
+ IfStmt (LabeledStmtInfo (), /* IfLoc*/ SourceLoc (),
529
+ /* Cond*/ C.AllocateCopy (cond), /* Then*/ thenBody,
530
+ /* ElseLoc*/ SourceLoc (), /* Else*/ elseBody, /* implicit*/ true );
531
+
532
+ funcDecl->setBody (BraceStmt::create (C, SourceLoc (), {ifStmt}, SourceLoc (),
533
+ /* implicit*/ true ));
534
+ }
535
+
536
+ // Synthesize a constructor declaration for a `TensorArrayProtocol`
537
+ // method requirement.
538
+ static ValueDecl *deriveTensorArrayProtocol_constructor (
539
+ DerivedConformance &derived, Identifier argument1Name,
540
+ Identifier parameter1Name, Type parameter1Type,
541
+ Identifier parameter2Name, Type parameter2Type, Type returnType,
542
+ AbstractFunctionDecl::BodySynthesizer bodySynthesizer) {
543
+ auto nominal = derived.Nominal ;
544
+ auto &C = derived.TC .Context ;
545
+ auto parentDC = derived.getConformanceContext ();
546
+
547
+ auto *param1 =
548
+ new (C) ParamDecl (VarDecl::Specifier::Default, SourceLoc (), SourceLoc (),
549
+ argument1Name, SourceLoc (), parameter1Name, parentDC);
550
+ param1->setInterfaceType (parameter1Type);
551
+ auto *param2 =
552
+ new (C) ParamDecl (VarDecl::Specifier::Default, SourceLoc (), SourceLoc (),
553
+ parameter2Name, SourceLoc (), parameter2Name, parentDC);
554
+ param2->setInterfaceType (parameter2Type);
555
+ ParameterList *params = ParameterList::create (C, {param1, param2});
556
+
557
+ DeclName name (C, DeclBaseName::createConstructor (), params);
558
+ auto *initDecl =
559
+ new (C) ConstructorDecl (name, SourceLoc (), OTK_None, SourceLoc (),
560
+ /* Throws*/ false , SourceLoc (), params,
561
+ /* GenericParams*/ nullptr , parentDC);
562
+ initDecl->setImplicit ();
563
+ initDecl->setSynthesized ();
564
+ initDecl->setBodySynthesizer (bodySynthesizer);
565
+
566
+ if (auto env = parentDC->getGenericEnvironmentOfContext ())
567
+ initDecl->setGenericEnvironment (env);
568
+ initDecl->computeType (AnyFunctionType::ExtInfo ().withThrows (false ));
569
+ initDecl->copyFormalAccessFrom (nominal, /* sourceIsParentContext*/ true );
570
+ initDecl->setValidationToChecked ();
571
+
572
+ derived.addMembersToConformanceContext ({initDecl});
573
+ C.addSynthesizedDecl (initDecl);
574
+
575
+ return initDecl;
576
+ }
577
+
578
+ // Synthesize the `init(_owning:count:)` function declaration.
579
+ static ValueDecl
580
+ *deriveTensorArrayProtocol_init (DerivedConformance &derived) {
581
+ auto &C = derived.TC .Context ;
582
+
583
+ // Obtain the address type.
584
+ auto cTensorHandleType = C.getOpaquePointerDecl ()->getDeclaredType ();
585
+ Type baseAddressType = BoundGenericType::get (
586
+ C.getUnsafePointerDecl (), Type (), {cTensorHandleType});
587
+ Type addressType = BoundGenericType::get (
588
+ C.getOptionalDecl (), Type (), {baseAddressType});
589
+ Type intType = C.getIntDecl ()->getDeclaredType ();
590
+ Type voidType = C.getVoidDecl ()->getDeclaredInterfaceType ();
591
+
592
+ return deriveTensorArrayProtocol_constructor (
593
+ derived, C.getIdentifier (" _owning" ), C.getIdentifier (" tensorHandles" ),
594
+ addressType, C.getIdentifier (" count" ), intType, voidType,
595
+ deriveBodyTensorArrayProtocol_init);
596
+ }
597
+
352
598
ValueDecl *DerivedConformance::deriveTensorArrayProtocol (
353
599
ValueDecl *requirement) {
354
600
if (requirement->getBaseName () == TC.Context .Id_unpackTensorHandles )
355
601
return deriveTensorArrayProtocol_unpackTensorHandles (*this );
356
602
if (requirement->getBaseName () == TC.Context .Id_tensorHandleCount )
357
603
return deriveTensorArrayProtocol_tensorHandleCount (*this );
604
+ if (requirement->getBaseName () == DeclBaseName::createConstructor ())
605
+ return deriveTensorArrayProtocol_init (*this );
358
606
TC.diagnose (requirement->getLoc (),
359
607
diag::broken_tensor_array_protocol_requirement);
360
608
return nullptr ;
0 commit comments