Skip to content

Commit 701e31d

Browse files
committed
Updated the 'TensorArrayProtocol' and its automatic derivation implementation.
1 parent ce4dfd3 commit 701e31d

File tree

4 files changed

+269
-7
lines changed

4 files changed

+269
-7
lines changed

lib/Sema/DerivedConformanceTensorArrayProtocol.cpp

Lines changed: 252 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,17 @@ bool DerivedConformance::canDeriveTensorArrayProtocol(NominalTypeDecl *nominal,
3737
auto *structDecl = dyn_cast<StructDecl>(nominal);
3838
if (!structDecl)
3939
return false;
40-
// All stored properties must conform to `TensorArrayProtocol`.
40+
// All stored properties must conform to `TensorGroup`.
4141
auto &C = nominal->getASTContext();
42-
auto *tensorArrayProto =
43-
C.getProtocol(KnownProtocolKind::TensorArrayProtocol);
42+
auto *tensorGroupProto =
43+
C.getProtocol(KnownProtocolKind::TensorGroup);
4444
return llvm::all_of(structDecl->getStoredProperties(), [&](VarDecl *v) {
4545
if (!v->hasInterfaceType())
4646
C.getLazyResolver()->resolveDeclSignature(v);
4747
if (!v->hasInterfaceType())
4848
return false;
4949
auto varType = DC->mapTypeIntoContext(v->getValueInterfaceType());
50-
return (bool)TypeChecker::conformsToProtocol(varType, tensorArrayProto, DC,
50+
return (bool)TypeChecker::conformsToProtocol(varType, tensorGroupProto, DC,
5151
ConformanceCheckFlags::Used);
5252
});
5353
}
@@ -66,6 +66,20 @@ static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) {
6666
return lookup.front();
6767
}
6868

