Skip to content

Commit c4c165b

Browse files
Yujie Huifacebook-github-bot
authored andcommitted
Enable customization of atol and rtol (#3720)
Summary: Pull Request resolved: #3720 Enable customization of atol and rtol for TestSuite in codegen. The reason for this is: we need to relax atol and rtol to pass some tests. (D57686850) bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Reviewed By: copyrightly Differential Revision: D57698945 fbshipit-source-id: 82c4888a5ddeaa1ea2eb49a95652d998dc89c44a
1 parent 2194486 commit c4c165b

File tree

3 files changed

+8
-2
lines changed

3 files changed

+8
-2
lines changed

backends/vulkan/test/op_tests/cases.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,8 @@ def get_unary_ops_inputs():
682682
]
683683
)
684684
test_suite.storage_types = ["api::kTexture3D", "api::kBuffer"]
685+
test_suite.atol = "1e-4"
686+
test_suite.rtol = "1e-4"
685687
return test_suite
686688

687689

backends/vulkan/test/op_tests/utils/codegen.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -600,8 +600,8 @@ class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple
600600
protected:
601601
ComputeGraph* graph;
602602
at::ScalarType test_dtype = at::kFloat;
603-
float rtol = 1e-4;
604-
float atol = 1e-4;
603+
float rtol = {rtol};
604+
float atol = {atol};
605605
606606
void SetUp() override {{
607607
GraphConfig config;
@@ -651,6 +651,8 @@ def generate_fixture_cpp(self) -> str:
651651
op_name=self.op_name,
652652
check_fn=check_fn,
653653
prepacked_check_fn=prepacked_check_fn,
654+
rtol=self.suite_def.rtol,
655+
atol=self.suite_def.atol,
654656
)
655657

656658
def gen_parameterization(self) -> str:

backends/vulkan/test/op_tests/utils/codegen_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def __init__(self, input_cases: List[Any]):
4747
self.prepacked_args: List[str] = []
4848
self.requires_prepack: bool = False
4949
self.dtypes: List[str] = ["at::kFloat", "at::kHalf"]
50+
self.atol: str = "1e-5"
51+
self.rtol: str = "1e-5"
5052

5153
def supports_prepack(self):
5254
return len(self.prepacked_args) > 0

0 commit comments

Comments
 (0)