Skip to content

Commit baee206

Browse files
eaplataniosrxwei
authored andcommitted
[TF] Updates to TensorArrayProtocol (#24229)
* Changed 'TensorArrayProtocol' such that it can be used to support output tensor arrays in raw ops. * Added a '_typeList' property to 'TensorArrayProtocol'. Friend PR: tensorflow/swift-bindings#26 .
1 parent b9e9051 commit baee206

File tree

6 files changed

+390
-17
lines changed

6 files changed

+390
-17
lines changed

lib/Sema/DerivedConformanceTensorArrayProtocol.cpp

Lines changed: 320 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,17 @@ bool DerivedConformance::canDeriveTensorArrayProtocol(NominalTypeDecl *nominal,
3737
auto *structDecl = dyn_cast<StructDecl>(nominal);
3838
if (!structDecl)
3939
return false;
40-
// All stored properties must conform to `TensorArrayProtocol`.
40+
// All stored properties must conform to `TensorGroup`.
4141
auto &C = nominal->getASTContext();
42-
auto *tensorArrayProto =
43-
C.getProtocol(KnownProtocolKind::TensorArrayProtocol);
42+
auto *tensorGroupProto =
43+
C.getProtocol(KnownProtocolKind::TensorGroup);
4444
return llvm::all_of(structDecl->getStoredProperties(), [&](VarDecl *v) {
4545
if (!v->hasInterfaceType())
4646
C.getLazyResolver()->resolveDeclSignature(v);
4747
if (!v->hasInterfaceType())
4848
return false;
4949
auto varType = DC->mapTypeIntoContext(v->getValueInterfaceType());
50-
return (bool)TypeChecker::conformsToProtocol(varType, tensorArrayProto, DC,
50+
return (bool)TypeChecker::conformsToProtocol(varType, tensorGroupProto, DC,
5151
ConformanceCheckFlags::Used);
5252
});
5353
}
@@ -66,6 +66,20 @@ static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) {
6666
return lookup.front();
6767
}
6868

69+
// Return the protocol requirement with the specified name.
70+
static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, DeclName name) {
71+
auto lookup = proto->lookupDirect(name);
72+
lookup.erase(std::remove_if(lookup.begin(), lookup.end(),
73+
[](ValueDecl *v) {
74+
return !isa<ProtocolDecl>(
75+
v->getDeclContext()) ||
76+
!v->isProtocolRequirement();
77+
}),
78+
lookup.end());
79+
assert(lookup.size() == 1 && "Ambiguous protocol requirement");
80+
return lookup.front();
81+
}
82+
6983
// Synthesize body for `_unpackTensorHandles(into:)`.
7084
static void
7185
deriveBodyTensorArrayProtocol_unpackTensorHandles(
@@ -349,12 +363,314 @@ static ValueDecl *deriveTensorArrayProtocol_tensorHandleCount(
349363
return tensorHandleCountDecl;
350364
}
351365

366+
367+
/// Derive the body for the '_typeList' getter.
368+
static void
369+
deriveBodyTensorArrayProtocol_typeList(AbstractFunctionDecl *funcDecl) {
370+
auto *parentDC = funcDecl->getParent();
371+
auto *nominal = funcDecl->getDeclContext()->getSelfNominalTypeDecl();
372+
auto &C = nominal->getASTContext();
373+
374+
auto *tensorGroupProto = C.getProtocol(KnownProtocolKind::TensorGroup);
375+
auto *typeListReq = getProtocolRequirement(tensorGroupProto, C.Id_typeList);
376+
377+
// Concatenate all member `_typeList` arrays.
378+
Type arrayType = BoundGenericType::get(
379+
C.getArrayDecl(), Type(),
380+
{C.getTensorDataTypeDecl()->getDeclaredInterfaceType()});
381+
auto *arrayTypeExpr = TypeExpr::createImplicit(arrayType, C);
382+
auto plusOpLookup = C.getArrayDecl()->lookupDirect(C.getIdentifier("+"));
383+
assert(plusOpLookup.size() == 1 && "Ambiguous 'Array.+' operator.");
384+
ValueDecl *plusOpDecl = plusOpLookup.front();
385+
auto plusOpDRE = new (C)
386+
DeclRefExpr(plusOpDecl, DeclNameLoc(), /*Implicit*/ true);
387+
auto plusOpExpr = new (C)
388+
DotSyntaxCallExpr(plusOpDRE, SourceLoc(), arrayTypeExpr);
389+
Expr *typeListExpr = ArrayExpr::create(C, SourceLoc(), {}, {}, SourceLoc());
390+
for (auto member : nominal->getStoredProperties()) {
391+
auto memberType =
392+
parentDC->mapTypeIntoContext(member->getValueInterfaceType());
393+
auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C);
394+
auto *memberTypeListExpr = new (C)
395+
MemberRefExpr(memberTypeExpr, SourceLoc(), typeListReq,
396+
DeclNameLoc(), /*Implicit*/ true);
397+
// Create expression `lhsArg + rhsArg`.
398+
auto *plusOpArgs =
399+
TupleExpr::create(C, SourceLoc(), {typeListExpr, memberTypeListExpr},
400+
{}, {}, SourceLoc(), /*HasTrailingClosure*/ false,
401+
/*Implicit*/ true);
402+
typeListExpr = new (C) BinaryExpr(plusOpExpr, plusOpArgs,
403+
/*Implicit*/ true);
404+
}
405+
406+
// Return the resulting data types array.
407+
auto *returnStmt = new (C) ReturnStmt(SourceLoc(), typeListExpr);
408+
auto *body = BraceStmt::create(C, SourceLoc(), {returnStmt}, SourceLoc(),
409+
/*Implicit*/ true);
410+
funcDecl->setBody(BraceStmt::create(C, SourceLoc(), {body}, SourceLoc(),
411+
/*Implicit*/ true));
412+
}
413+
414+
/// Derive a '_typeList' implementation.
415+
static ValueDecl *deriveTensorArrayProtocol_typeList(
416+
DerivedConformance &derived) {
417+
auto nominal = derived.Nominal;
418+
auto &TC = derived.TC;
419+
ASTContext &C = TC.Context;
420+
421+
auto parentDC = derived.getConformanceContext();
422+
Type dataTypeArrayType = BoundGenericType::get(
423+
C.getArrayDecl(), Type(),
424+
{C.getTensorDataTypeDecl()->getDeclaredInterfaceType()});
425+
auto returnType = parentDC->mapTypeIntoContext(dataTypeArrayType);
426+
427+
// Create `_typeList` property declaration.
428+
VarDecl *typeListDecl;
429+
PatternBindingDecl *patDecl;
430+
std::tie(typeListDecl, patDecl) = derived.declareDerivedProperty(
431+
C.Id_typeList, returnType, returnType, /*isStatic*/ false,
432+
/*isFinal*/ false);
433+
434+
// Add `@inlinable` to the `_typeList` declaration.
435+
if (nominal->getEffectiveAccess() > AccessLevel::Internal)
436+
typeListDecl->getAttrs().add(new (C) InlinableAttr(/*implicit*/ true));
437+
438+
// Create `_typeList` getter.
439+
auto *getterDecl = derived.declareDerivedPropertyGetter(
440+
TC, typeListDecl, returnType);
441+
getterDecl->setBodySynthesizer(deriveBodyTensorArrayProtocol_typeList);
442+
typeListDecl->setAccessors(StorageImplInfo::getImmutableComputed(),
443+
SourceLoc(), {getterDecl}, SourceLoc());
444+
derived.addMembersToConformanceContext({getterDecl, typeListDecl, patDecl});
445+
446+
return typeListDecl;
447+
}
448+
449+
// Synthesize body for `init(_owning:count:)`.
450+
static void
451+
deriveBodyTensorArrayProtocol_init(AbstractFunctionDecl *funcDecl) {
452+
auto *parentDC = funcDecl->getParent();
453+
auto *nominal = parentDC->getSelfNominalTypeDecl();
454+
auto &C = nominal->getASTContext();
455+
456+
// Obtain the address type.
457+
auto cTensorHandleType = C.getOpaquePointerDecl()->getDeclaredType();
458+
auto baseAddressType = BoundGenericType::get(
459+
C.getUnsafePointerDecl(), Type(), {cTensorHandleType});
460+
auto addressType = BoundGenericType::get(
461+
C.getOptionalDecl(), Type(), {baseAddressType});
462+
auto *addressTE = TypeExpr::createImplicit(addressType, C);
463+
464+
// Get references to `self` and parameter declarations.
465+
auto *selfDecl = funcDecl->getImplicitSelfDecl();
466+
auto *selfDRE = new (C)
467+
DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true);
468+
auto *paramDecl = funcDecl->getParameters()->get(0);
469+
auto *paramDRE = new (C)
470+
DeclRefExpr(paramDecl, DeclNameLoc(), /*Implicit*/ true);
471+
472+
// Create an `if var` statement for the current address.
473+
VarDecl *currAddressDecl = new (C) VarDecl(
474+
/*IsStatic*/ false, VarDecl::Specifier::Var, /*IsCaptureList*/ false,
475+
SourceLoc(), C.getIdentifier("currentAddress"), funcDecl);
476+
currAddressDecl->setImplicit();
477+
currAddressDecl->setHasNonPatternBindingInit(true);
478+
currAddressDecl->setInterfaceType(baseAddressType);
479+
currAddressDecl->setValidationToChecked();
480+
481+
Pattern *currAddressPat = new (C)
482+
NamedPattern(currAddressDecl, /*implicit*/ true);
483+
currAddressPat = new (C)
484+
VarPattern(SourceLoc(), /*isLet*/ false, currAddressPat,
485+
/*implicit*/ true);
486+
currAddressPat = new (C)
487+
OptionalSomePattern(currAddressPat, currAddressPat->getEndLoc(),
488+
/*implicit*/ true);
489+
StmtConditionElement cond[] = {
490+
StmtConditionElement(SourceLoc(), currAddressPat, /*Init*/ paramDRE)};
491+
492+
// Get the necessary protocol requirements.
493+
auto *tensorGroupProto = C.getProtocol(KnownProtocolKind::TensorGroup);
494+
auto *tensorArrayProto = C.getProtocol(
495+
KnownProtocolKind::TensorArrayProtocol);
496+
auto initName = DeclName(
497+
C, DeclBaseName::createConstructor(), {C.getIdentifier("_owning")});
498+
auto *initReq = getProtocolRequirement(tensorGroupProto, initName);
499+
auto *tensorHandleCountReq = getProtocolRequirement(
500+
tensorArrayProto, C.Id_tensorHandleCount);
501+
502+
Type intType = C.getIntDecl()->getDeclaredType();
503+
TypeExpr *intTE = TypeExpr::createImplicit(intType, C);
504+
505+
// Iterate over members and call `self.t = T(_owning:)`.
506+
llvm::SmallVector<ASTNode, 2> thenMemberExprs;
507+
llvm::SmallVector<ASTNode, 2> elseMemberExprs;
508+
for (auto member : nominal->getStoredProperties()) {
509+
auto memberType = parentDC->mapTypeIntoContext(
510+
member->getValueInterfaceType());
511+
auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C);
512+
auto module = nominal->getModuleContext();
513+
auto confRef = module->lookupConformance(
514+
memberType, tensorGroupProto);
515+
assert(confRef && "Member does not conform to `TensorGroup`");
516+
517+
// Get member type's constructor, e.g. `MemberType.init(_owning:)`.
518+
// Use protocol requirement declaration for the method by default: this
519+
// will be dynamically dispatched.
520+
ValueDecl *memberInitDecl = initReq;
521+
// If conformance reference is concrete, then use concrete witness
522+
// declaration for the constructor.
523+
if (confRef->isConcrete())
524+
memberInitDecl = confRef->getConcrete()->getWitnessDecl(
525+
initReq, C.getLazyResolver());
526+
assert(memberInitDecl && "Member constructor declaration must exist");
527+
auto memberInitDRE = new (C) DeclRefExpr(
528+
memberInitDecl, DeclNameLoc(), /*implicit*/ true);
529+
memberInitDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
530+
531+
// Create reference to member constructor: `MemberType.init(_owning:)`.
532+
auto *memberInitExpr = new (C) ConstructorRefCallExpr(
533+
memberInitDRE, memberTypeExpr);
534+
535+
auto *addressDRE = new (C) DeclRefExpr(
536+
currAddressDecl, DeclNameLoc(), /*implicit*/ true);
537+
auto *loadExpr = new (C) LoadExpr(addressDRE, baseAddressType);
538+
539+
// Initialize the member using its TensorGroup constructor.
540+
// Note that, initialization is dependent on the branch of the
541+
// if-statement taken.
542+
auto *thenInitExpr = new (C) InjectIntoOptionalExpr(loadExpr, addressType);
543+
auto *thenInitCallExpr = CallExpr::createImplicit(
544+
C, memberInitExpr, {thenInitExpr}, {C.getIdentifier("_owning")});
545+
546+
// Create a nil expression with type UnsafePointer<CTensorHandle>? for the
547+
// `else` branch.
548+
auto *nilDecl = C.getOptionalNoneDecl();
549+
auto *nilDRE = new (C) DeclRefExpr(
550+
nilDecl, DeclNameLoc(), /*implicit*/ true);
551+
auto *elseInitExpr = new (C) DotSyntaxCallExpr(
552+
nilDRE, SourceLoc(), addressTE);
553+
auto *elseInitCallExpr = CallExpr::createImplicit(
554+
C, memberInitExpr, {elseInitExpr}, {C.getIdentifier("_owning")});
555+
556+
// Assign the current member to the result of the initializer call.
557+
auto *memberDRE = new (C) MemberRefExpr(
558+
selfDRE, SourceLoc(), member, DeclNameLoc(), /*Implicit*/ true);
559+
560+
auto *thenAssignMemberExpr = new (C) AssignExpr(
561+
memberDRE, SourceLoc(), thenInitCallExpr, /*Implicit*/ true);
562+
auto *elseAssignMemberExpr = new (C) AssignExpr(
563+
memberDRE, SourceLoc(), elseInitCallExpr, /*Implicit*/ true);
564+
565+
thenMemberExprs.push_back(thenAssignMemberExpr);
566+
elseMemberExprs.push_back(elseAssignMemberExpr);
567+
568+
// Advance the current address.
569+
DeclName advancedName(C, C.getIdentifier("advanced"),
570+
{C.getIdentifier("by")});
571+
auto *advancedMethodExpr =
572+
new (C) UnresolvedDotExpr(addressDRE, SourceLoc(),
573+
advancedName, DeclNameLoc(),
574+
/*Implicit*/ true);
575+
576+
// Obtain `MemberType._tensorHandleCount`.
577+
auto *memberCountMRE = new (C) MemberRefExpr(
578+
memberDRE, SourceLoc(), tensorHandleCountReq, DeclNameLoc(),
579+
/*Implicit*/ true);
580+
581+
// Cast the tensor handle count to Int.
582+
auto intInitName = DeclName(C, DeclBaseName::createConstructor(),
583+
{Identifier()});
584+
auto *intInitExpr =
585+
new (C) UnresolvedDotExpr(intTE, SourceLoc(), intInitName,
586+
DeclNameLoc(), /*Implicit*/ true);
587+
auto *intInitCallExpr = CallExpr::createImplicit(
588+
C, intInitExpr, {memberCountMRE}, {Identifier()});
589+
590+
// Assign the new address.
591+
auto *assignAddrCallExpr = CallExpr::createImplicit(
592+
C, advancedMethodExpr, {intInitCallExpr}, {C.getIdentifier("by")});
593+
auto *assignAddrExpr = new (C) AssignExpr(addressDRE, SourceLoc(),
594+
assignAddrCallExpr,
595+
/*Implicit*/ true);
596+
597+
thenMemberExprs.push_back(assignAddrExpr);
598+
}
599+
600+
auto *thenBody = BraceStmt::create(
601+
C, SourceLoc(), C.AllocateCopy(thenMemberExprs), SourceLoc(),
602+
/*implicit*/ true);
603+
604+
auto *elseBody = BraceStmt::create(
605+
C, SourceLoc(), C.AllocateCopy(elseMemberExprs), SourceLoc(),
606+
/*implicit*/ true);
607+
608+
auto *ifStmt = new (C)
609+
IfStmt(LabeledStmtInfo(), /*IfLoc*/ SourceLoc(),
610+
/*Cond*/ C.AllocateCopy(cond), /*Then*/ thenBody,
611+
/*ElseLoc*/ SourceLoc(), /*Else*/ elseBody, /*implicit*/ true);
612+
613+
funcDecl->setBody(BraceStmt::create(C, SourceLoc(), {ifStmt}, SourceLoc(),
614+
/*implicit*/ true));
615+
}
616+
617+
// Synthesize the `init(_owning:count:)` function declaration.
618+
static ValueDecl
619+
*deriveTensorArrayProtocol_init(DerivedConformance &derived) {
620+
auto &C = derived.TC.Context;
621+
auto nominal = derived.Nominal;
622+
auto parentDC = derived.getConformanceContext();
623+
624+
// Obtain the address type.
625+
auto cTensorHandleType = C.getOpaquePointerDecl()->getDeclaredType();
626+
Type baseAddressType = BoundGenericType::get(
627+
C.getUnsafePointerDecl(), Type(), {cTensorHandleType});
628+
Type addressType = BoundGenericType::get(
629+
C.getOptionalDecl(), Type(), {baseAddressType});
630+
Type intType = C.getIntDecl()->getDeclaredType();
631+
632+
auto *param1 = new (C) ParamDecl(
633+
VarDecl::Specifier::Default, SourceLoc(), SourceLoc(),
634+
C.getIdentifier("_owning"), SourceLoc(), C.getIdentifier("tensorHandles"),
635+
parentDC);
636+
param1->setInterfaceType(addressType);
637+
auto *param2 = new (C) ParamDecl(
638+
VarDecl::Specifier::Default, SourceLoc(), SourceLoc(),
639+
C.getIdentifier("count"), SourceLoc(), C.getIdentifier("count"), parentDC);
640+
param2->setInterfaceType(intType);
641+
ParameterList *params = ParameterList::create(C, {param1, param2});
642+
643+
DeclName name(C, DeclBaseName::createConstructor(), params);
644+
auto *initDecl =
645+
new (C) ConstructorDecl(name, SourceLoc(), OTK_None, SourceLoc(),
646+
/*Throws*/ false, SourceLoc(), params,
647+
/*GenericParams*/ nullptr, parentDC);
648+
initDecl->setImplicit();
649+
initDecl->setSynthesized();
650+
initDecl->setBodySynthesizer(deriveBodyTensorArrayProtocol_init);
651+
652+
if (auto env = parentDC->getGenericEnvironmentOfContext())
653+
initDecl->setGenericEnvironment(env);
654+
initDecl->computeType(AnyFunctionType::ExtInfo().withThrows(false));
655+
initDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);
656+
initDecl->setValidationToChecked();
657+
658+
derived.addMembersToConformanceContext({initDecl});
659+
C.addSynthesizedDecl(initDecl);
660+
661+
return initDecl;
662+
}
663+
352664
ValueDecl *DerivedConformance::deriveTensorArrayProtocol(
353665
ValueDecl *requirement) {
354666
if (requirement->getBaseName() == TC.Context.Id_unpackTensorHandles)
355667
return deriveTensorArrayProtocol_unpackTensorHandles(*this);
356668
if (requirement->getBaseName() == TC.Context.Id_tensorHandleCount)
357669
return deriveTensorArrayProtocol_tensorHandleCount(*this);
670+
if (requirement->getBaseName() == TC.Context.Id_typeList)
671+
return deriveTensorArrayProtocol_typeList(*this);
672+
if (requirement->getBaseName() == DeclBaseName::createConstructor())
673+
return deriveTensorArrayProtocol_init(*this);
358674
TC.diagnose(requirement->getLoc(),
359675
diag::broken_tensor_array_protocol_requirement);
360676
return nullptr;

lib/Sema/DerivedConformances.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,11 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc,
230230
// TensorArrayProtocol._tensorHandleCount
231231
if (name.isSimpleName(ctx.Id_tensorHandleCount))
232232
return getRequirement(KnownProtocolKind::TensorArrayProtocol);
233+
234+
// SWIFT_ENABLE_TENSORFLOW
235+
// TensorArrayProtocol._typeList
236+
if (name.isSimpleName(ctx.Id_typeList) && !requirement->isStatic())
237+
return getRequirement(KnownProtocolKind::TensorArrayProtocol);
233238

234239
// SWIFT_ENABLE_TENSORFLOW
235240
// TensorGroup._typeList
@@ -340,6 +345,13 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc,
340345
if (argumentNames[0] == ctx.getIdentifier("_owning")) {
341346
return getRequirement(KnownProtocolKind::TensorGroup);
342347
}
348+
} else if (argumentNames.size() == 2) {
349+
// SWIFT_ENABLE_TENSORFLOW
350+
// TensorArrayProtocol.init(_owning:count)
351+
if (argumentNames[0] == ctx.getIdentifier("_owning") &&
352+
argumentNames[1] == ctx.getIdentifier("count")) {
353+
return getRequirement(KnownProtocolKind::TensorArrayProtocol);
354+
}
343355
}
344356

345357
return nullptr;

0 commit comments

Comments
 (0)