7
7
import itertools
8
8
import warnings
9
9
from copy import deepcopy
10
+ from random import randint
10
11
11
12
import torch
12
13
from botorch import settings
@@ -65,7 +66,7 @@ def transform(self, X):
65
66
66
67
67
68
class TestInputTransforms (BotorchTestCase ):
68
- def test_abstract_base_input_transform (self ):
69
+ def test_abstract_base_input_transform (self ) -> None :
69
70
with self .assertRaises (TypeError ):
70
71
InputTransform ()
71
72
X = torch .zeros ([1 ])
@@ -140,7 +141,9 @@ def test_abstract_base_input_transform(self):
140
141
with self .assertRaises (NotImplementedError ):
141
142
affine ._update_coefficients (X )
142
143
143
- def test_normalize (self ):
144
+ def test_normalize (self ) -> None :
145
+ # set seed to range where this is known to not be flaky
146
+ torch .manual_seed (randint (0 , 1000 ))
144
147
for dtype in (torch .float , torch .double ):
145
148
# basic init, learned bounds
146
149
nlz = Normalize (d = 2 )
@@ -259,7 +262,9 @@ def test_normalize(self):
259
262
[X .min (dim = - 2 , keepdim = True )[0 ], X .max (dim = - 2 , keepdim = True )[0 ]],
260
263
dim = - 2 ,
261
264
)
262
- self .assertAllClose (nlz .bounds , expected_bounds )
265
+ atol = 1e-6 if dtype is torch .float32 else 1e-12
266
+ rtol = 1e-4 if dtype is torch .float32 else 1e-8
267
+ self .assertAllClose (nlz .bounds , expected_bounds , atol = atol , rtol = rtol )
263
268
# test errors on wrong shape
264
269
nlz = Normalize (d = 2 , batch_shape = batch_shape )
265
270
X = torch .randn (* batch_shape , 2 , 1 , device = self .device , dtype = dtype )
@@ -526,6 +531,8 @@ def test_chained_input_transform(self):
526
531
ds = (1 , 2 )
527
532
batch_shapes = (torch .Size (), torch .Size ([2 ]))
528
533
dtypes = (torch .float , torch .double )
534
+ # set seed to range where this is known to not be flaky
535
+ torch .manual_seed (randint (0 , 1000 ))
529
536
530
537
for d , batch_shape , dtype in itertools .product (ds , batch_shapes , dtypes ):
531
538
bounds = torch .tensor (
@@ -591,27 +598,25 @@ def test_chained_input_transform(self):
591
598
tf = ChainedInputTransform (stz = tf1 , pert = tf2 )
592
599
self .assertTrue (tf .is_one_to_many )
593
600
594
- def test_round_transform (self ):
595
- for dtype in (torch .float , torch .double ):
596
- # basic init
597
- int_idcs = [0 , 4 ]
598
- categorical_feats = {2 : 2 , 5 : 3 }
599
- # test deprecation warning
600
- with warnings .catch_warnings (record = True ) as ws , settings .debug (True ):
601
- Round (indices = int_idcs )
602
- self .assertTrue (
603
- any (issubclass (w .category , DeprecationWarning ) for w in ws )
604
- )
605
- round_tf = Round (
606
- integer_indices = int_idcs , categorical_features = categorical_feats
607
- )
608
- self .assertEqual (round_tf .integer_indices .tolist (), int_idcs )
609
- self .assertEqual (round_tf .categorical_features , categorical_feats )
610
- self .assertTrue (round_tf .training )
611
- self .assertFalse (round_tf .approximate )
612
- self .assertEqual (round_tf .tau , 1e-3 )
613
- self .assertTrue (round_tf .equals (Round (** round_tf .get_init_args ())))
601
+ def test_round_transform_init (self ) -> None :
602
+ # basic init
603
+ int_idcs = [0 , 4 ]
604
+ categorical_feats = {2 : 2 , 5 : 3 }
605
+ # test deprecation warning
606
+ with warnings .catch_warnings (record = True ) as ws , settings .debug (True ):
607
+ Round (indices = int_idcs )
608
+ self .assertTrue (any (issubclass (w .category , DeprecationWarning ) for w in ws ))
609
+ round_tf = Round (
610
+ integer_indices = int_idcs , categorical_features = categorical_feats
611
+ )
612
+ self .assertEqual (round_tf .integer_indices .tolist (), int_idcs )
613
+ self .assertEqual (round_tf .categorical_features , categorical_feats )
614
+ self .assertTrue (round_tf .training )
615
+ self .assertFalse (round_tf .approximate )
616
+ self .assertEqual (round_tf .tau , 1e-3 )
617
+ self .assertTrue (round_tf .equals (Round (** round_tf .get_init_args ())))
614
618
619
+ for dtype in (torch .float , torch .double ):
615
620
# With tensor indices.
616
621
round_tf = Round (
617
622
integer_indices = torch .tensor (int_idcs , dtype = dtype , device = self .device ),
@@ -620,11 +625,22 @@ def test_round_transform(self):
620
625
self .assertEqual (round_tf .integer_indices .tolist (), int_idcs )
621
626
self .assertTrue (round_tf .equals (Round (** round_tf .get_init_args ())))
622
627
623
- # basic usage
624
- for batch_shape , approx , categorical_features in itertools .product (
625
- (torch .Size (), torch .Size ([3 ])),
626
- (False , True ),
627
- (None , categorical_feats ),
628
+ def test_round_transform (self ) -> None :
629
+ int_idcs = [0 , 4 ]
630
+ categorical_feats = {2 : 2 , 5 : 3 }
631
+ # set seed to range where this is known to not be flaky
632
+ torch .manual_seed (randint (0 , 1000 ))
633
+ for dtype , batch_shape , approx , categorical_features in itertools .product (
634
+ (torch .float , torch .double ),
635
+ (torch .Size (), torch .Size ([3 ])),
636
+ (False , True ),
637
+ (None , categorical_feats ),
638
+ ):
639
+ with self .subTest (
640
+ dtype = dtype ,
641
+ batch_shape = batch_shape ,
642
+ approx = approx ,
643
+ categorical_features = categorical_features ,
628
644
):
629
645
X = torch .rand (* batch_shape , 4 , 8 , device = self .device , dtype = dtype )
630
646
X [..., int_idcs ] *= 5
@@ -649,11 +665,15 @@ def test_round_transform(self):
649
665
if approx :
650
666
# check that approximate rounding is closer to rounded values than
651
667
# the original inputs
668
+ dist_approx_to_rounded = (
669
+ X_rounded [..., int_idcs ] - exact_rounded_X_ints
670
+ ).abs ()
671
+ dist_orig_to_rounded = (
672
+ X [..., int_idcs ] - exact_rounded_X_ints
673
+ ).abs ()
674
+ tol = 1e-5 if dtype == torch .float32 else 1e-11
652
675
self .assertTrue (
653
- (
654
- (X_rounded [..., int_idcs ] - exact_rounded_X_ints ).abs ()
655
- <= (X [..., int_idcs ] - exact_rounded_X_ints ).abs ()
656
- ).all ()
676
+ (dist_approx_to_rounded <= dist_orig_to_rounded + tol ).all ()
657
677
)
658
678
self .assertFalse (
659
679
torch .equal (X_rounded [..., int_idcs ], exact_rounded_X_ints )
@@ -756,7 +776,9 @@ def test_round_transform(self):
756
776
torch .equal (round_tf .preprocess_transform (X ), X_rounded )
757
777
)
758
778
759
- def test_log10_transform (self ):
779
+ def test_log10_transform (self ) -> None :
780
+ # set seed to range where this is known to not be flaky
781
+ torch .manual_seed (randint (0 , 1000 ))
760
782
for dtype in (torch .float , torch .double ):
761
783
# basic init
762
784
indices = [0 , 2 ]
@@ -810,7 +832,9 @@ def test_log10_transform(self):
810
832
log_tf .transform_on_train = True
811
833
self .assertTrue (torch .equal (log_tf .preprocess_transform (X ), X_tf ))
812
834
813
- def test_warp_transform (self ):
835
+ def test_warp_transform (self ) -> None :
836
+ # set seed to range where this is known to not be flaky
837
+ torch .manual_seed (randint (0 , 1000 ))
814
838
for dtype , batch_shape , warp_batch_shape in itertools .product (
815
839
(torch .float , torch .double ),
816
840
(torch .Size (), torch .Size ([3 ])),
@@ -955,7 +979,9 @@ def test_warp_transform(self):
955
979
warp_tf ._set_concentration (i = 1 , value = 3.0 )
956
980
self .assertTrue ((warp_tf .concentration1 == 3.0 ).all ())
957
981
958
- def test_one_hot_to_numeric (self ):
982
+ def test_one_hot_to_numeric (self ) -> None :
983
+ # set seed to range where this is known to not be flaky
984
+ torch .manual_seed (randint (0 , 1000 ))
959
985
dim = 8
960
986
# test exception when categoricals are not the trailing dimensions
961
987
categorical_features = {0 : 2 }
@@ -1042,6 +1068,9 @@ def test_append_features(self):
1042
1068
with self .assertRaises (ValueError ):
1043
1069
AppendFeatures (torch .ones (3 , 4 , 2 ))
1044
1070
1071
+ # set seed to range where this is known to not be flaky
1072
+ torch .manual_seed (randint (0 , 100 ))
1073
+
1045
1074
for dtype in (torch .float , torch .double ):
1046
1075
feature_set = (
1047
1076
torch .linspace (0 , 1 , 6 ).view (3 , 2 ).to (device = self .device , dtype = dtype )
@@ -1106,6 +1135,9 @@ def f2(x: Tensor, n_f: int = 1) -> Tensor:
1106
1135
result = x [..., - 2 :].unsqueeze (- 2 )
1107
1136
return result .expand (* result .shape [:- 2 ], n_f , - 1 )
1108
1137
1138
+ # set seed to range where this is known to not be flaky
1139
+ torch .manual_seed (randint (0 , 100 ))
1140
+
1109
1141
for dtype in [torch .float , torch .double ]:
1110
1142
tkwargs = {"device" : self .device , "dtype" : dtype }
1111
1143
@@ -1336,6 +1368,9 @@ def test_filter_features(self):
1336
1368
with self .assertRaises (ValueError ):
1337
1369
FilterFeatures (torch .tensor ([0 , 1 , 1 ], dtype = torch .long ))
1338
1370
1371
+ # set seed to range where this is known to not be flaky
1372
+ torch .manual_seed (randint (0 , 100 ))
1373
+
1339
1374
for dtype in (torch .float , torch .double ):
1340
1375
feature_indices = torch .tensor (
1341
1376
[0 , 2 , 3 , 5 ], dtype = torch .long , device = self .device
0 commit comments