Skip to content

Commit 3d155d2

Browse files
soulitzerguilhermeleobas
authored andcommitted
Move SingletonSymNodeImpl from c10 to aten (pytorch#114895)
Pull Request resolved: pytorch#114895 Approved by: https://github.com/jbschlosser
1 parent 0b3f2db commit 3d155d2

File tree

7 files changed

+111
-100
lines changed

7 files changed

+111
-100
lines changed

c10/core/SingletonSymNodeImpl.cpp renamed to aten/src/ATen/core/SingletonSymNodeImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include <c10/core/SingletonSymNodeImpl.h>
1+
#include <ATen/core/SingletonSymNodeImpl.h>
22
#include <c10/core/SymNodeImpl.h>
33
#include <c10/util/Exception.h>
44

c10/core/SingletonSymNodeImpl.h renamed to aten/src/ATen/core/SingletonSymNodeImpl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ namespace c10 {
2828
// During tracing the strides of the outputs need to be a function of the size
2929
// and strides of the inputs so it is important that SingletonSymNode itself is
3030
// able to express this.
31-
class C10_API SingletonSymNodeImpl : public SymNodeImpl {
31+
class TORCH_API SingletonSymNodeImpl : public SymNodeImpl {
3232
public:
3333
// CAUTION: you should probably not be constructing these directly; please
3434
// the higher-level API in python instead (TODO: actually introduce that).

build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,7 @@ aten_cpu_source_non_codegen_list = [
10211021
"aten/src/ATen/core/operator_name.cpp",
10221022
"aten/src/ATen/core/TorchDispatchUtils.cpp",
10231023
"aten/src/ATen/core/register_symbols.cpp",
1024+
"aten/src/ATen/core/SingletonSymNodeImpl.cpp",
10241025
"aten/src/ATen/core/class_type.cpp",
10251026
"aten/src/ATen/core/type.cpp",
10261027
"aten/src/ATen/core/type_factory.cpp",

c10/test/core/SymInt_test.cpp

Lines changed: 0 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include <gtest/gtest.h>
22

3-
#include <c10/core/SingletonSymNodeImpl.h>
43
#include <c10/core/SymInt.h>
54
#include <c10/core/SymNodeImpl.h>
65

@@ -23,100 +22,4 @@ TEST(SymIntTest, CheckRange) {
2322
EXPECT_FALSE(SymInt::check_range(INT64_MIN));
2423
}
2524

26-
TEST(SymIntTest, SingletonSymNode) {
27-
auto a = c10::SymInt(
28-
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 1)));
29-
auto b = c10::SymInt(
30-
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 1)));
31-
auto c = c10::SymInt(
32-
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(2, 1)));
33-
auto d = c10::SymInt(3);
34-
35-
ASSERT_TRUE(a == a);
36-
ASSERT_TRUE(a == b);
37-
ASSERT_FALSE(a != a);
38-
ASSERT_FALSE(a != b);
39-
ASSERT_FALSE(a == c);
40-
ASSERT_TRUE(a != c);
41-
42-
ASSERT_FALSE(a == d);
43-
ASSERT_TRUE(a != d);
44-
ASSERT_FALSE(d == a);
45-
ASSERT_TRUE(d != a);
46-
47-
// ge
48-
ASSERT_TRUE(a >= a);
49-
ASSERT_TRUE(a >= b);
50-
ASSERT_TRUE(b >= a);
51-
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
52-
EXPECT_THROW((void)(a >= c), c10::Error);
53-
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
54-
EXPECT_THROW((void)(c >= a), c10::Error);
55-
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
56-
EXPECT_THROW((void)(c >= 3), c10::Error);
57-
ASSERT_TRUE(c >= 2);
58-
ASSERT_TRUE(c >= 1);
59-
ASSERT_FALSE(1 >= c);
60-
61-
// lt
62-
ASSERT_FALSE(a < a);
63-
ASSERT_FALSE(a < b);
64-
ASSERT_FALSE(b < a);
65-
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
66-
EXPECT_THROW((void)(a < c), c10::Error);
67-
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
68-
EXPECT_THROW((void)(c < a), c10::Error);
69-
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
70-
EXPECT_THROW((void)(3 < a), c10::Error);
71-
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
72-
EXPECT_THROW((void)(2 < a), c10::Error);
73-
ASSERT_TRUE(1 < a);
74-
75-
// le
76-
ASSERT_TRUE(a <= a);
77-
ASSERT_TRUE(b <= a);
78-
ASSERT_TRUE(a <= b);
79-
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
80-
EXPECT_THROW((void)(a <= c), c10::Error);
81-
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
82-
EXPECT_THROW((void)(c <= a), c10::Error);
83-
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
84-
EXPECT_THROW((void)(3 <= c), c10::Error);
85-
ASSERT_TRUE(2 <= c);
86-
ASSERT_TRUE(1 <= c);
87-
ASSERT_FALSE(c <= 1);
88-
89-
// gt
90-
ASSERT_FALSE(a > a);
91-
ASSERT_FALSE(b > a);
92-
ASSERT_FALSE(a > b);
93-
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
94-
EXPECT_THROW((void)(a > c), c10::Error);
95-
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
96-
EXPECT_THROW((void)(c > a), c10::Error);
97-
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
98-
EXPECT_THROW((void)(a > 3), c10::Error);
99-
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
100-
EXPECT_THROW((void)(a > 2), c10::Error);
101-
ASSERT_TRUE(a > 1);
102-
}
103-
104-
TEST(SymIntTest, SingletonSymNodeWithFactor) {
105-
auto a = c10::SymInt(
106-
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 5)));
107-
auto b = c10::SymInt(
108-
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 10)));
109-
// eq
110-
ASSERT_FALSE(a == b);
111-
ASSERT_FALSE(a >= b);
112-
ASSERT_TRUE(b >= a);
113-
ASSERT_TRUE(a <= b);
114-
ASSERT_FALSE(b <= a);
115-
// ne
116-
ASSERT_TRUE(a != b);
117-
// mul
118-
ASSERT_TRUE(a * 2 == b);
119-
ASSERT_TRUE(a * 3 >= b);
120-
ASSERT_TRUE(a * 2 == 2 * a);
121-
}
12225
#endif

