Skip to content

[TF] Updates to TensorArrayProtocol derived conformances #24229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 27, 2019
Merged
324 changes: 320 additions & 4 deletions lib/Sema/DerivedConformanceTensorArrayProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,17 @@ bool DerivedConformance::canDeriveTensorArrayProtocol(NominalTypeDecl *nominal,
auto *structDecl = dyn_cast<StructDecl>(nominal);
if (!structDecl)
return false;
// All stored properties must conform to `TensorArrayProtocol`.
// All stored properties must conform to `TensorGroup`.
auto &C = nominal->getASTContext();
auto *tensorArrayProto =
C.getProtocol(KnownProtocolKind::TensorArrayProtocol);
auto *tensorGroupProto =
C.getProtocol(KnownProtocolKind::TensorGroup);
return llvm::all_of(structDecl->getStoredProperties(), [&](VarDecl *v) {
if (!v->hasInterfaceType())
C.getLazyResolver()->resolveDeclSignature(v);
if (!v->hasInterfaceType())
return false;
auto varType = DC->mapTypeIntoContext(v->getValueInterfaceType());
return (bool)TypeChecker::conformsToProtocol(varType, tensorArrayProto, DC,
return (bool)TypeChecker::conformsToProtocol(varType, tensorGroupProto, DC,
ConformanceCheckFlags::Used);
});
}
Expand All @@ -66,6 +66,20 @@ static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) {
return lookup.front();
}

// Return the protocol requirement with the specified name.
static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, DeclName name) {
auto lookup = proto->lookupDirect(name);
lookup.erase(std::remove_if(lookup.begin(), lookup.end(),
[](ValueDecl *v) {
return !isa<ProtocolDecl>(
v->getDeclContext()) ||
!v->isProtocolRequirement();
}),
lookup.end());
assert(lookup.size() == 1 && "Ambiguous protocol requirement");
return lookup.front();
}

// Synthesize body for `_unpackTensorHandles(into:)`.
static void
deriveBodyTensorArrayProtocol_unpackTensorHandles(
Expand Down Expand Up @@ -349,12 +363,314 @@ static ValueDecl *deriveTensorArrayProtocol_tensorHandleCount(
return tensorHandleCountDecl;
}


/// Derive the body for the '_typeList' getter.
static void
deriveBodyTensorArrayProtocol_typeList(AbstractFunctionDecl *funcDecl) {
auto *parentDC = funcDecl->getParent();
auto *nominal = funcDecl->getDeclContext()->getSelfNominalTypeDecl();
auto &C = nominal->getASTContext();

auto *tensorGroupProto = C.getProtocol(KnownProtocolKind::TensorGroup);
auto *typeListReq = getProtocolRequirement(tensorGroupProto, C.Id_typeList);

// Concatenate all member `_typeList` arrays.
Type arrayType = BoundGenericType::get(
C.getArrayDecl(), Type(),
{C.getTensorDataTypeDecl()->getDeclaredInterfaceType()});
auto *arrayTypeExpr = TypeExpr::createImplicit(arrayType, C);
auto plusOpLookup = C.getArrayDecl()->lookupDirect(C.getIdentifier("+"));
assert(plusOpLookup.size() == 1 && "Ambiguous 'Array.+' operator.");
ValueDecl *plusOpDecl = plusOpLookup.front();
auto plusOpDRE = new (C)
DeclRefExpr(plusOpDecl, DeclNameLoc(), /*Implicit*/ true);
auto plusOpExpr = new (C)
DotSyntaxCallExpr(plusOpDRE, SourceLoc(), arrayTypeExpr);
Expr *typeListExpr = ArrayExpr::create(C, SourceLoc(), {}, {}, SourceLoc());
for (auto member : nominal->getStoredProperties()) {
auto memberType =
parentDC->mapTypeIntoContext(member->getValueInterfaceType());
auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C);
auto *memberTypeListExpr = new (C)
MemberRefExpr(memberTypeExpr, SourceLoc(), typeListReq,
DeclNameLoc(), /*Implicit*/ true);
// Create expression `lhsArg + rhsArg`.
auto *plusOpArgs =
TupleExpr::create(C, SourceLoc(), {typeListExpr, memberTypeListExpr},
{}, {}, SourceLoc(), /*HasTrailingClosure*/ false,
/*Implicit*/ true);
typeListExpr = new (C) BinaryExpr(plusOpExpr, plusOpArgs,
/*Implicit*/ true);
}

// Return the resulting data types array.
auto *returnStmt = new (C) ReturnStmt(SourceLoc(), typeListExpr);
auto *body = BraceStmt::create(C, SourceLoc(), {returnStmt}, SourceLoc(),
/*Implicit*/ true);
funcDecl->setBody(BraceStmt::create(C, SourceLoc(), {body}, SourceLoc(),
/*Implicit*/ true));
}

