@@ -5,25 +5,30 @@ namespace torch_tensorrt {
5
5
namespace tests {
6
6
namespace util {
7
7
8
- bool almostEqual (const at::Tensor& a , const at::Tensor& b, float threshold , float atol = 1e-8 , float rtol = 1e-5 ) {
9
- LOG_GRAPH (a << std::endl << b << std::endl) ;
10
- auto a_float = a. toType (at:: kFloat ) ;
11
- auto b_float = b. toType (at:: kFloat ) ;
8
+ bool almostEqual (const at::Tensor& computed_tensor , const at::Tensor& gt_tensor , float atol = 1e-8 , float rtol = 1e-5 ) {
9
+ std::ostringstream ss ;
10
+ ss << computed_tensor << std::endl << gt_tensor << std::endl ;
11
+ ss << " atol: " << atol << " rtol: " << rtol << std::endl ;
12
12
13
- auto diff = a_float - b_float;
14
- auto result = diff.abs ().max ().item <float >() - (atol + rtol * b.abs ().max ().item <float >());
13
+ LOG_GRAPH (ss.str ());
14
+ auto computed_tensor_float = computed_tensor.toType (at::kFloat );
15
+ auto gt_tensor_float = gt_tensor.toType (at::kFloat );
15
16
16
- std::cout << " Max Difference: " << result << std::endl;
17
- std::cout << " Acceptable Threshold: " << threshold << std::endl;
17
+ auto diff = computed_tensor_float - gt_tensor_float;
18
+ auto result = diff.abs ().max ().item <float >();
19
+ auto threshold = atol + (rtol * gt_tensor.abs ().max ().item <float >());
20
+
21
+ LOG_GRAPH (std::string (" Max Difference: " ) + std::to_string (result));
22
+ LOG_GRAPH (std::string (" Acceptable Threshold: " ) + std::to_string (threshold));
18
23
19
24
return result <= threshold;
20
25
}
21
26
22
- bool exactlyEqual (const at::Tensor& a , const at::Tensor& b ) {
23
- LOG_GRAPH (a << std::endl << b << std::endl);
24
- std::cout << " Max Difference: " << (a - b ).abs ().max ().item <float >() << std::endl;
27
+ bool exactlyEqual (const at::Tensor& computed_tensor , const at::Tensor& gt_tensor ) {
28
+ LOG_GRAPH (computed_tensor << std::endl << gt_tensor << std::endl);
29
+ std::cout << " Max Difference: " << (computed_tensor - gt_tensor ).abs ().max ().item <float >() << std::endl;
25
30
26
- return (a - b ).abs ().max ().item <float >() == 0 .f ;
31
+ return (computed_tensor - gt_tensor ).abs ().max ().item <float >() == 0 .f ;
27
32
}
28
33
29
34
} // namespace util
0 commit comments