Skip to content

Commit 457df8f

Browse files
author
Eugene Burmako
committed
Fix merge fallout
1 parent ff48eed commit 457df8f

File tree

10 files changed

+1
-333
lines changed

10 files changed

+1
-333
lines changed

include/swift/AST/ASTContext.h

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -291,18 +291,6 @@ class ASTContext final {
291291
/// across invocations of both the parser and the type-checker.
292292
unsigned NextAutoClosureDiscriminator = 0;
293293

294-
// SWIFT_ENABLE_TENSORFLOW
295-
296-
/// Cache of `@differentiable` attributes keyed by parameter indices. Used to
297-
/// diagnose duplicate `@differentiable` attributes for the same key.
298-
// NOTE(TF-680): relaxing the uniqueness condition to use derivative generic
299-
// signature as a key is possible. It requires derivative generic signature
300-
// mangling to avoid name collisions for SIL derivative functions with the
301-
// same parameter indices but different derivative generic signatures.
302-
llvm::DenseMap<std::pair<Decl *, IndexSubset *>, DifferentiableAttr *>
303-
DifferentiableAttrs;
304-
// SWIFT_ENABLE_TENSORFLOW END
305-
306294
/// Cached mapping from types to their associated tangent spaces.
307295
llvm::DenseMap<Type, Optional<TangentSpace>> AutoDiffTangentSpaces;
308296

lib/AST/ASTScopeCreation.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1725,12 +1725,6 @@ NullablePtr<AbstractStorageDecl>
17251725
SpecializeAttributeScope::getEnclosingAbstractStorageDecl() const {
17261726
return getParent().get()->getEnclosingAbstractStorageDecl();
17271727
}
1728-
// SWIFT_ENABLE_TENSORFLOW
1729-
NullablePtr<AbstractStorageDecl>
1730-
DifferentiableAttributeScope::getEnclosingAbstractStorageDecl() const {
1731-
return getParent().get()->getEnclosingAbstractStorageDecl();
1732-
}
1733-
// SWIFT_ENABLE_TENSORFLOW END
17341728
NullablePtr<AbstractStorageDecl>
17351729
DifferentiableAttributeScope::getEnclosingAbstractStorageDecl() const {
17361730
return getParent().get()->getEnclosingAbstractStorageDecl();

lib/AST/Attr.cpp

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1522,33 +1522,6 @@ DifferentiableAttr::create(AbstractFunctionDecl *original, bool implicit,
15221522
std::move(vjp), derivativeGenSig);
15231523
}
15241524

1525-
// SWIFT_ENABLE_TENSORFLOW
1526-
bool DifferentiableAttr::hasComputedParameterIndices() const {
1527-
return ParameterIndicesAndBit.getInt();
1528-
}
1529-
1530-
IndexSubset *DifferentiableAttr::getParameterIndices() const {
1531-
assert(getOriginalDeclaration() &&
1532-
"Original declaration must have been resolved");
1533-
auto &ctx = getOriginalDeclaration()->getASTContext();
1534-
return evaluateOrDefault(
1535-
ctx.evaluator,
1536-
DifferentiableAttributeParameterIndicesRequest{
1537-
const_cast<DifferentiableAttr *>(this), getOriginalDeclaration()},
1538-
nullptr);
1539-
}
1540-
1541-
void DifferentiableAttr::setParameterIndices(IndexSubset *paramIndices) {
1542-
assert(getOriginalDeclaration() &&
1543-
"Original declaration must have been resolved");
1544-
auto &ctx = getOriginalDeclaration()->getASTContext();
1545-
ctx.evaluator.cacheOutput(
1546-
DifferentiableAttributeParameterIndicesRequest{
1547-
const_cast<DifferentiableAttr *>(this), getOriginalDeclaration()},
1548-
std::move(paramIndices));
1549-
}
1550-
// SWIFT_ENABLE_TENSORFLOW END
1551-
15521525
void DifferentiableAttr::setOriginalDeclaration(Decl *originalDeclaration) {
15531526
assert(originalDeclaration && "Original declaration must be non-null");
15541527
assert(!OriginalDeclaration &&

lib/AST/Type.cpp

Lines changed: 0 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -4806,145 +4806,6 @@ Type TypeBase::openAnyExistentialType(OpenedArchetypeType *&opened) {
48064806
return opened;
48074807
}
48084808

4809-
// SWIFT_ENABLE_TENSORFLOW
4810-
// Makes a function with the same generic signature and ExtInfo as `copy`, but
4811-
// with `params` parameters and `retTy` return type.
4812-
static AnyFunctionType *
4813-
makeFunctionType(AnyFunctionType *copy, ArrayRef<AnyFunctionType::Param> params,
4814-
Type retTy, GenericSignature genericSignature) {
4815-
if (!genericSignature)
4816-
if (auto *genericFunctionType = copy->getAs<GenericFunctionType>())
4817-
genericSignature = genericFunctionType->getGenericSignature();
4818-
if (genericSignature)
4819-
return GenericFunctionType::get(genericSignature, params, retTy,
4820-
copy->getExtInfo());
4821-
return FunctionType::get(params, retTy, copy->getExtInfo());
4822-
}
4823-
4824-
AnyFunctionType *AnyFunctionType::getAutoDiffDerivativeFunctionType(
4825-
IndexSubset *indices, unsigned resultIndex,
4826-
AutoDiffDerivativeFunctionKind kind, LookupConformanceFn lookupConformance,
4827-
GenericSignature whereClauseGenSig, bool makeSelfParamFirst) {
4828-
// JVP: (T...) -> ((R...),
4829-
// (T.TangentVector...) -> (R.TangentVector...))
4830-
// VJP: (T...) -> ((R...),
4831-
// (R.TangentVector...) -> (T.TangentVector...))
4832-
//
4833-
// Note that both can be written as "(T...) -> ((R...), Closure)", so we build
4834-
// "Closure" and then use common code to wrap "Closure" in the outer function
4835-
// type.
4836-
4837-
assert(!indices->isEmpty() && "there must be at least one wrt index");
4838-
4839-
auto &ctx = getASTContext();
4840-
4841-
// Get differentiability parameter types.
4842-
SmallVector<Type, 8> diffParamTypes;
4843-
autodiff::getSubsetParameterTypes(indices, this, diffParamTypes,
4844-
/*reverseCurryLevels*/ !makeSelfParamFirst);
4845-
4846-
// Unwrap curry levels. At most, two parameter lists are necessary, for
4847-
// curried method types with a `(Self)` parameter list.
4848-
SmallVector<AnyFunctionType *, 2> curryLevels;
4849-
auto *currentLevel = castTo<AnyFunctionType>();
4850-
for (unsigned i : range(2)) {
4851-
(void)i;
4852-
if (currentLevel == nullptr)
4853-
break;
4854-
curryLevels.push_back(currentLevel);
4855-
currentLevel = currentLevel->getResult()->getAs<AnyFunctionType>();
4856-
}
4857-
4858-
Type originalResult = curryLevels.back()->getResult();
4859-
4860-
// Build the closure type, which is different depending on whether this is a
4861-
// JVP or VJP.
4862-
Type closure;
4863-
switch (kind) {
4864-
case AutoDiffDerivativeFunctionKind::JVP: {
4865-
// closure is the JVP "differential":
4866-
// (T.TangentVector...) -> (R.TangentVector...)
4867-
SmallVector<AnyFunctionType::Param, 8> differentialParams;
4868-
for (auto diffParamType : diffParamTypes)
4869-
differentialParams.push_back(AnyFunctionType::Param(
4870-
diffParamType->getAutoDiffTangentSpace(lookupConformance)
4871-
->getType()));
4872-
4873-
SmallVector<TupleTypeElt, 8> differentialResults;
4874-
if (auto *resultTuple = originalResult->getAs<TupleType>()) {
4875-
auto resultTupleEltType = resultTuple->getElementType(resultIndex);
4876-
differentialResults.push_back(resultTupleEltType
4877-
->getAutoDiffTangentSpace(lookupConformance)->getType());
4878-
} else {
4879-
assert(resultIndex == 0 && "resultIndex out of bounds");
4880-
differentialResults.push_back(
4881-
originalResult->getAutoDiffTangentSpace(lookupConformance)
4882-
->getType());
4883-
}
4884-
Type differentialResult =
4885-
differentialResults.size() > 1
4886-
? TupleType::get(differentialResults, ctx)
4887-
: differentialResults[0].getType();
4888-
4889-
closure = FunctionType::get(differentialParams, differentialResult);
4890-
break;
4891-
}
4892-
case AutoDiffDerivativeFunctionKind::VJP: {
4893-
// closure is the VJP "pullback":
4894-
// (R.TangentVector...) -> (T.TangentVector...)
4895-
SmallVector<AnyFunctionType::Param, 8> pullbackParams;
4896-
if (auto *resultTuple = originalResult->getAs<TupleType>()) {
4897-
auto resultTupleEltType = resultTuple->getElementType(resultIndex);
4898-
pullbackParams.push_back(
4899-
AnyFunctionType::Param(resultTupleEltType
4900-
->getAutoDiffTangentSpace(lookupConformance)
4901-
->getType()));
4902-
} else {
4903-
assert(resultIndex == 0 && "resultIndex out of bounds");
4904-
pullbackParams.push_back(
4905-
AnyFunctionType::Param(originalResult
4906-
->getAutoDiffTangentSpace(lookupConformance)
4907-
->getType()));
4908-
}
4909-
4910-
SmallVector<TupleTypeElt, 8> pullbackResults;
4911-
for (auto diffParamType : diffParamTypes)
4912-
pullbackResults.push_back(
4913-
diffParamType->getAutoDiffTangentSpace(lookupConformance)
4914-
->getType());
4915-
Type pullbackResult = pullbackResults.size() > 1
4916-
? TupleType::get(pullbackResults, ctx)
4917-
: pullbackResults[0].getType();
4918-
4919-
closure = FunctionType::get(pullbackParams, pullbackResult);
4920-
break;
4921-
}
4922-
}
4923-
assert(closure && "should have built a closure");
4924-
4925-
// Build "(T...) -> ((R...), Closure)".
4926-
SmallVector<TupleTypeElt, 2> retElts;
4927-
retElts.push_back(originalResult);
4928-
retElts.push_back(closure);
4929-
auto retTy = TupleType::get(retElts, ctx);
4930-
auto *derivativeFunction = makeFunctionType(
4931-
curryLevels.back(), curryLevels.back()->getParams(), retTy,
4932-
curryLevels.size() == 1 ? whereClauseGenSig : nullptr);
4933-
4934-
// Wrap the derivative function type in additional curry levels.
4935-
auto curryLevelsWithoutLast =
4936-
ArrayRef<AnyFunctionType *>(curryLevels).drop_back(1);
4937-
for (auto pair : enumerate(llvm::reverse(curryLevelsWithoutLast))) {
4938-
unsigned i = pair.index();
4939-
AnyFunctionType *curryLevel = pair.value();
4940-
derivativeFunction = makeFunctionType(
4941-
curryLevel, curryLevel->getParams(), derivativeFunction,
4942-
i == curryLevelsWithoutLast.size() - 1 ? whereClauseGenSig : nullptr);
4943-
}
4944-
4945-
return derivativeFunction;
4946-
}
4947-
49484809
bool TypeBase::hasOpaqueArchetypePropertiesOrCases() {
49494810
if (auto *structDecl = getStructOrBoundGenericStruct()) {
49504811
for (auto *field : structDecl->getStoredProperties()) {

lib/AST/TypeCheckRequests.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,7 @@ void swift::simple_display(llvm::raw_ostream &out,
10621062
Optional<IndexSubset *>
10631063
DifferentiableAttributeParameterIndicesRequest::getCachedResult() const {
10641064
auto *attr = std::get<0>(getStorage());
1065-
if (attr->hasComputedParameterIndices())
1065+
if (attr->hasBeenTypeChecked())
10661066
return attr->ParameterIndicesAndBit.getPointer();
10671067
return None;
10681068
}

lib/Sema/TypeCheckAttr.cpp

Lines changed: 0 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,6 @@ class AttributeChecker : public AttributeVisitor<AttributeChecker> {
121121
IGNORED_ATTR(ProjectedValueProperty)
122122
IGNORED_ATTR(ReferenceOwnership)
123123
IGNORED_ATTR(OriginallyDefinedIn)
124-
// TODO(TF-830): Upstream `@transpose` attribute type-checking from tensorflow
125-
// branch.
126-
IGNORED_ATTR(Transpose)
127-
128124
// SWIFT_ENABLE_TENSORFLOW
129125
// TODO(TF-715): Allow @quoted on more decls.
130126
IGNORED_ATTR(Quoted)
@@ -265,7 +261,6 @@ class AttributeChecker : public AttributeVisitor<AttributeChecker> {
265261
void visitDifferentiableAttr(DifferentiableAttr *attr);
266262
void visitDerivativeAttr(DerivativeAttr *attr);
267263
// SWIFT_ENABLE_TENSORFLOW
268-
void visitDifferentiableAttr(DifferentiableAttr *attr);
269264
void visitTransposeAttr(TransposeAttr *attr);
270265
// TODO(TF-999): Remove deprecated `@differentiating` attribute.
271266
void visitDifferentiatingAttr(DerivativeAttr *attr);
@@ -4450,121 +4445,6 @@ static bool tangentVectorEqualsSelf(Type type, DeclContext *DC) {
44504445
};
44514446

44524447
// SWIFT_ENABLE_TENSORFLOW
4453-
// Finds a derivative function declaration using the given function specifier,
4454-
// original function declaration, expected type, and "is valid" predicate. If no
4455-
// valid derivative function is found, emits diagnostics and returns false.
4456-
static FuncDecl *findAutoDiffDerivativeFunction(
4457-
DeclNameRefWithLoc specifier, AbstractFunctionDecl *original, Type expectedTy,
4458-
std::function<bool(AbstractFunctionDecl *)> isValid) {
4459-
auto &ctx = original->getASTContext();
4460-
auto &diags = ctx.Diags;
4461-
auto noneValidDiagnostic = [&]() {
4462-
diags.diagnose(specifier.Loc,
4463-
diag::differentiable_attr_overload_not_found, specifier.Name,
4464-
expectedTy);
4465-
};
4466-
auto ambiguousDiagnostic = [&]() {
4467-
diags.diagnose(specifier.Loc, diag::attr_ambiguous_reference_to_decl,
4468-
specifier.Name, "differentiable");
4469-
};
4470-
auto notFunctionDiagnostic = [&]() {
4471-
diags.diagnose(specifier.Loc,
4472-
diag::differentiable_attr_derivative_not_function,
4473-
specifier.Name);
4474-
};
4475-
std::function<void()> invalidTypeContextDiagnostic = [&]() {
4476-
diags.diagnose(specifier.Loc,
4477-
diag::differentiable_attr_function_not_same_type_context,
4478-
specifier.Name);
4479-
};
4480-
4481-
// Returns true if the original function and derivative function candidate are
4482-
// defined in compatible type contexts. If the original function and the
4483-
// derivative function have different parents, or if they both have no type
4484-
// context and are in different modules, return false.
4485-
std::function<bool(AbstractFunctionDecl *)> hasValidTypeContext =
4486-
[&](AbstractFunctionDecl *func) {
4487-
// Check if both functions are top-level.
4488-
if (!original->getInnermostTypeContext() &&
4489-
!func->getInnermostTypeContext() &&
4490-
original->getParentModule() == func->getParentModule())
4491-
return true;
4492-
// Check if both functions are defined in the same type context.
4493-
if (auto typeCtx1 = original->getInnermostTypeContext())
4494-
if (auto typeCtx2 = func->getInnermostTypeContext())
4495-
return typeCtx1->getSelfNominalTypeDecl() ==
4496-
typeCtx2->getSelfNominalTypeDecl();
4497-
return original->getParent() == func->getParent();
4498-
};
4499-
4500-
auto isABIPublic = [&](AbstractFunctionDecl *func) {
4501-
return func->getFormalAccess() >= AccessLevel::Public ||
4502-
func->getAttrs().hasAttribute<InlinableAttr>() ||
4503-
func->getAttrs().hasAttribute<UsableFromInlineAttr>();
4504-
};
4505-
4506-
// If the original function is exported (i.e. it is public or
4507-
// `@usableFromInline`), then the derivative functions must also be exported.
4508-
// Returns true on error.
4509-
auto checkAccessControl = [&](AbstractFunctionDecl *func) {
4510-
if (!isABIPublic(original))
4511-
return false;
4512-
if (isABIPublic(func))
4513-
return false;
4514-
diags.diagnose(specifier.Loc, diag::differentiable_attr_invalid_access,
4515-
specifier.Name, original->getFullName());
4516-
return true;
4517-
};
4518-
4519-
auto originalTypeCtx = original->getInnermostTypeContext();
4520-
if (!originalTypeCtx) originalTypeCtx = original->getParent();
4521-
assert(originalTypeCtx);
4522-
4523-
// Set lookup options.
4524-
auto lookupOptions = defaultMemberLookupOptions
4525-
| NameLookupFlags::IgnoreAccessControl;
4526-
4527-
auto *candidate = findAbstractFunctionDecl(
4528-
specifier.Name, specifier.Loc.getBaseNameLoc(), /*baseType*/ Type(),
4529-
originalTypeCtx, isValid, noneValidDiagnostic, ambiguousDiagnostic,
4530-
notFunctionDiagnostic, lookupOptions, hasValidTypeContext,
4531-
invalidTypeContextDiagnostic);
4532-
if (!candidate)
4533-
return nullptr;
4534-
// Reject non-`func` registered derivatives. JVPs and VJPs must be `func`
4535-
// declarations.
4536-
if (isa<AccessorDecl>(candidate)) {
4537-
diags.diagnose(specifier.Loc,
4538-
diag::differentiable_attr_derivative_not_function,
4539-
specifier.Name);
4540-
return nullptr;
4541-
}
4542-
if (checkAccessControl(candidate))
4543-
return nullptr;
4544-
// Derivatives of class members must be final.
4545-
if (original->getDeclContext()->getSelfClassDecl() &&
4546-
!candidate->isFinal()) {
4547-
diags.diagnose(specifier.Loc,
4548-
diag::differentiable_attr_class_derivative_not_final);
4549-
return nullptr;
4550-
}
4551-
assert(isa<FuncDecl>(candidate));
4552-
auto *funcDecl = cast<FuncDecl>(candidate);
4553-
return funcDecl;
4554-
}
4555-
4556-
// SWIFT_ENABLE_TENSORFLOW
4557-
void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
4558-
// Call `getParameterIndices` to trigger a
4559-
// `DifferentiableAttributeParameterIndicesRequest`, which currently performs
4560-
// full `@differentiable` type-checking.
4561-
// TODO: Consider creating separate requests for the following functionality:
4562-
// - `DifferentiableAttr::getJVPFunction`
4563-
// - `DifferentiableAttr::getVJPFunction`
4564-
// - `DifferentiableAttr::getDerivativeGenericSignature`
4565-
(void)attr->getParameterIndices();
4566-
}
4567-
45684448
llvm::Expected<IndexSubset *>
45694449
DifferentiableAttributeParameterIndicesRequest::evaluate(
45704450
Evaluator &evaluator, DifferentiableAttr *attr, Decl *D) const {

lib/Sema/TypeCheckProtocol.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -358,15 +358,6 @@ struct RequirementMatch {
358358
assert(!hasWitnessType() && "Should have witness type");
359359
}
360360

361-
// SWIFT_ENABLE_TENSORFLOW
362-
RequirementMatch(ValueDecl *witness, MatchKind kind,
363-
const DeclAttribute *attr)
364-
: Witness(witness), Kind(kind), WitnessType(), UnmetAttribute(attr),
365-
ReqEnv(None) {
366-
assert(!hasWitnessType() && "Should have witness type");
367-
assert(UnmetAttribute);
368-
}
369-
370361
RequirementMatch(ValueDecl *witness, MatchKind kind,
371362
const DeclAttribute *attr)
372363
: Witness(witness), Kind(kind), WitnessType(), UnmetAttribute(attr),

lib/Serialization/Deserialization.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4350,13 +4350,6 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
43504350
break;
43514351
}
43524352

4353-
case decls_block::ImplicitlySynthesizesNestedRequirement_DECL_ATTR: {
4354-
serialization::decls_block::ImplicitlySynthesizesNestedRequirementDeclAttrLayout
4355-
::readRecord(scratch);
4356-
Attr = new (ctx) ImplicitlySynthesizesNestedRequirementAttr(blobData, {}, {});
4357-
break;
4358-
}
4359-
43604353
#define SIMPLE_DECL_ATTR(NAME, CLASS, ...) \
43614354
case decls_block::CLASS##_DECL_ATTR: { \
43624355
bool isImplicit; \

0 commit comments

Comments
 (0)