Skip to content

Commit 82431d4

Browse files
committed
Addressed Dan's comments.
1 parent 6765b94 commit 82431d4

File tree

1 file changed

+24
-38
lines changed

1 file changed

+24
-38
lines changed

lib/Sema/DerivedConformanceTensorArrayProtocol.cpp

Lines changed: 24 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -419,8 +419,7 @@ deriveBodyTensorArrayProtocol_init(AbstractFunctionDecl *funcDecl) {
419419
Type intType = C.getIntDecl()->getDeclaredType();
420420
TypeExpr *intTE = TypeExpr::createImplicit(intType, C);
421421

422-
// Goes through the member TensorGroups and call
423-
// `self.t = T(_owning:count:)`.
422+
// Iterate over members and call `self.t = T(_owning:)`.
424423
llvm::SmallVector<ASTNode, 2> thenMemberExprs;
425424
llvm::SmallVector<ASTNode, 2> elseMemberExprs;
426425
for (auto member : nominal->getStoredProperties()) {
@@ -532,25 +531,32 @@ deriveBodyTensorArrayProtocol_init(AbstractFunctionDecl *funcDecl) {
532531
/*implicit*/ true));
533532
}
534533

535-
// Synthesize a constructor declaration for a `TensorArrayProtocol`
536-
// method requirement.
537-
static ValueDecl *deriveTensorArrayProtocol_constructor(
538-
DerivedConformance &derived, Identifier argument1Name,
539-
Identifier parameter1Name, Type parameter1Type,
540-
Identifier parameter2Name, Type parameter2Type, Type returnType,
541-
AbstractFunctionDecl::BodySynthesizer bodySynthesizer) {
534+
// Synthesize the `init(_owning:count:)` function declaration.
535+
static ValueDecl
536+
*deriveTensorArrayProtocol_init(DerivedConformance &derived) {
537+
auto &C = derived.TC.Context;
538+
539+
// Obtain the address type.
540+
auto cTensorHandleType = C.getOpaquePointerDecl()->getDeclaredType();
541+
Type baseAddressType = BoundGenericType::get(
542+
C.getUnsafePointerDecl(), Type(), {cTensorHandleType});
543+
Type addressType = BoundGenericType::get(
544+
C.getOptionalDecl(), Type(), {baseAddressType});
545+
Type intType = C.getIntDecl()->getDeclaredType();
546+
542547
auto nominal = derived.Nominal;
543548
auto &C = derived.TC.Context;
544549
auto parentDC = derived.getConformanceContext();
545550

546-
auto *param1 =
547-
new (C) ParamDecl(VarDecl::Specifier::Default, SourceLoc(), SourceLoc(),
548-
argument1Name, SourceLoc(), parameter1Name, parentDC);
549-
param1->setInterfaceType(parameter1Type);
550-
auto *param2 =
551-
new (C) ParamDecl(VarDecl::Specifier::Default, SourceLoc(), SourceLoc(),
552-
parameter2Name, SourceLoc(), parameter2Name, parentDC);
553-
param2->setInterfaceType(parameter2Type);
551+
auto *param1 = new (C) ParamDecl(
552+
VarDecl::Specifier::Default, SourceLoc(), SourceLoc(),
553+
C.getIdentifier("_owning"), SourceLoc(), C.getIdentifier("tensorHandles"),
554+
parentDC);
555+
param1->setInterfaceType(addressType);
556+
auto *param2 = new (C) ParamDecl(
557+
VarDecl::Specifier::Default, SourceLoc(), SourceLoc(),
558+
C.getIdentifier("count"), SourceLoc(), C.getIdentifier("count"), parentDC);
559+
param2->setInterfaceType(intType);
554560
ParameterList *params = ParameterList::create(C, {param1, param2});
555561

556562
DeclName name(C, DeclBaseName::createConstructor(), params);
@@ -560,7 +566,7 @@ static ValueDecl *deriveTensorArrayProtocol_constructor(
560566
/*GenericParams*/ nullptr, parentDC);
561567
initDecl->setImplicit();
562568
initDecl->setSynthesized();
563-
initDecl->setBodySynthesizer(bodySynthesizer);
569+
initDecl->setBodySynthesizer(deriveBodyTensorArrayProtocol_init);
564570

565571
if (auto env = parentDC->getGenericEnvironmentOfContext())
566572
initDecl->setGenericEnvironment(env);
@@ -574,26 +580,6 @@ static ValueDecl *deriveTensorArrayProtocol_constructor(
574580
return initDecl;
575581
}
576582

577-
// Synthesize the `init(_owning:count:)` function declaration.
578-
static ValueDecl
579-
*deriveTensorArrayProtocol_init(DerivedConformance &derived) {
580-
auto &C = derived.TC.Context;
581-
582-
// Obtain the address type.
583-
auto cTensorHandleType = C.getOpaquePointerDecl()->getDeclaredType();
584-
Type baseAddressType = BoundGenericType::get(
585-
C.getUnsafePointerDecl(), Type(), {cTensorHandleType});
586-
Type addressType = BoundGenericType::get(
587-
C.getOptionalDecl(), Type(), {baseAddressType});
588-
Type intType = C.getIntDecl()->getDeclaredType();
589-
Type voidType = C.getVoidDecl()->getDeclaredInterfaceType();
590-
591-
return deriveTensorArrayProtocol_constructor(
592-
derived, C.getIdentifier("_owning"), C.getIdentifier("tensorHandles"),
593-
addressType, C.getIdentifier("count"), intType, voidType,
594-
deriveBodyTensorArrayProtocol_init);
595-
}
596-
597583
ValueDecl *DerivedConformance::deriveTensorArrayProtocol(
598584
ValueDecl *requirement) {
599585
if (requirement->getBaseName() == TC.Context.Id_unpackTensorHandles)

0 commit comments

Comments
 (0)