Skip to content

Commit 5a8d27c

Browse files
esantorellafacebook-github-bot
authored andcommitted
Stop input transform tests from being flaky (#1896)
Summary: ## Motivation Tests were flaky. [x] Break one long test into two [x] Loosen a tolerance [x] Set seeds to be randomly chosen from the range (0, 1000), where I made sure this is not flaky ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: #1896 Test Plan: Ran each test 1000 times with each seed in (0, 1000) Reviewed By: Balandat Differential Revision: D46842516 Pulled By: esantorella fbshipit-source-id: 91bbe7ad48f1ff46b68a81257ca9e6d35c7ce60e
1 parent 04d8f05 commit 5a8d27c

File tree

1 file changed

+70
-35
lines changed

1 file changed

+70
-35
lines changed

test/models/transforms/test_input.py

Lines changed: 70 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import itertools
88
import warnings
99
from copy import deepcopy
10+
from random import randint
1011

1112
import torch
1213
from botorch import settings
@@ -65,7 +66,7 @@ def transform(self, X):
6566

6667

6768
class TestInputTransforms(BotorchTestCase):
68-
def test_abstract_base_input_transform(self):
69+
def test_abstract_base_input_transform(self) -> None:
6970
with self.assertRaises(TypeError):
7071
InputTransform()
7172
X = torch.zeros([1])
@@ -140,7 +141,9 @@ def test_abstract_base_input_transform(self):
140141
with self.assertRaises(NotImplementedError):
141142
affine._update_coefficients(X)
142143

143-
def test_normalize(self):
144+
def test_normalize(self) -> None:
145+
# set seed to range where this is known to not be flaky
146+
torch.manual_seed(randint(0, 1000))
144147
for dtype in (torch.float, torch.double):
145148
# basic init, learned bounds
146149
nlz = Normalize(d=2)
@@ -259,7 +262,9 @@ def test_normalize(self):
259262
[X.min(dim=-2, keepdim=True)[0], X.max(dim=-2, keepdim=True)[0]],
260263
dim=-2,
261264
)
262-
self.assertAllClose(nlz.bounds, expected_bounds)
265+
atol = 1e-6 if dtype is torch.float32 else 1e-12
266+
rtol = 1e-4 if dtype is torch.float32 else 1e-8
267+
self.assertAllClose(nlz.bounds, expected_bounds, atol=atol, rtol=rtol)
263268
# test errors on wrong shape
264269
nlz = Normalize(d=2, batch_shape=batch_shape)
265270
X = torch.randn(*batch_shape, 2, 1, device=self.device, dtype=dtype)
@@ -526,6 +531,8 @@ def test_chained_input_transform(self):
526531
ds = (1, 2)
527532
batch_shapes = (torch.Size(), torch.Size([2]))
528533
dtypes = (torch.float, torch.double)
534+
# set seed to range where this is known to not be flaky
535+
torch.manual_seed(randint(0, 1000))
529536

