Skip to content

Commit 0e7438e

Browse files
authored
Fix stack overflow in containsTensorFlowValue for recursive data structures. (#21449)
* Break cycles in containsTensorFlowValue. * Use stack to keep track of parent decls. * Trying to add a unit test. * Updated the unit test with different scenarios * remove tab
1 parent 7484431 commit 0e7438e

File tree

4 files changed

+140
-13
lines changed

4 files changed

+140
-13
lines changed

include/swift/AST/TensorFlow.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "swift/Basic/LLVM.h"
2222
#include "llvm/ADT/DenseMap.h"
23+
#include "llvm/ADT/SetVector.h"
2324

2425
namespace swift {
2526
class CanType;
@@ -92,7 +93,12 @@ namespace tf {
9293
bool containsTensorFlowValue(Type ty, bool checkHigherOrderFunctions);
9394

9495
private:
95-
bool structContainsTensorFlowValue(StructDecl *decl);
96+
bool containsTensorFlowValueImpl(
97+
Type ty, bool checkHigherOrderFunctions,
98+
llvm::SetVector<NominalTypeDecl *> &parentDecls);
99+
100+
bool structContainsTensorFlowValue(
101+
StructDecl *decl, llvm::SetVector<NominalTypeDecl *> &parentDecls);
96102
};
97103

98104
/// This class provides a single source of truth for the set of types that are

lib/AST/TensorFlow.cpp

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,14 @@ bool tf::flattenTensorFlowValueAggregate(Type ty,
220220
/// parameter of result contains any TensorFlow value type.
221221
bool TypeContainsTensorFlowValue::containsTensorFlowValue(
222222
Type ty, bool checkHigherOrderFunctions) {
223+
llvm::SetVector<NominalTypeDecl *> parentDecls;
224+
return containsTensorFlowValueImpl(ty, checkHigherOrderFunctions,
225+
parentDecls);
226+
}
227+
228+
bool TypeContainsTensorFlowValue::containsTensorFlowValueImpl(
229+
Type ty, bool checkHigherOrderFunctions,
230+
llvm::SetVector<NominalTypeDecl *> &parentDecls) {
223231
// If this type literally is a value type, then yep, we contain it. This is
224232
// the base case.
225233
if (isTensorFlowValue(ty))
@@ -229,40 +237,43 @@ bool TypeContainsTensorFlowValue::containsTensorFlowValue(
229237
// then the tuple itself does.
230238
if (auto *tuple = ty->getAs<TupleType>()) {
231239
for (auto &elt : tuple->getElements())
232-
if (containsTensorFlowValue(elt.getType(), checkHigherOrderFunctions))
240+
if (containsTensorFlowValueImpl(elt.getType(), checkHigherOrderFunctions,
241+
parentDecls))
233242
return true;
234243
return false;
235244
}
236245

237246
// Deabstraction scalarizes structs.
238247
if (auto *st = ty->getAs<StructType>())
239-
return structContainsTensorFlowValue(st->getDecl());
248+
return structContainsTensorFlowValue(st->getDecl(), parentDecls);
240249

241250
// Deabstractions binds specialized generic structs. Check if either the
242251
// struct itself or one of the generic arguments contains a tensor value.
243252
if (auto *bgst = ty->getAs<BoundGenericStructType>()) {
244253
// Check the generic arguments.
245254
for (auto arg : bgst->getGenericArgs())
246-
if (containsTensorFlowValue(arg, checkHigherOrderFunctions))
255+
if (containsTensorFlowValueImpl(arg, checkHigherOrderFunctions,
256+
parentDecls))
247257
return true;
248258

249-
return structContainsTensorFlowValue(bgst->getDecl());
259+
return structContainsTensorFlowValue(bgst->getDecl(), parentDecls);
250260
}
251261

252262
// Handle still-generic types that may contain a tensor value.
253263
if (auto *ugst = ty->getAs<UnboundGenericType>())
254264
if (auto *decl = dyn_cast<StructDecl>(ugst->getDecl()))
255-
return structContainsTensorFlowValue(decl);
265+
return structContainsTensorFlowValue(decl, parentDecls);
256266

257267
if (checkHigherOrderFunctions) {
258268
if (auto *fnType = ty->getAs<SILFunctionType>()) {
259269
for (auto &result : fnType->getResults())
260-
if (containsTensorFlowValue(result.getType(),
261-
checkHigherOrderFunctions))
270+
if (containsTensorFlowValueImpl(result.getType(),
271+
checkHigherOrderFunctions, parentDecls))
262272
return true;
263273

264274
for (auto &param : fnType->getParameters())
265-
if (containsTensorFlowValue(param.getType(), checkHigherOrderFunctions))
275+
if (containsTensorFlowValueImpl(param.getType(),
276+
checkHigherOrderFunctions, parentDecls))
266277
return true;
267278
}
268279
}
@@ -274,20 +285,27 @@ bool TypeContainsTensorFlowValue::containsTensorFlowValue(
274285

275286
/// Determine whether the given struct contains a TensorFlow value type, caching
276287
/// the result.
277-
bool TypeContainsTensorFlowValue::
278-
structContainsTensorFlowValue(StructDecl *decl) {
288+
bool TypeContainsTensorFlowValue::structContainsTensorFlowValue(
289+
StructDecl *decl, llvm::SetVector<NominalTypeDecl *> &parentDecls) {
290+
if (parentDecls.count(decl) > 0) {
291+
// We have a cycle, break it here.
292+
return false;
293+
}
279294
auto it = declContainsTensorFlowValue.find(decl);
280295
if (it != declContainsTensorFlowValue.end())
281296
return it->second;
282297

298+
parentDecls.insert(decl);
283299
bool hasTensorFlowValue = false;
284300
for (auto p : decl->getStoredProperties())
285-
if (containsTensorFlowValue(p->getType(),
286-
/*checkHigherOrderFunctions*/ false)) {
301+
if (containsTensorFlowValueImpl(p->getType(),
302+
/*checkHigherOrderFunctions*/ false,
303+
parentDecls)) {
287304
hasTensorFlowValue = true;
288305
break;
289306
}
290307

308+
parentDecls.pop_back();
291309
return declContainsTensorFlowValue[decl] = hasTensorFlowValue;
292310
}
293311

unittests/AST/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ add_swift_unittest(SwiftASTTests
44
# SWIFT_ENABLE_TENSORFLOW
55
SILAutoDiffIndices.cpp
66
SourceLocTests.cpp
7+
# SWIFT_ENABLE_TENSORFLOW
8+
TensorFlow.cpp
79
TestContext.cpp
810
TypeMatchTests.cpp
911
VersionRangeLattice.cpp

unittests/AST/TensorFlow.cpp

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
//===--- SILAutoDiffIndices.cpp - Tests SILAutoDiffIndices ----------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See https://swift.org/LICENSE.txt for license information
9+
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
// SWIFT_ENABLE_TENSORFLOW
13+
14+
#include "TestContext.h"
15+
#include "swift/AST/AutoDiff.h"
16+
#include "swift/AST/Decl.h"
17+
#include "swift/AST/Type.h"
18+
#include "swift/AST/TensorFlow.h"
19+
#include "gtest/gtest.h"
20+
21+
using namespace swift;
22+
using namespace swift::unittest;
23+
24+
class TypeContainsTensorFlowValueTest : public ::testing::Test {
25+
protected:
26+
TestContext testContext;
27+
ASTContext& Ctx;
28+
Type floatType;
29+
BoundGenericClassType *tensorHandleType;
30+
tf::TypeContainsTensorFlowValue tctf;
31+
32+
TypeContainsTensorFlowValueTest()
33+
: Ctx(testContext.Ctx), floatType(createFloatType()),
34+
tensorHandleType(createTensorHandleType(floatType)) {}
35+
36+
Type createStructWithField(const char *name, Type fieldType) {
37+
auto *structDecl = testContext.makeNominal<StructDecl>(name);
38+
structDecl->computeType();
39+
auto *structType = StructType::get(structDecl, Type(), Ctx);
40+
auto *varDecl = new (Ctx) VarDecl(
41+
/*IsStatic*/ false, VarDecl::Specifier::Var,
42+
/*IsCaptureList*/ false, SourceLoc(), Ctx.getIdentifier("field"),
43+
structDecl);
44+
varDecl->setInterfaceType(fieldType);
45+
structDecl->addMember(varDecl);
46+
return structType;
47+
}
48+
49+
bool containsTensorFlowValue(Type t) {
50+
// TODO: Add tests with CheckHigherOrderFunctions set to true.
51+
return tctf.containsTensorFlowValue(t, /*CheckHigherOrderFunctions*/false);
52+
}
53+
54+
55+
private:
56+
Type createFloatType() {
57+
// Float type.
58+
auto *floatDecl = testContext.makeNominal<StructDecl>("Float");
59+
return StructType::get(floatDecl, Type(), Ctx);
60+
}
61+
62+
BoundGenericClassType *createTensorHandleType(Type genericType) {
63+
auto *tensorDecl = testContext.makeNominal<ClassDecl>("TensorHandle");
64+
65+
// Generic parameter list.
66+
auto *floatGenericParamDecl = new (Ctx)
67+
GenericTypeParamDecl(tensorDecl->getDeclContext(),
68+
Ctx.getIdentifier("Scalar"), SourceLoc(), 0, 0);
69+
auto *paramList = GenericParamList::create(
70+
Ctx, SourceLoc(), {floatGenericParamDecl}, SourceLoc());
71+
tensorDecl->setGenericParams(paramList);
72+
73+
return BoundGenericClassType::get(tensorDecl, Type(), {genericType});
74+
}
75+
};
76+
77+
TEST_F(TypeContainsTensorFlowValueTest, ClassifiesCorrectly) {
78+
EXPECT_TRUE(containsTensorFlowValue(tensorHandleType));
79+
EXPECT_TRUE(containsTensorFlowValue(
80+
createStructWithField("StructWithTensor", tensorHandleType)));
81+
EXPECT_FALSE(containsTensorFlowValue(
82+
createStructWithField("StructWithNoTensor", floatType)));
83+
}
84+
85+
TEST_F(TypeContainsTensorFlowValueTest, WorksForRecursiveTypes) {
86+
// Creates a recursive type for testing purposes. Note that this is not a
87+
// valid swift type, but should suffice for the purposes of this unit test.
88+
//
89+
auto *recursiveDecl = testContext.makeNominal<StructDecl>("RecursiveStruct");
90+
recursiveDecl->computeType();
91+
auto *recursiveType = StructType::get(recursiveDecl, Type(), Ctx);
92+
// Add a field of the recursiveType.
93+
// (This is not possible in swift, but ok for tests.)
94+
auto *varDecl = new (Ctx) VarDecl(
95+
/*IsStatic*/ false, VarDecl::Specifier::Var,
96+
/*IsCaptureList*/ false, SourceLoc(), Ctx.getIdentifier("someField"),
97+
recursiveDecl);
98+
varDecl->setInterfaceType(recursiveType);
99+
recursiveDecl->addMember(varDecl);
100+
EXPECT_FALSE(containsTensorFlowValue(recursiveType));
101+
}

0 commit comments

Comments
 (0)