/// Derive a '_typeList' implementation.
static ValueDecl *deriveTensorArrayProtocol_typeList(
DerivedConformance &derived) {
auto nominal = derived.Nominal;
auto &TC = derived.TC;
ASTContext &C = TC.Context;

auto parentDC = derived.getConformanceContext();
Type dataTypeArrayType = BoundGenericType::get(
C.getArrayDecl(), Type(),
{C.getTensorDataTypeDecl()->getDeclaredInterfaceType()});
auto returnType = parentDC->mapTypeIntoContext(dataTypeArrayType);

// Create `_typeList` property declaration.
VarDecl *typeListDecl;
PatternBindingDecl *patDecl;
std::tie(typeListDecl, patDecl) = derived.declareDerivedProperty(
C.Id_typeList, returnType, returnType, /*isStatic*/ false,
/*isFinal*/ false);

// Add `@inlinable` to the `_typeList` declaration.
if (nominal->getEffectiveAccess() > AccessLevel::Internal)
typeListDecl->getAttrs().add(new (C) InlinableAttr(/*implicit*/ true));

// Create `_typeList` getter.
auto *getterDecl = derived.declareDerivedPropertyGetter(
TC, typeListDecl, returnType);
getterDecl->setBodySynthesizer(deriveBodyTensorArrayProtocol_typeList);
typeListDecl->setAccessors(StorageImplInfo::getImmutableComputed(),
SourceLoc(), {getterDecl}, SourceLoc());
derived.addMembersToConformanceContext({getterDecl, typeListDecl, patDecl});

return typeListDecl;
}

