Skip to content

Commit c51ba38

Browse files
authored
Merge pull request #1055 from pytorch/anuragd/test_update_atol_rtol
feat(//tests): Update rtol and atol based tolerance for test cases
2 parents 0c53125 + 0f26bed commit c51ba38

File tree

2 files changed

+18
-13
lines changed

2 files changed

+18
-13
lines changed

tests/util/util.cpp

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,30 @@ namespace torch_tensorrt {
55
namespace tests {
66
namespace util {
77

8-
bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold, float atol = 1e-8, float rtol = 1e-5) {
9-
LOG_GRAPH(a << std::endl << b << std::endl);
10-
auto a_float = a.toType(at::kFloat);
11-
auto b_float = b.toType(at::kFloat);
8+
bool almostEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor, float atol = 1e-8, float rtol = 1e-5) {
9+
std::ostringstream ss;
10+
ss << computed_tensor << std::endl << gt_tensor << std::endl;
11+
ss << " atol: " << atol << " rtol: " << rtol << std::endl;
1212

13-
auto diff = a_float - b_float;
14-
auto result = diff.abs().max().item<float>() - (atol + rtol * b.abs().max().item<float>());
13+
LOG_GRAPH(ss.str());
14+
auto computed_tensor_float = computed_tensor.toType(at::kFloat);
15+
auto gt_tensor_float = gt_tensor.toType(at::kFloat);
1516

16-
std::cout << "Max Difference: " << result << std::endl;
17-
std::cout << "Acceptable Threshold: " << threshold << std::endl;
17+
auto diff = computed_tensor_float - gt_tensor_float;
18+
auto result = diff.abs().max().item<float>();
19+
auto threshold = atol + (rtol * gt_tensor.abs().max().item<float>());
20+
21+
LOG_GRAPH(std::string("Max Difference: ") + std::to_string(result));
22+
LOG_GRAPH(std::string("Acceptable Threshold: ") + std::to_string(threshold));
1823

1924
return result <= threshold;
2025
}
2126

22-
bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) {
23-
LOG_GRAPH(a << std::endl << b << std::endl);
24-
std::cout << "Max Difference: " << (a - b).abs().max().item<float>() << std::endl;
27+
bool exactlyEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor) {
28+
LOG_GRAPH(computed_tensor << std::endl << gt_tensor << std::endl);
29+
std::cout << "Max Difference: " << (computed_tensor - gt_tensor).abs().max().item<float>() << std::endl;
2530

26-
return (a - b).abs().max().item<float>() == 0.f;
31+
return (computed_tensor - gt_tensor).abs().max().item<float>() == 0.f;
2732
}
2833

2934
} // namespace util

tests/util/util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace torch_tensorrt {
1111
namespace tests {
1212
namespace util {
1313

14-
bool almostEqual(const at::Tensor& a, const at::Tensor& b, float threshold, float atol = 1e-8, float rtol = 1e-5);
14+
bool almostEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor, float atol = 1e-8, float rtol = 1e-5);
1515

1616
bool exactlyEqual(const at::Tensor& a, const at::Tensor& b);
1717

0 commit comments

Comments
 (0)