@@ -5,26 +5,30 @@ namespace torch_tensorrt {
5
5
namespace tests {
6
6
namespace util {
7
7
8
- bool almostEqual (const at::Tensor& a , const at::Tensor& b , float atol = 1e-8 , float rtol = 1e-5 ) {
9
- LOG_GRAPH (a << std::endl << b << std::endl) ;
10
- auto a_float = a. toType (at:: kFloat ) ;
11
- auto b_float = b. toType (at:: kFloat ) ;
8
+ bool almostEqual (const at::Tensor& computed_tensor , const at::Tensor& gt_tensor , float atol = 1e-8 , float rtol = 1e-5 ) {
9
+ std::ostringstream ss ;
10
+ ss << computed_tensor << std::endl << gt_tensor << std::endl ;
11
+ ss << " atol: " << atol << " rtol: " << rtol << std::endl ;
12
12
13
- auto diff = a_float - b_float;
13
+ LOG_GRAPH (ss.str ());
14
+ auto computed_tensor_float = computed_tensor.toType (at::kFloat );
15
+ auto gt_tensor_float = gt_tensor.toType (at::kFloat );
16
+
17
+ auto diff = computed_tensor_float - gt_tensor_float;
14
18
auto result = diff.abs ().max ().item <float >();
15
- auto threshold = atol + (rtol * b .abs ().max ().item <float >());
19
+ auto threshold = atol + (rtol * gt_tensor .abs ().max ().item <float >());
16
20
17
- std::cout << " Max Difference: " << result << std::endl ;
18
- std::cout << " Acceptable Threshold: " << threshold << std::endl ;
21
+ LOG_GRAPH ( std::string ( " Max Difference: " ) + std::to_string (result) ) ;
22
+ LOG_GRAPH ( std::string ( " Acceptable Threshold: " ) + std::to_string (threshold)) ;
19
23
20
24
return result <= threshold;
21
25
}
22
26
23
- bool exactlyEqual (const at::Tensor& a , const at::Tensor& b ) {
24
- LOG_GRAPH (a << std::endl << b << std::endl);
25
- std::cout << " Max Difference: " << (a - b ).abs ().max ().item <float >() << std::endl;
27
+ bool exactlyEqual (const at::Tensor& computed_tensor , const at::Tensor& gt_tensor ) {
28
+ LOG_GRAPH (computed_tensor << std::endl << gt_tensor << std::endl);
29
+ std::cout << " Max Difference: " << (computed_tensor - gt_tensor ).abs ().max ().item <float >() << std::endl;
26
30
27
- return (a - b ).abs ().max ().item <float >() == 0 .f ;
31
+ return (computed_tensor - gt_tensor ).abs ().max ().item <float >() == 0 .f ;
28
32
}
29
33
30
34
} // namespace util
0 commit comments