// Synthesize body for `init(_owning:count:)`.
static void
deriveBodyTensorArrayProtocol_init(AbstractFunctionDecl *funcDecl) {
auto *parentDC = funcDecl->getParent();
auto *nominal = parentDC->getSelfNominalTypeDecl();
auto &C = nominal->getASTContext();

// Obtain the address type.
auto cTensorHandleType = C.getOpaquePointerDecl()->getDeclaredType();
auto baseAddressType = BoundGenericType::get(
C.getUnsafePointerDecl(), Type(), {cTensorHandleType});
auto addressType = BoundGenericType::get(
C.getOptionalDecl(), Type(), {baseAddressType});
auto *addressTE = TypeExpr::createImplicit(addressType, C);

// Get references to `self` and parameter declarations.
auto *selfDecl = funcDecl->getImplicitSelfDecl();
auto *selfDRE = new (C)
DeclRefExpr(selfDecl, DeclNameLoc(), /*Implicit*/ true);
auto *paramDecl = funcDecl->getParameters()->get(0);
auto *paramDRE = new (C)
DeclRefExpr(paramDecl, DeclNameLoc(), /*Implicit*/ true);

// Create an `if var` statement for the current address.
VarDecl *currAddressDecl = new (C) VarDecl(
/*IsStatic*/ false, VarDecl::Specifier::Var, /*IsCaptureList*/ false,
SourceLoc(), C.getIdentifier("currentAddress"), funcDecl);
currAddressDecl->setImplicit();
currAddressDecl->setHasNonPatternBindingInit(true);
currAddressDecl->setInterfaceType(baseAddressType);
currAddressDecl->setValidationToChecked();

Pattern *currAddressPat = new (C)
NamedPattern(currAddressDecl, /*implicit*/ true);
currAddressPat = new (C)
VarPattern(SourceLoc(), /*isLet*/ false, currAddressPat,
/*implicit*/ true);
currAddressPat = new (C)
OptionalSomePattern(currAddressPat, currAddressPat->getEndLoc(),
/*implicit*/ true);
StmtConditionElement cond[] = {
StmtConditionElement(SourceLoc(), currAddressPat, /*Init*/ paramDRE)};

// Get the necessary protocol requirements.
auto *tensorGroupProto = C.getProtocol(KnownProtocolKind::TensorGroup);
auto *tensorArrayProto = C.getProtocol(
KnownProtocolKind::TensorArrayProtocol);
auto initName = DeclName(
C, DeclBaseName::createConstructor(), {C.getIdentifier("_owning")});
auto *initReq = getProtocolRequirement(tensorGroupProto, initName);
auto *tensorHandleCountReq = getProtocolRequirement(
tensorArrayProto, C.Id_tensorHandleCount);

Type intType = C.getIntDecl()->getDeclaredType();
TypeExpr *intTE = TypeExpr::createImplicit(intType, C);

// Iterate over members and call `self.t = T(_owning:)`.
llvm::SmallVector<ASTNode, 2> thenMemberExprs;
llvm::SmallVector<ASTNode, 2> elseMemberExprs;
for (auto member : nominal->getStoredProperties()) {
auto memberType = parentDC->mapTypeIntoContext(
member->getValueInterfaceType());
auto *memberTypeExpr = TypeExpr::createImplicit(memberType, C);
auto module = nominal->getModuleContext();
auto confRef = module->lookupConformance(
memberType, tensorGroupProto);
assert(confRef && "Member does not conform to `TensorGroup`");

// Get member type's constructor, e.g. `MemberType.init(_owning:)`.
// Use protocol requirement declaration for the method by default: this
// will be dynamically dispatched.
ValueDecl *memberInitDecl = initReq;
// If conformance reference is concrete, then use concrete witness
// declaration for the constructor.
if (confRef->isConcrete())
memberInitDecl = confRef->getConcrete()->getWitnessDecl(
initReq, C.getLazyResolver());
assert(memberInitDecl && "Member constructor declaration must exist");
auto memberInitDRE = new (C) DeclRefExpr(
memberInitDecl, DeclNameLoc(), /*implicit*/ true);
memberInitDRE->setFunctionRefKind(FunctionRefKind::SingleApply);

// Create reference to member constructor: `MemberType.init(_owning:)`.
auto *memberInitExpr = new (C) ConstructorRefCallExpr(
memberInitDRE, memberTypeExpr);

auto *addressDRE = new (C) DeclRefExpr(
currAddressDecl, DeclNameLoc(), /*implicit*/ true);
auto *loadExpr = new (C) LoadExpr(addressDRE, baseAddressType);

// Initialize the member using its TensorGroup constructor.
// Note that, initialization is dependent on the branch of the
// if-statement taken.
auto *thenInitExpr = new (C) InjectIntoOptionalExpr(loadExpr, addressType);
auto *thenInitCallExpr = CallExpr::createImplicit(
C, memberInitExpr, {thenInitExpr}, {C.getIdentifier("_owning")});

// Create a nil expression with type UnsafePointer<CTensorHandle>? for the
// `else` branch.
auto *nilDecl = C.getOptionalNoneDecl();
auto *nilDRE = new (C) DeclRefExpr(
nilDecl, DeclNameLoc(), /*implicit*/ true);
auto *elseInitExpr = new (C) DotSyntaxCallExpr(
nilDRE, SourceLoc(), addressTE);
auto *elseInitCallExpr = CallExpr::createImplicit(
C, memberInitExpr, {elseInitExpr}, {C.getIdentifier("_owning")});

// Assign the current member to the result of the initializer call.
auto *memberDRE = new (C) MemberRefExpr(
selfDRE, SourceLoc(), member, DeclNameLoc(), /*Implicit*/ true);

auto *thenAssignMemberExpr = new (C) AssignExpr(
memberDRE, SourceLoc(), thenInitCallExpr, /*Implicit*/ true);
auto *elseAssignMemberExpr = new (C) AssignExpr(
memberDRE, SourceLoc(), elseInitCallExpr, /*Implicit*/ true);

thenMemberExprs.push_back(thenAssignMemberExpr);
elseMemberExprs.push_back(elseAssignMemberExpr);

// Advance the current address.
DeclName advancedName(C, C.getIdentifier("advanced"),
{C.getIdentifier("by")});
auto *advancedMethodExpr =
new (C) UnresolvedDotExpr(addressDRE, SourceLoc(),
advancedName, DeclNameLoc(),
/*Implicit*/ true);

// Obtain `MemberType._tensorHandleCount`.
auto *memberCountMRE = new (C) MemberRefExpr(
memberDRE, SourceLoc(), tensorHandleCountReq, DeclNameLoc(),
/*Implicit*/ true);

// Cast the tensor handle count to Int.
auto intInitName = DeclName(C, DeclBaseName::createConstructor(),
{Identifier()});
auto *intInitExpr =
new (C) UnresolvedDotExpr(intTE, SourceLoc(), intInitName,
DeclNameLoc(), /*Implicit*/ true);
auto *intInitCallExpr = CallExpr::createImplicit(
C, intInitExpr, {memberCountMRE}, {Identifier()});

// Assign the new address.
auto *assignAddrCallExpr = CallExpr::createImplicit(
C, advancedMethodExpr, {intInitCallExpr}, {C.getIdentifier("by")});
auto *assignAddrExpr = new (C) AssignExpr(addressDRE, SourceLoc(),
assignAddrCallExpr,
/*Implicit*/ true);

thenMemberExprs.push_back(assignAddrExpr);
}

auto *thenBody = BraceStmt::create(
C, SourceLoc(), C.AllocateCopy(thenMemberExprs), SourceLoc(),
/*implicit*/ true);

auto *elseBody = BraceStmt::create(
C, SourceLoc(), C.AllocateCopy(elseMemberExprs), SourceLoc(),
/*implicit*/ true);

auto *ifStmt = new (C)
IfStmt(LabeledStmtInfo(), /*IfLoc*/ SourceLoc(),
/*Cond*/ C.AllocateCopy(cond), /*Then*/ thenBody,
/*ElseLoc*/ SourceLoc(), /*Else*/ elseBody, /*implicit*/ true);

funcDecl->setBody(BraceStmt::create(C, SourceLoc(), {ifStmt}, SourceLoc(),
/*implicit*/ true));
}