69+
// Return the protocol requirement with the specified name.
70+
static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, DeclName name) {
71+
auto lookup = proto->lookupDirect(name);
72+
lookup.erase(std::remove_if(lookup.begin(), lookup.end(),
73+
[](ValueDecl *v) {
74+
return !isa<ProtocolDecl>(
75+
v->getDeclContext()) ||
76+
!v->isProtocolRequirement();
77+
}),
78+
lookup.end());
79+
assert(lookup.size() == 1 && "Ambiguous protocol requirement");
80+
return lookup.front();
81+
}
82+
6983
// Synthesize body for `_unpackTensorHandles(into:)`.
7084
static void
7185
deriveBodyTensorArrayProtocol_unpackTensorHandles(
@@ -349,12 +363,246 @@ static ValueDecl *deriveTensorArrayProtocol_tensorHandleCount(
349363
return tensorHandleCountDecl;
350364
}
351365

366+
// Synthesize body for `init(_owning:count:)`.
367+
static void
368+
deriveBodyTensorArrayProtocol_init(AbstractFunctionDecl *funcDecl) {
369+
auto *parentDC = funcDecl->getParent();
370+
auto *nominal = parentDC->getSelfNominalTypeDecl();
371+
auto &C = nominal->getASTContext();
372+
373+
// Obtain the address type.
374+
auto cTensorHandleType = C.getOpaquePointerDecl()->getDeclaredType();
375+
auto baseAddressType = BoundGenericType::get(
376+
C.getUnsafePointerDecl(), Type(), {cTensorHandleType});
377+
auto addressType = BoundGenericType::get(
378+
C.getOptionalDecl(), Type(), {baseAddressType});
379+
auto *addressTE = TypeExpr::createImplicit(addressType, C);
380+
381+
// Get references to `self` and parameter declarations.
382+
auto *selfDecl = funcDecl->getImplicitSelfDecl();
383+
auto *selfDRE = new (C)
384+
DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true);
385+
auto *paramDecl = funcDecl->getParameters()->get(0);
386+
auto *paramDRE = new (C)
387+
DeclRefExpr(paramDecl, DeclNameLoc(), /*Implicit*/ true);
388+
389+
// Create an `if var` statement for the current address.
390+
VarDecl *currAddressDecl = new (C) VarDecl(
391+
/*IsStatic*/ false, VarDecl::Specifier::Var, /*IsCaptureList*/ false,
392+
SourceLoc(), C.getIdentifier("currentAddress"), funcDecl);
393+
currAddressDecl->setImplicit();
394+
currAddressDecl->setHasNonPatternBindingInit(true);
395+
currAddressDecl->setInterfaceType(baseAddressType);
396+
currAddressDecl->setValidationToChecked();
397+
398+
Pattern *currAddressPat = new (C)
399+
NamedPattern(currAddressDecl, /*implicit*/ true);
400+
currAddressPat = new (C)
401+
VarPattern(SourceLoc(), /*isLet*/ false, currAddressPat,
402+
/*implicit*/ true);
403+
currAddressPat = new (C)
404+
OptionalSomePattern(currAddressPat, currAddressPat->getEndLoc(),
405+
/*implicit*/ true);
406+
StmtConditionElement cond[] = {
407+
StmtConditionElement(SourceLoc(), currAddressPat, /*Init*/ paramDRE)};
408+
409+
// Get the necessary protocol requirements.
410+
auto *tensorGroupProto = C.getProtocol(KnownProtocolKind::TensorGroup);
411+
auto *tensorArrayProto = C.getProtocol(
412+
KnownProtocolKind::TensorArrayProtocol);
413+
auto initName = DeclName(
414+
C, DeclBaseName::createConstructor(),
415+
{C.getIdentifier("_owning"), C.getIdentifier("count")});
416+
auto *initReq = getProtocolRequirement(tensorArrayProto, initName);
417+
auto *tensorHandleCountReq = getProtocolRequirement(
418+
tensorArrayProto, C.Id_tensorHandleCount);
419+
420+
Type intType = C.getIntDecl()->getDeclaredType();
421+
TypeExpr *intTE = TypeExpr::createImplicit(intType, C);
422+
423+
// Goes through the member TensorGroups and call
424+
// `self.t = T(_owning:count:)`.
425+
llvm::SmallVector<ASTNode, 2> thenMemberExprs;
426+
llvm::SmallVector<ASTNode, 2> elseMemberExprs;
427+
for (auto member : nominal->getStoredProperties()) {
428+
auto memberType = parentDC->mapTypeIntoContext(
429+
member->getValueInterfaceType());
430+
auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C);
431+
auto module = nominal->getModuleContext();
432+
auto confRef = module->lookupConformance(
433+
memberType, tensorGroupProto);
434+
assert(confRef && "Member does not conform to `TensorGroup`");
435+
436+
// Get member type's constructor, e.g. `MemberType.init(_owning:)`.
437+
// Use protocol requirement declaration for the method by default: this
438+
// will be dynamically dispatched.
439+
ValueDecl *memberInitDecl = initReq;
440+
// If conformance reference is concrete, then use concrete witness
441+
// declaration for the constructor.
442+
if (confRef->isConcrete())
443+
memberInitDecl = confRef->getConcrete()->getWitnessDecl(
444+
initReq, C.getLazyResolver());
445+
assert(memberInitDecl && "Member constructor declaration must exist");
446+
auto memberInitDRE = new (C) DeclRefExpr(
447+
memberInitDecl, DeclNameLoc(), /*implicit*/ true);
448+
memberInitDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
449+
450+
// Create reference to member constructor: `MemberType.init(_owning:)`.
451+
auto *memberInitExpr = new (C) ConstructorRefCallExpr(
452+
memberInitDRE, memberTypeExpr);
453+
454+
auto *addressDRE = new (C) DeclRefExpr(
455+
currAddressDecl, DeclNameLoc(), /*implicit*/ true);
456+
auto *loadExpr = new (C) LoadExpr(addressDRE, baseAddressType);
457+
458+
// Initialize the member using its TensorGroup constructor.
459+
// Note that, initialization is dependent on the branch of the
460+
// if-statement taken.
461+
auto *thenInitExpr = new (C) InjectIntoOptionalExpr(loadExpr, addressType);
462+
auto *thenInitCallExpr = CallExpr::createImplicit(
463+
C, memberInitExpr, {thenInitExpr}, {C.getIdentifier("_owning")});
464+
465+
// Create a nil expression with type UnsafePointer<CTensorHandle>? for the
466+
// `else` branch.
467+
auto *nilDecl = C.getOptionalNoneDecl();
468+
auto *nilDRE = new (C) DeclRefExpr(
469+
nilDecl, DeclNameLoc(), /*implicit*/ true);
470+
auto *elseInitExpr = new (C) DotSyntaxCallExpr(
471+
nilDRE, SourceLoc(), addressTE);
472+
auto *elseInitCallExpr = CallExpr::createImplicit(
473+
C, memberInitExpr, {elseInitExpr}, {C.getIdentifier("_owning")});
474+
475+
// Assign the current member to the result of the initializer call.
476+
auto *memberDRE = new (C) MemberRefExpr(
477+
selfDRE, SourceLoc(), member, DeclNameLoc(), /*Implicit*/ true);
478+
479+
auto *thenAssignMemberExpr = new (C) AssignExpr(
480+
memberDRE, SourceLoc(), thenInitCallExpr, /*Implicit*/ true);
481+
auto *elseAssignMemberExpr = new (C) AssignExpr(
482+
memberDRE, SourceLoc(), elseInitCallExpr, /*Implicit*/ true);
483+
484+
thenMemberExprs.push_back(thenAssignMemberExpr);
485+
elseMemberExprs.push_back(elseAssignMemberExpr);
486+
487+
// Advance the current address.
488+
DeclName advancedName(C, C.getIdentifier("advanced"),
489+
{C.getIdentifier("by")});
490+
auto *advancedMethodExpr =
491+
new (C) UnresolvedDotExpr(addressDRE, SourceLoc(),
492+
advancedName, DeclNameLoc(),
493+
/*Implicit*/ true);
494+
495+
// Obtain `MemberType._tensorHandleCount`.
496+
auto *memberCountMRE = new (C) MemberRefExpr(
497+
memberDRE, SourceLoc(), tensorHandleCountReq, DeclNameLoc(),
498+
/*Implicit*/ true);
499+
500+
// Cast the tensor handle count to Int.
501+
auto intInitName = DeclName(C, DeclBaseName::createConstructor(),
502+
{Identifier()});
503+
auto *intInitExpr =
504+
new (C) UnresolvedDotExpr(intTE, SourceLoc(), intInitName,
505+
DeclNameLoc(), /*Implicit*/ true);
506+
auto *intInitCallExpr = CallExpr::createImplicit(
507+
C, intInitExpr, {memberCountMRE}, {Identifier()});
508+
509+
// Assign the new address.
510+
auto *assignAddrCallExpr = CallExpr::createImplicit(
511+
C, advancedMethodExpr, {intInitCallExpr}, {C.getIdentifier("by")});
512+
auto *assignAddrExpr = new (C) AssignExpr(addressDRE, SourceLoc(),
513+
assignAddrCallExpr,
514+
/*Implicit*/ true);
515+
516+
thenMemberExprs.push_back(assignAddrExpr);
517+
}
518+
519+
auto *thenBody = BraceStmt::create(
520+
C, SourceLoc(), C.AllocateCopy(thenMemberExprs), SourceLoc(),
521+
/*implicit*/ true);
522+
523+
auto *elseBody = BraceStmt::create(
524+
C, SourceLoc(), C.AllocateCopy(elseMemberExprs), SourceLoc(),
525+
/*implicit*/ true);
526+
527+
auto *ifStmt = new (C)
528+
IfStmt(LabeledStmtInfo(), /*IfLoc*/ SourceLoc(),
529+
/*Cond*/ C.AllocateCopy(cond), /*Then*/ thenBody,
530+
/*ElseLoc*/ SourceLoc(), /*Else*/ elseBody, /*implicit*/ true);
531+
532+
funcDecl->setBody(BraceStmt::create(C, SourceLoc(), {ifStmt}, SourceLoc(),
533+
/*implicit*/ true));
534+
}
535+
536+
// Synthesize a constructor declaration for a `TensorArrayProtocol`
537+
// method requirement.
538+
static ValueDecl *deriveTensorArrayProtocol_constructor(
539+
DerivedConformance &derived, Identifier argument1Name,
540+
Identifier parameter1Name, Type parameter1Type,
541+
Identifier parameter2Name, Type parameter2Type, Type returnType,
542+
AbstractFunctionDecl::BodySynthesizer bodySynthesizer) {
543+
auto nominal = derived.Nominal;
544+
auto &C = derived.TC.Context;
545+
auto parentDC = derived.getConformanceContext();
546+
547+
auto *param1 =
548+
new (C) ParamDecl(VarDecl::Specifier::Default, SourceLoc(), SourceLoc(),
549+
argument1Name, SourceLoc(), parameter1Name, parentDC);
550+
param1->setInterfaceType(parameter1Type);
551+
auto *param2 =
552+
new (C) ParamDecl(VarDecl::Specifier::Default, SourceLoc(), SourceLoc(),
553+
parameter2Name, SourceLoc(), parameter2Name, parentDC);
554+
param2->setInterfaceType(parameter2Type);
555+
ParameterList *params = ParameterList::create(C, {param1, param2});
556+
557+
DeclName name(C, DeclBaseName::createConstructor(), params);
558+
auto *initDecl =
559+
new (C) ConstructorDecl(name, SourceLoc(), OTK_None, SourceLoc(),
560+
/*Throws*/ false, SourceLoc(), params,
561+
/*GenericParams*/ nullptr, parentDC);
562+
initDecl->setImplicit();
563+
initDecl->setSynthesized();
564+
initDecl->setBodySynthesizer(bodySynthesizer);
565+
566+
if (auto env = parentDC->getGenericEnvironmentOfContext())
567+
initDecl->setGenericEnvironment(env);
568+
initDecl->computeType(AnyFunctionType::ExtInfo().withThrows(false));
569+
initDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);
570+
initDecl->setValidationToChecked();
571+
572+
derived.addMembersToConformanceContext({initDecl});
573+
C.addSynthesizedDecl(initDecl);
574+
575+
return initDecl;
576+
}
577+
578+
// Synthesize the `init(_owning:count:)` function declaration.
579+
static ValueDecl
580+
*deriveTensorArrayProtocol_init(DerivedConformance &derived) {
581+
auto &C = derived.TC.Context;
582+
583+
// Obtain the address type.
584+
auto cTensorHandleType = C.getOpaquePointerDecl()->getDeclaredType();
585+
Type baseAddressType = BoundGenericType::get(
586+
C.getUnsafePointerDecl(), Type(), {cTensorHandleType});
587+
Type addressType = BoundGenericType::get(
588+
C.getOptionalDecl(), Type(), {baseAddressType});
589+
Type intType = C.getIntDecl()->getDeclaredType();
590+
Type voidType = C.getVoidDecl()->getDeclaredInterfaceType();
591+
592+
return deriveTensorArrayProtocol_constructor(
593+
derived, C.getIdentifier("_owning"), C.getIdentifier("tensorHandles"),
594+
addressType, C.getIdentifier("count"), intType, voidType,
595+
deriveBodyTensorArrayProtocol_init);
596+
}
597+
352598
ValueDecl *DerivedConformance::deriveTensorArrayProtocol(
353599
ValueDecl *requirement) {
354600
if (requirement->getBaseName() == TC.Context.Id_unpackTensorHandles)
355601
return deriveTensorArrayProtocol_unpackTensorHandles(*this);
356602
if (requirement->getBaseName() == TC.Context.Id_tensorHandleCount)
357603
return deriveTensorArrayProtocol_tensorHandleCount(*this);
604+
if (requirement->getBaseName() == DeclBaseName::createConstructor())
605+
return deriveTensorArrayProtocol_init(*this);
358606
TC.diagnose(requirement->getLoc(),
359607
diag::broken_tensor_array_protocol_requirement);
360608
return nullptr;

lib/Sema/DerivedConformances.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,13 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc,
340340
if (argumentNames[0] == ctx.getIdentifier("_owning")) {
341341
return getRequirement(KnownProtocolKind::TensorGroup);
342342
}
343+
} else if (argumentNames.size() == 2) {
344+
// SWIFT_ENABLE_TENSORFLOW
345+
// TensorArrayProtocol.init(_owning:count)
346+
if (argumentNames[0] == ctx.getIdentifier("_owning") &&
347+
argumentNames[0] == ctx.getIdentifier("count")) {
348+
return getRequirement(KnownProtocolKind::TensorArrayProtocol);
349+
}
343350
}
344351