test/cpp/api/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ set(TORCH_API_TEST_SOURCES
4141
${TORCH_API_TEST_DIR}/inference_mode.cpp
4242
${TORCH_API_TEST_DIR}/grad_mode.cpp
4343
${TORCH_API_TEST_DIR}/operations.cpp
44+
${TORCH_API_TEST_DIR}/singleton_int.cpp
4445
)
4546
if(USE_CUDA OR USE_ROCM)
4647
list(APPEND TORCH_API_TEST_SOURCES ${TORCH_API_TEST_DIR}/parallel.cpp)

test/cpp/api/singleton_int.cpp

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
#include <gtest/gtest.h>
2+
3+
#include <ATen/core/SingletonSymNodeImpl.h>
4+
#include <c10/core/SymInt.h>
5+
#include <c10/core/SymNodeImpl.h>
6+
#include <torch/torch.h>
7+
8+
#include <test/cpp/api/support.h>
9+
10+
TEST(SingletonIntTest, Comparisons) {
11+
auto a = c10::SymInt(
12+
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 1)));
13+
auto b = c10::SymInt(
14+
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 1)));
15+
auto c = c10::SymInt(
16+
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(2, 1)));
17+
auto d = c10::SymInt(3);
18+
19+
ASSERT_TRUE(a == a);
20+
ASSERT_TRUE(a == b);
21+
ASSERT_FALSE(a != a);
22+
ASSERT_FALSE(a != b);
23+
ASSERT_FALSE(a == c);
24+
ASSERT_TRUE(a != c);
25+
26+
ASSERT_FALSE(a == d);
27+
ASSERT_TRUE(a != d);
28+
ASSERT_FALSE(d == a);
29+
ASSERT_TRUE(d != a);
30+
31+
// ge
32+
ASSERT_TRUE(a >= a);
33+
ASSERT_TRUE(a >= b);
34+
ASSERT_TRUE(b >= a);
35+
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
36+
EXPECT_THROW((void)(a >= c), c10::Error);
37+
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
38+
EXPECT_THROW((void)(c >= a), c10::Error);
39+
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
40+
EXPECT_THROW((void)(c >= 3), c10::Error);
41+
ASSERT_TRUE(c >= 2);
42+
ASSERT_TRUE(c >= 1);
43+
ASSERT_FALSE(1 >= c);
44+
45+
// lt
46+
ASSERT_FALSE(a < a);
47+
ASSERT_FALSE(a < b);
48+
ASSERT_FALSE(b < a);
49+
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
50+
EXPECT_THROW((void)(a < c), c10::Error);
51+
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
52+
EXPECT_THROW((void)(c < a), c10::Error);
53+
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
54+
EXPECT_THROW((void)(3 < a), c10::Error);
55+
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
56+
EXPECT_THROW((void)(2 < a), c10::Error);
57+
ASSERT_TRUE(1 < a);
58+
59+
// le
60+
ASSERT_TRUE(a <= a);
61+
ASSERT_TRUE(b <= a);
62+
ASSERT_TRUE(a <= b);
63+
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
64+
EXPECT_THROW((void)(a <= c), c10::Error);
65+
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
66+
EXPECT_THROW((void)(c <= a), c10::Error);
67+
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
68+
EXPECT_THROW((void)(3 <= c), c10::Error);
69+
ASSERT_TRUE(2 <= c);
70+
ASSERT_TRUE(1 <= c);
71+
ASSERT_FALSE(c <= 1);
72+
73+
// gt
74+
ASSERT_FALSE(a > a);
75+
ASSERT_FALSE(b > a);
76+
ASSERT_FALSE(a > b);
77+
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
78+
EXPECT_THROW((void)(a > c), c10::Error);
79+
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
80+
EXPECT_THROW((void)(c > a), c10::Error);
81+
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
82+
EXPECT_THROW((void)(a > 3), c10::Error);
83+
// NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
84+
EXPECT_THROW((void)(a > 2), c10::Error);
85+
ASSERT_TRUE(a > 1);
86+
}
87+
88+
TEST(SingletonIntTest, WiithFactor) {
89+
auto a = c10::SymInt(
90+
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 5)));
91+
auto b = c10::SymInt(
92+
c10::SymNode(c10::make_intrusive<c10::SingletonSymNodeImpl>(1, 10)));
93+
// eq
94+
ASSERT_FALSE(a == b);
95+
ASSERT_FALSE(a >= b);
96+
ASSERT_TRUE(b >= a);
97+
ASSERT_TRUE(a <= b);
98+
ASSERT_FALSE(b <= a);
99+
// ne
100+
ASSERT_TRUE(a != b);
101+
// mul
102+
ASSERT_TRUE(a * 2 == b);
103+
ASSERT_TRUE(a * 3 >= b);
104+
ASSERT_TRUE(a * 2 == 2 * a);
105+
}

torch/csrc/utils/python_dispatch.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
#include <ATen/FunctionalTensorWrapper.h>
77
#include <ATen/TensorSubclassLikeUtils.h>
88
#include <ATen/core/PythonOpRegistrationTrampoline.h>
9+
#include <ATen/core/SingletonSymNodeImpl.h>
910
#include <ATen/core/dispatch/Dispatcher.h>
11+
1012
#include <ATen/functorch/BatchedTensorImpl.h>
1113
#include <torch/library.h>
1214

@@ -15,7 +17,6 @@
1517
#include <torch/csrc/autograd/python_variable.h>
1618
#include <torch/csrc/jit/python/pybind_utils.h>
1719

18-
#include <c10/core/SingletonSymNodeImpl.h>
1920
#include <c10/util/flat_hash_map.h>
2021
#include <pybind11/operators.h>
2122
#include <pybind11/stl.h>

0 commit comments

Comments
 (0)