// Synthesize the `init(_owning:count:)` function declaration.
static ValueDecl
*deriveTensorArrayProtocol_init(DerivedConformance &derived) {
auto &C = derived.TC.Context;
auto nominal = derived.Nominal;
auto parentDC = derived.getConformanceContext();

// Obtain the address type.
auto cTensorHandleType = C.getOpaquePointerDecl()->getDeclaredType();
Type baseAddressType = BoundGenericType::get(
C.getUnsafePointerDecl(), Type(), {cTensorHandleType});
Type addressType = BoundGenericType::get(
C.getOptionalDecl(), Type(), {baseAddressType});
Type intType = C.getIntDecl()->getDeclaredType();

auto *param1 = new (C) ParamDecl(
VarDecl::Specifier::Default, SourceLoc(), SourceLoc(),
C.getIdentifier("_owning"), SourceLoc(), C.getIdentifier("tensorHandles"),
parentDC);
param1->setInterfaceType(addressType);
auto *param2 = new (C) ParamDecl(
VarDecl::Specifier::Default, SourceLoc(), SourceLoc(),
C.getIdentifier("count"), SourceLoc(), C.getIdentifier("count"), parentDC);
param2->setInterfaceType(intType);
ParameterList *params = ParameterList::create(C, {param1, param2});

DeclName name(C, DeclBaseName::createConstructor(), params);
auto *initDecl =
new (C) ConstructorDecl(name, SourceLoc(), OTK_None, SourceLoc(),
/*Throws*/ false, SourceLoc(), params,
/*GenericParams*/ nullptr, parentDC);
initDecl->setImplicit();
initDecl->setSynthesized();
initDecl->setBodySynthesizer(deriveBodyTensorArrayProtocol_init);

if (auto env = parentDC->getGenericEnvironmentOfContext())
initDecl->setGenericEnvironment(env);
initDecl->computeType(AnyFunctionType::ExtInfo().withThrows(false));
initDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);
initDecl->setValidationToChecked();

derived.addMembersToConformanceContext({initDecl});
C.addSynthesizedDecl(initDecl);

return initDecl;
}

ValueDecl *DerivedConformance::deriveTensorArrayProtocol(
ValueDecl *requirement) {
if (requirement->getBaseName() == TC.Context.Id_unpackTensorHandles)
return deriveTensorArrayProtocol_unpackTensorHandles(*this);
if (requirement->getBaseName() == TC.Context.Id_tensorHandleCount)
return deriveTensorArrayProtocol_tensorHandleCount(*this);
if (requirement->getBaseName() == TC.Context.Id_typeList)
return deriveTensorArrayProtocol_typeList(*this);
if (requirement->getBaseName() == DeclBaseName::createConstructor())
return deriveTensorArrayProtocol_init(*this);
TC.diagnose(requirement->getLoc(),
diag::broken_tensor_array_protocol_requirement);
return nullptr;
Expand Down
12 changes: 12 additions & 0 deletions lib/Sema/DerivedConformances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,11 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc,
// TensorArrayProtocol._tensorHandleCount
if (name.isSimpleName(ctx.Id_tensorHandleCount))
return getRequirement(KnownProtocolKind::TensorArrayProtocol);

// SWIFT_ENABLE_TENSORFLOW
// TensorArrayProtocol._typeList
if (name.isSimpleName(ctx.Id_typeList) && !requirement->isStatic())
return getRequirement(KnownProtocolKind::TensorArrayProtocol);

// SWIFT_ENABLE_TENSORFLOW
// TensorGroup._typeList
Expand Down Expand Up @@ -340,6 +345,13 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc,
if (argumentNames[0] == ctx.getIdentifier("_owning")) {
return getRequirement(KnownProtocolKind::TensorGroup);
}
} else if (argumentNames.size() == 2) {
// SWIFT_ENABLE_TENSORFLOW
// TensorArrayProtocol.init(_owning:count)
if (argumentNames[0] == ctx.getIdentifier("_owning") &&
argumentNames[1] == ctx.getIdentifier("count")) {
return getRequirement(KnownProtocolKind::TensorArrayProtocol);
}
}

return nullptr;
Expand Down
Loading