530537
for d, batch_shape, dtype in itertools.product(ds, batch_shapes, dtypes):
531538
bounds = torch.tensor(
@@ -591,27 +598,25 @@ def test_chained_input_transform(self):
591598
tf = ChainedInputTransform(stz=tf1, pert=tf2)
592599
self.assertTrue(tf.is_one_to_many)
593600

594-
def test_round_transform(self):
595-
for dtype in (torch.float, torch.double):
596-
# basic init
597-
int_idcs = [0, 4]
598-
categorical_feats = {2: 2, 5: 3}
599-
# test deprecation warning
600-
with warnings.catch_warnings(record=True) as ws, settings.debug(True):
601-
Round(indices=int_idcs)
602-
self.assertTrue(
603-
any(issubclass(w.category, DeprecationWarning) for w in ws)
604-
)
605-
round_tf = Round(
606-
integer_indices=int_idcs, categorical_features=categorical_feats
607-
)
608-
self.assertEqual(round_tf.integer_indices.tolist(), int_idcs)
609-
self.assertEqual(round_tf.categorical_features, categorical_feats)
610-
self.assertTrue(round_tf.training)
611-
self.assertFalse(round_tf.approximate)
612-
self.assertEqual(round_tf.tau, 1e-3)
613-
self.assertTrue(round_tf.equals(Round(**round_tf.get_init_args())))
601+
def test_round_transform_init(self) -> None:
602+
# basic init
603+
int_idcs = [0, 4]
604+
categorical_feats = {2: 2, 5: 3}
605+
# test deprecation warning
606+
with warnings.catch_warnings(record=True) as ws, settings.debug(True):
607+
Round(indices=int_idcs)
608+
self.assertTrue(any(issubclass(w.category, DeprecationWarning) for w in ws))
609+
round_tf = Round(
610+
integer_indices=int_idcs, categorical_features=categorical_feats
611+
)
612+
self.assertEqual(round_tf.integer_indices.tolist(), int_idcs)
613+
self.assertEqual(round_tf.categorical_features, categorical_feats)
614+
self.assertTrue(round_tf.training)
615+
self.assertFalse(round_tf.approximate)
616+
self.assertEqual(round_tf.tau, 1e-3)
617+
self.assertTrue(round_tf.equals(Round(**round_tf.get_init_args())))
614618

619+
for dtype in (torch.float, torch.double):
615620
# With tensor indices.
616621
round_tf = Round(
617622
integer_indices=torch.tensor(int_idcs, dtype=dtype, device=self.device),
@@ -620,11 +625,22 @@ def test_round_transform(self):
620625
self.assertEqual(round_tf.integer_indices.tolist(), int_idcs)
621626
self.assertTrue(round_tf.equals(Round(**round_tf.get_init_args())))
622627

623-
# basic usage
624-
for batch_shape, approx, categorical_features in itertools.product(
625-
(torch.Size(), torch.Size([3])),
626-
(False, True),
627-
(None, categorical_feats),
628+
def test_round_transform(self) -> None:
629+
int_idcs = [0, 4]
630+
categorical_feats = {2: 2, 5: 3}
631+
# set seed to range where this is known to not be flaky
632+
torch.manual_seed(randint(0, 1000))
633+
for dtype, batch_shape, approx, categorical_features in itertools.product(
634+
(torch.float, torch.double),
635+
(torch.Size(), torch.Size([3])),
636+
(False, True),
637+
(None, categorical_feats),
638+
):
639+
with self.subTest(
640+
dtype=dtype,
641+
batch_shape=batch_shape,
642+
approx=approx,
643+
categorical_features=categorical_features,
628644
):
629645
X = torch.rand(*batch_shape, 4, 8, device=self.device, dtype=dtype)
630646
X[..., int_idcs] *= 5
@@ -649,11 +665,15 @@ def test_round_transform(self):
649665
if approx:
650666
# check that approximate rounding is closer to rounded values than
651667
# the original inputs
668+
dist_approx_to_rounded = (
669+
X_rounded[..., int_idcs] - exact_rounded_X_ints
670+
).abs()
671+
dist_orig_to_rounded = (
672+
X[..., int_idcs] - exact_rounded_X_ints
673+
).abs()
674+
tol = 1e-5 if dtype == torch.float32 else 1e-11
652675
self.assertTrue(
653-
(
654-
(X_rounded[..., int_idcs] - exact_rounded_X_ints).abs()
655-
<= (X[..., int_idcs] - exact_rounded_X_ints).abs()
656-
).all()
676+
(dist_approx_to_rounded <= dist_orig_to_rounded + tol).all()
657677
)
658678
self.assertFalse(
659679
torch.equal(X_rounded[..., int_idcs], exact_rounded_X_ints)
@@ -756,7 +776,9 @@ def test_round_transform(self):
756776
torch.equal(round_tf.preprocess_transform(X), X_rounded)
757777
)
758778

759-
def test_log10_transform(self):
779+
def test_log10_transform(self) -> None:
780+
# set seed to range where this is known to not be flaky
781+
torch.manual_seed(randint(0, 1000))
760782
for dtype in (torch.float, torch.double):
761783
# basic init
762784
indices = [0, 2]
@@ -810,7 +832,9 @@ def test_log10_transform(self):
810832
log_tf.transform_on_train = True
811833
self.assertTrue(torch.equal(log_tf.preprocess_transform(X), X_tf))
812834

813-
def test_warp_transform(self):
835+
def test_warp_transform(self) -> None:
836+
# set seed to range where this is known to not be flaky
837+
torch.manual_seed(randint(0, 1000))
814838
for dtype, batch_shape, warp_batch_shape in itertools.product(
815839
(torch.float, torch.double),
816840
(torch.Size(), torch.Size([3])),
@@ -955,7 +979,9 @@ def test_warp_transform(self):
955979
warp_tf._set_concentration(i=1, value=3.0)
956980
self.assertTrue((warp_tf.concentration1 == 3.0).all())
957981

958-
def test_one_hot_to_numeric(self):
982+
def test_one_hot_to_numeric(self) -> None:
983+
# set seed to range where this is known to not be flaky
984+
torch.manual_seed(randint(0, 1000))
959985
dim = 8
960986
# test exception when categoricals are not the trailing dimensions
961987
categorical_features = {0: 2}
@@ -1042,6 +1068,9 @@ def test_append_features(self):
10421068
with self.assertRaises(ValueError):
10431069
AppendFeatures(torch.ones(3, 4, 2))
10441070

1071+
# set seed to range where this is known to not be flaky
1072+
torch.manual_seed(randint(0, 100))
1073+
10451074
for dtype in (torch.float, torch.double):
10461075
feature_set = (
10471076
torch.linspace(0, 1, 6).view(3, 2).to(device=self.device, dtype=dtype)
@@ -1106,6 +1135,9 @@ def f2(x: Tensor, n_f: int = 1) -> Tensor:
11061135
result = x[..., -2:].unsqueeze(-2)
11071136
return result.expand(*result.shape[:-2], n_f, -1)
11081137

1138+
# set seed to range where this is known to not be flaky
1139+
torch.manual_seed(randint(0, 100))
1140+
11091141
for dtype in [torch.float, torch.double]:
11101142
tkwargs = {"device": self.device, "dtype": dtype}
11111143

@@ -1336,6 +1368,9 @@ def test_filter_features(self):
13361368
with self.assertRaises(ValueError):
13371369
FilterFeatures(torch.tensor([0, 1, 1], dtype=torch.long))
13381370

1371+
# set seed to range where this is known to not be flaky
1372+
torch.manual_seed(randint(0, 100))
1373+
13391374
for dtype in (torch.float, torch.double):
13401375
feature_indices = torch.tensor(
13411376
[0, 2, 3, 5], dtype=torch.long, device=self.device

0 commit comments

Comments
 (0)