Skip to content

Fix stack overflow containsTensorFlowValue for recursive data structures. #21449

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 30, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion include/swift/AST/TensorFlow.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "swift/Basic/LLVM.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SetVector.h"

namespace swift {
class CanType;
Expand Down Expand Up @@ -92,7 +93,12 @@ namespace tf {
bool containsTensorFlowValue(Type ty, bool checkHigherOrderFunctions);

private:
bool structContainsTensorFlowValue(StructDecl *decl);
bool containsTensorFlowValueImpl(
Type ty, bool checkHigherOrderFunctions,
llvm::SetVector<NominalTypeDecl *> &parentDecls);

bool structContainsTensorFlowValue(
StructDecl *decl, llvm::SetVector<NominalTypeDecl *> &parentDecls);
};

/// This class provides a single source of truth for the set of types that are
Expand Down
42 changes: 30 additions & 12 deletions lib/AST/TensorFlow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,14 @@ bool tf::flattenTensorFlowValueAggregate(Type ty,
/// parameter of result contains any TensorFlow value type.
bool TypeContainsTensorFlowValue::containsTensorFlowValue(
Type ty, bool checkHigherOrderFunctions) {
llvm::SetVector<NominalTypeDecl *> parentDecls;
return containsTensorFlowValueImpl(ty, checkHigherOrderFunctions,
parentDecls);
}

bool TypeContainsTensorFlowValue::containsTensorFlowValueImpl(
Type ty, bool checkHigherOrderFunctions,
llvm::SetVector<NominalTypeDecl *> &parentDecls) {
// If this type literally is a value type, then yep, we contain it. This is
// the base case.
if (isTensorFlowValue(ty))
Expand All @@ -229,40 +237,43 @@ bool TypeContainsTensorFlowValue::containsTensorFlowValue(
// then the tuple itself does.
if (auto *tuple = ty->getAs<TupleType>()) {
for (auto &elt : tuple->getElements())
if (containsTensorFlowValue(elt.getType(), checkHigherOrderFunctions))
if (containsTensorFlowValueImpl(elt.getType(), checkHigherOrderFunctions,
parentDecls))
return true;
return false;
}

// Deabstraction scalarizes structs.
if (auto *st = ty->getAs<StructType>())
return structContainsTensorFlowValue(st->getDecl());
return structContainsTensorFlowValue(st->getDecl(), parentDecls);

// Deabstractions binds specialized generic structs. Check if either the
// struct itself or one of the generic arguments contains a tensor value.
if (auto *bgst = ty->getAs<BoundGenericStructType>()) {
// Check the generic arguments.
for (auto arg : bgst->getGenericArgs())
if (containsTensorFlowValue(arg, checkHigherOrderFunctions))
if (containsTensorFlowValueImpl(arg, checkHigherOrderFunctions,
parentDecls))
return true;

return structContainsTensorFlowValue(bgst->getDecl());
return structContainsTensorFlowValue(bgst->getDecl(), parentDecls);
}

// Handle still-generic types that may contain a tensor value.
if (auto *ugst = ty->getAs<UnboundGenericType>())
if (auto *decl = dyn_cast<StructDecl>(ugst->getDecl()))
return structContainsTensorFlowValue(decl);
return structContainsTensorFlowValue(decl, parentDecls);

if (checkHigherOrderFunctions) {
if (auto *fnType = ty->getAs<SILFunctionType>()) {
for (auto &result : fnType->getResults())
if (containsTensorFlowValue(result.getType(),
checkHigherOrderFunctions))
if (containsTensorFlowValueImpl(result.getType(),
checkHigherOrderFunctions, parentDecls))
return true;

for (auto &param : fnType->getParameters())
if (containsTensorFlowValue(param.getType(), checkHigherOrderFunctions))
if (containsTensorFlowValueImpl(param.getType(),
checkHigherOrderFunctions, parentDecls))
return true;
}
}
Expand All @@ -274,20 +285,27 @@ bool TypeContainsTensorFlowValue::containsTensorFlowValue(

/// Determine whether the given struct contains a TensorFlow value type, caching
/// the result.
bool TypeContainsTensorFlowValue::
structContainsTensorFlowValue(StructDecl *decl) {
bool TypeContainsTensorFlowValue::structContainsTensorFlowValue(
StructDecl *decl, llvm::SetVector<NominalTypeDecl *> &parentDecls) {
if (parentDecls.count(decl) > 0) {
// We have a cycle, break it here.
return false;
}
auto it = declContainsTensorFlowValue.find(decl);
if (it != declContainsTensorFlowValue.end())
return it->second;

parentDecls.insert(decl);
bool hasTensorFlowValue = false;
for (auto p : decl->getStoredProperties())
if (containsTensorFlowValue(p->getType(),
/*checkHigherOrderFunctions*/ false)) {
if (containsTensorFlowValueImpl(p->getType(),
/*checkHigherOrderFunctions*/ false,
parentDecls)) {
hasTensorFlowValue = true;
break;
}

parentDecls.pop_back();
return declContainsTensorFlowValue[decl] = hasTensorFlowValue;
}

Expand Down
2 changes: 2 additions & 0 deletions unittests/AST/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ add_swift_unittest(SwiftASTTests
# SWIFT_ENABLE_TENSORFLOW
SILAutoDiffIndices.cpp
SourceLocTests.cpp
# SWIFT_ENABLE_TENSORFLOW
TensorFlow.cpp
TestContext.cpp
TypeMatchTests.cpp
VersionRangeLattice.cpp
Expand Down
101 changes: 101 additions & 0 deletions unittests/AST/TensorFlow.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
//===--- SILAutoDiffIndices.cpp - Tests SILAutoDiffIndices ----------------===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
// SWIFT_ENABLE_TENSORFLOW

#include "TestContext.h"
#include "swift/AST/AutoDiff.h"
#include "swift/AST/Decl.h"
#include "swift/AST/Type.h"
#include "swift/AST/TensorFlow.h"
#include "gtest/gtest.h"

using namespace swift;
using namespace swift::unittest;

class TypeContainsTensorFlowValueTest : public ::testing::Test {
protected:
TestContext testContext;
ASTContext& Ctx;
Type floatType;
BoundGenericClassType *tensorHandleType;
tf::TypeContainsTensorFlowValue tctf;

TypeContainsTensorFlowValueTest()
: Ctx(testContext.Ctx), floatType(createFloatType()),
tensorHandleType(createTensorHandleType(floatType)) {}

Type createStructWithField(const char *name, Type fieldType) {
auto *structDecl = testContext.makeNominal<StructDecl>(name);
structDecl->computeType();
auto *structType = StructType::get(structDecl, Type(), Ctx);
auto *varDecl = new (Ctx) VarDecl(
/*IsStatic*/ false, VarDecl::Specifier::Var,
/*IsCaptureList*/ false, SourceLoc(), Ctx.getIdentifier("field"),
structDecl);
varDecl->setInterfaceType(fieldType);
structDecl->addMember(varDecl);
return structType;
}

bool containsTensorFlowValue(Type t) {
// TODO: Add tests with CheckHigherOrderFunctions set to true.
return tctf.containsTensorFlowValue(t, /*CheckHigherOrderFunctions*/false);
}


private:
Type createFloatType() {
// Float type.
auto *floatDecl = testContext.makeNominal<StructDecl>("Float");
return StructType::get(floatDecl, Type(), Ctx);
}

BoundGenericClassType *createTensorHandleType(Type genericType) {
auto *tensorDecl = testContext.makeNominal<ClassDecl>("TensorHandle");

// Generic parameter list.
auto *floatGenericParamDecl = new (Ctx)
GenericTypeParamDecl(tensorDecl->getDeclContext(),
Ctx.getIdentifier("Scalar"), SourceLoc(), 0, 0);
auto *paramList = GenericParamList::create(
Ctx, SourceLoc(), {floatGenericParamDecl}, SourceLoc());
tensorDecl->setGenericParams(paramList);

return BoundGenericClassType::get(tensorDecl, Type(), {genericType});
}
};

TEST_F(TypeContainsTensorFlowValueTest, ClassifiesCorrectly) {
EXPECT_TRUE(containsTensorFlowValue(tensorHandleType));
EXPECT_TRUE(containsTensorFlowValue(
createStructWithField("StructWithTensor", tensorHandleType)));
EXPECT_FALSE(containsTensorFlowValue(
createStructWithField("StructWithNoTensor", floatType)));
}

TEST_F(TypeContainsTensorFlowValueTest, WorksForRecursiveTypes) {
// Creates a recursive type for testing purposes. Note that this is not a
// valid swift type, but should suffice for the purposes of this unit test.
//
auto *recursiveDecl = testContext.makeNominal<StructDecl>("RecursiveStruct");
recursiveDecl->computeType();
auto *recursiveType = StructType::get(recursiveDecl, Type(), Ctx);
// Add a field of the recursiveType.
// (This is not possible in swift, but ok for tests.)
auto *varDecl = new (Ctx) VarDecl(
/*IsStatic*/ false, VarDecl::Specifier::Var,
/*IsCaptureList*/ false, SourceLoc(), Ctx.getIdentifier("someField"),
recursiveDecl);
varDecl->setInterfaceType(recursiveType);
recursiveDecl->addMember(varDecl);
EXPECT_FALSE(containsTensorFlowValue(recursiveType));
}