Skip to content

Commit 00adb21

Browse files
committed
chore!: Changing the names of tensor
Signed-off-by: Anurag Dixit <[email protected]>
1 parent c7f0147 commit 00adb21

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

tests/util/util.cpp

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,30 @@ namespace torch_tensorrt {
55
namespace tests {
66
namespace util {
77

8-
bool almostEqual(const at::Tensor& a, const at::Tensor& b, float atol = 1e-8, float rtol = 1e-5) {
9-
LOG_GRAPH(a << std::endl << b << std::endl);
10-
auto a_float = a.toType(at::kFloat);
11-
auto b_float = b.toType(at::kFloat);
8+
bool almostEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor, float atol = 1e-8, float rtol = 1e-5) {
9+
std::ostringstream ss;
10+
ss << computed_tensor << std::endl << gt_tensor << std::endl;
11+
ss << " atol: " << atol << " rtol: " << rtol << std::endl;
1212

13-
auto diff = a_float - b_float;
13+
LOG_GRAPH(ss.str());
14+
auto computed_tensor_float = computed_tensor.toType(at::kFloat);
15+
auto gt_tensor_float = gt_tensor.toType(at::kFloat);
16+
17+
auto diff = computed_tensor_float - gt_tensor_float;
1418
auto result = diff.abs().max().item<float>();
15-
auto threshold = atol + (rtol * b.abs().max().item<float>());
19+
auto threshold = atol + (rtol * gt_tensor.abs().max().item<float>());
1620

17-
std::cout << "Max Difference: " << result << std::endl;
18-
std::cout << "Acceptable Threshold: " << threshold << std::endl;
21+
LOG_GRAPH(std::string("Max Difference: ") + std::to_string(result) );
22+
LOG_GRAPH(std::string("Acceptable Threshold: ") + std::to_string(threshold));
1923

2024
return result <= threshold;
2125
}
2226

23-
bool exactlyEqual(const at::Tensor& a, const at::Tensor& b) {
24-
LOG_GRAPH(a << std::endl << b << std::endl);
25-
std::cout << "Max Difference: " << (a - b).abs().max().item<float>() << std::endl;
27+
bool exactlyEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor) {
28+
LOG_GRAPH(computed_tensor << std::endl << gt_tensor << std::endl);
29+
std::cout << "Max Difference: " << (computed_tensor - gt_tensor).abs().max().item<float>() << std::endl;
2630

27-
return (a - b).abs().max().item<float>() == 0.f;
31+
return (computed_tensor - gt_tensor).abs().max().item<float>() == 0.f;
2832
}
2933

3034
} // namespace util

tests/util/util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace torch_tensorrt {
1111
namespace tests {
1212
namespace util {
1313

14-
bool almostEqual(const at::Tensor& a, const at::Tensor& b, float atol = 1e-8, float rtol = 1e-5);
14+
bool almostEqual(const at::Tensor& computed_tensor, const at::Tensor& gt_tensor, float atol = 1e-8, float rtol = 1e-5);
1515

1616
bool exactlyEqual(const at::Tensor& a, const at::Tensor& b);
1717

0 commit comments

Comments
 (0)