Skip to content

Commit 5fc3423

Browse files
committed
ROCm: Fix test_nadam
Change the rtol level Signed-off-by: Jagadish Krishnamoorthy <[email protected]>
1 parent a6c4396 commit 5fc3423

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

test/test_optim.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
_LRScheduler, CyclicLR, CosineAnnealingWarmRestarts, OneCycleLR, ChainedScheduler, \
2121
EPOCH_DEPRECATION_WARNING
2222
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
23-
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_UBSAN, load_tests
23+
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_ROCM, TEST_WITH_UBSAN, load_tests
2424
# load_tests from common_utils is used to automatically filter tests for
2525
# sharding on sandcastle. This line silences flake warnings
2626
load_tests = load_tests
@@ -661,6 +661,8 @@ def test_adadelta_complex(self):
661661
)
662662

663663
def test_nadam(self):
664+
if TEST_WITH_ROCM:
665+
self.rel_tol = 1e-5
664666
for optimizer in [optim.NAdam, optim_mt.NAdam]:
665667
self._test_basic_cases(
666668
lambda weight, bias: optimizer([weight, bias], lr=1e-3)

0 commit comments

Comments
 (0)