345352
return nullptr;

stdlib/public/TensorFlow/TensorGroup.swift

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@ import CTensorFlow
2121
/// This protocol is defined separately from `TensorGroup` in order for the
2222
/// number of tensors to be determined at runtime. For example,
2323
/// `[Tensor<Float>]` may have an unknown number of elements at compile time.
24+
///
25+
/// This protocol can be derived automatically for structs whose stored
26+
/// properties all conform to the `TensorGroup` protocol. It cannot be derived
27+
/// automatically for structs whose properties all conform to
28+
/// `TensorArrayProtocol` due to the constructor requirement (i.e., in such
29+
/// cases it would be impossible to know how to break down `count` among the
30+
/// stored properties).
2431
public protocol TensorArrayProtocol {
2532
/// Writes the tensor handles to `address`, which must be allocated
2633
/// with enough capacity to hold `_tensorHandleCount` handles. The tensor

test/TensorFlowRuntime/tensor_array_protocol.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ import StdlibUnittest
1111

1212
var TensorArrayProtocolTests = TestSuite("TensorArrayProtocol")
1313

14-
struct Empty : TensorArrayProtocol {}
14+
struct Empty : TensorGroup {}
1515

16-
struct Simple : TensorArrayProtocol {
16+
struct Simple : TensorGroup {
1717
var w, b: Tensor<Float>
1818
}
1919

@@ -32,7 +32,7 @@ struct Nested : TensorArrayProtocol {
3232
var mixed: Mixed
3333
}
3434

35-
struct Generic<T: TensorArrayProtocol, U: TensorArrayProtocol> : TensorArrayProtocol {
35+
struct Generic<T: TensorGroup, U: TensorGroup> : TensorArrayProtocol {
3636
var t: T
3737
var u: U
3838
}

0 commit comments

Comments
 (0)