@@ -21,9 +21,8 @@ TEST(Converters, ATenMMConvertsCorrectly) {
21
21
22
22
params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
23
23
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, in2});
24
- auto trt = trt_results[0 ].reshape_as (jit_results[0 ]);
25
24
26
- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt , 2e-6 ));
25
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[ 0 ] , 2e-6 ));
27
26
}
28
27
29
28
TEST (Converters, ATenMMWithDiffShapesConvertsCorrectly) {
@@ -42,9 +41,131 @@ TEST(Converters, ATenMMWithDiffShapesConvertsCorrectly) {
42
41
43
42
params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
44
43
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, in2});
45
- auto trt = trt_results[0 ].reshape_as (jit_results[0 ]);
46
44
47
- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt, 2e-6 ));
45
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
46
+ }
47
+
48
+ TEST (Converters, ATenMM1d2dConvertsCorrectly) {
49
+ const auto graph = R"IR(
50
+ graph(%0 : Tensor, %1 : Tensor):
51
+ %2 : Tensor = aten::matmul(%0, %1)
52
+ return (%2))IR" ;
53
+
54
+ auto g = std::make_shared<torch::jit::Graph>();
55
+ torch::jit::parseIR (graph, g.get ());
56
+
57
+ auto in1 = at::randint (0 , 5 , {10 }, {at::kCUDA });
58
+ auto in2 = at::randint (0 , 5 , {10 , 1 }, {at::kCUDA });
59
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
60
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in1, in2});
61
+
62
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
63
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, in2});
64
+
65
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
66
+ }
67
+
68
+ TEST (Converters, ATenMM1d3dConvertsCorrectly) {
69
+ const auto graph = R"IR(
70
+ graph(%0 : Tensor, %1 : Tensor):
71
+ %2 : Tensor = aten::matmul(%0, %1)
72
+ return (%2))IR" ;
73
+
74
+ auto g = std::make_shared<torch::jit::Graph>();
75
+ torch::jit::parseIR (graph, g.get ());
76
+
77
+ auto in1 = at::randint (0 , 5 , {10 }, {at::kCUDA });
78
+ auto in2 = at::randint (0 , 5 , {2 , 10 , 8 }, {at::kCUDA });
79
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
80
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in1, in2});
81
+
82
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
83
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, in2});
84
+
85
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
86
+ }
87
+
88
+ TEST (Converters, ATenMM1d4dConvertsCorrectly) {
89
+ const auto graph = R"IR(
90
+ graph(%0 : Tensor, %1 : Tensor):
91
+ %2 : Tensor = aten::matmul(%0, %1)
92
+ return (%2))IR" ;
93
+
94
+ auto g = std::make_shared<torch::jit::Graph>();
95
+ torch::jit::parseIR (graph, g.get ());
96
+
97
+ auto in1 = at::randint (0 , 5 , {10 }, {at::kCUDA });
98
+ auto in2 = at::randint (0 , 5 , {2 , 3 , 10 , 8 }, {at::kCUDA });
99
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
100
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in1, in2});
101
+
102
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
103
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, in2});
104
+
105
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
106
+ }
107
+
108
+ TEST (Converters, ATenMM3d1dConvertsCorrectly) {
109
+ const auto graph = R"IR(
110
+ graph(%0 : Tensor, %1 : Tensor):
111
+ %2 : Tensor = aten::matmul(%0, %1)
112
+ return (%2))IR" ;
113
+
114
+ auto g = std::make_shared<torch::jit::Graph>();
115
+ torch::jit::parseIR (graph, g.get ());
116
+
117
+ auto in1 = at::randint (0 , 5 , {2 , 10 , 8 }, {at::kCUDA });
118
+ auto in2 = at::randint (0 , 5 , {8 }, {at::kCUDA });
119
+
120
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
121
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in1, in2});
122
+
123
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
124
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, in2});
125
+
126
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
127
+ }
128
+
129
+ TEST (Converters, ATenMM2d1dConvertsCorrectly) {
130
+ const auto graph = R"IR(
131
+ graph(%0 : Tensor, %1 : Tensor):
132
+ %2 : Tensor = aten::matmul(%0, %1)
133
+ return (%2))IR" ;
134
+
135
+ auto g = std::make_shared<torch::jit::Graph>();
136
+ torch::jit::parseIR (graph, g.get ());
137
+
138
+ auto in1 = at::randint (0 , 5 , {1 , 10 }, {at::kCUDA });
139
+ auto in2 = at::randint (0 , 5 , {10 }, {at::kCUDA });
140
+
141
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
142
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in1, in2});
143
+
144
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
145
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, in2});
146
+
147
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
148
+ }
149
+
150
+ TEST (Converters, ATenMM4d1dConvertsCorrectly) {
151
+ const auto graph = R"IR(
152
+ graph(%0 : Tensor, %1 : Tensor):
153
+ %2 : Tensor = aten::matmul(%0, %1)
154
+ return (%2))IR" ;
155
+
156
+ auto g = std::make_shared<torch::jit::Graph>();
157
+ torch::jit::parseIR (graph, g.get ());
158
+
159
+ auto in1 = at::randint (0 , 5 , {2 , 3 , 10 , 8 }, {at::kCUDA });
160
+ auto in2 = at::randint (0 , 5 , {8 }, {at::kCUDA });
161
+
162
+ auto params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
163
+ auto jit_results = torch_tensorrt::tests::util::RunGraph (g, params, {in1, in2});
164
+
165
+ params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
166
+ auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, in2});
167
+
168
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
48
169
}
49
170
50
171
TEST (Converters, ATenBMMConvertsCorrectly) {
@@ -63,9 +184,8 @@ TEST(Converters, ATenBMMConvertsCorrectly) {
63
184
64
185
params = torch_tensorrt::core::ir::get_static_params (g->inputs (), {});
65
186
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine (g, params, {in1, in2});
66
- auto trt = trt_results[0 ].reshape_as (jit_results[0 ]);
67
187
68
- ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt , 2e-6 ));
188
+ ASSERT_TRUE (torch_tensorrt::tests::util::almostEqual (jit_results[0 ], trt_results[ 0 ] , 2e-6 ));
69
189
}
70
190
71
191
TEST (Converters, ATenBADDBMMConvertsCorrectly) {
0 commit comments