@@ -390,7 +390,6 @@ def test_construct_inputs_qEI(self):
390
390
self .assertTrue (torch .equal (kwargs ["objective" ].weights , objective .weights ))
391
391
self .assertTrue (torch .equal (kwargs ["X_pending" ], X_pending ))
392
392
self .assertIsNone (kwargs ["sampler" ])
393
- self .assertIsNone (kwargs ["constraints" ])
394
393
self .assertIsInstance (kwargs ["eta" ], float )
395
394
self .assertTrue (kwargs ["eta" ] < 1 )
396
395
multi_Y = torch .cat ([d .Y () for d in self .blockX_multiY .values ()], dim = - 1 )
@@ -406,6 +405,20 @@ def test_construct_inputs_qEI(self):
406
405
best_f = best_f_expected ,
407
406
)
408
407
self .assertEqual (kwargs ["best_f" ], best_f_expected )
408
+ # test passing constraints
409
+ outcome_constraints = (torch .tensor ([[0.0 , 1.0 ]]), torch .tensor ([[0.5 ]]))
410
+ constraints = get_outcome_constraint_transforms (
411
+ outcome_constraints = outcome_constraints
412
+ )
413
+ kwargs = c (
414
+ model = mock_model ,
415
+ training_data = self .blockX_multiY ,
416
+ objective = objective ,
417
+ X_pending = X_pending ,
418
+ best_f = best_f_expected ,
419
+ constraints = constraints ,
420
+ )
421
+ self .assertIs (kwargs ["constraints" ], constraints )
409
422
410
423
# testing qLogEI input constructor
411
424
log_constructor = get_acqf_input_constructor (qLogExpectedImprovement )
@@ -415,6 +428,7 @@ def test_construct_inputs_qEI(self):
415
428
objective = objective ,
416
429
X_pending = X_pending ,
417
430
best_f = best_f_expected ,
431
+ constraints = constraints ,
418
432
)
419
433
# includes strict superset of kwargs tested above
420
434
self .assertTrue (kwargs .items () <= log_kwargs .items ())
@@ -423,6 +437,7 @@ def test_construct_inputs_qEI(self):
423
437
self .assertEqual (log_kwargs ["tau_max" ], TAU_MAX )
424
438
self .assertTrue ("tau_relu" in log_kwargs )
425
439
self .assertEqual (log_kwargs ["tau_relu" ], TAU_RELU )
440
+ self .assertIs (log_kwargs ["constraints" ], constraints )
426
441
427
442
def test_construct_inputs_qNEI (self ):
428
443
c = get_acqf_input_constructor (qNoisyExpectedImprovement )
@@ -441,29 +456,36 @@ def test_construct_inputs_qNEI(self):
441
456
with self .assertRaisesRegex (ValueError , "Field `X` must be shared" ):
442
457
c (model = mock_model , training_data = self .multiX_multiY )
443
458
X_baseline = torch .rand (2 , 2 )
459
+ outcome_constraints = (torch .tensor ([[0.0 , 1.0 ]]), torch .tensor ([[0.5 ]]))
460
+ constraints = get_outcome_constraint_transforms (
461
+ outcome_constraints = outcome_constraints
462
+ )
444
463
kwargs = c (
445
464
model = mock_model ,
446
465
training_data = self .blockX_blockY ,
447
466
X_baseline = X_baseline ,
448
467
prune_baseline = False ,
468
+ constraints = constraints ,
449
469
)
450
470
self .assertEqual (kwargs ["model" ], mock_model )
451
471
self .assertIsNone (kwargs ["objective" ])
452
472
self .assertIsNone (kwargs ["X_pending" ])
453
473
self .assertIsNone (kwargs ["sampler" ])
454
474
self .assertFalse (kwargs ["prune_baseline" ])
455
475
self .assertTrue (torch .equal (kwargs ["X_baseline" ], X_baseline ))
456
- self .assertIsNone (kwargs ["constraints" ])
457
476
self .assertIsInstance (kwargs ["eta" ], float )
458
477
self .assertTrue (kwargs ["eta" ] < 1 )
478
+ self .assertIs (kwargs ["constraints" ], constraints )
459
479
460
480
# testing qLogNEI input constructor
461
481
log_constructor = get_acqf_input_constructor (qLogNoisyExpectedImprovement )
482
+
462
483
log_kwargs = log_constructor (
463
484
model = mock_model ,
464
485
training_data = self .blockX_blockY ,
465
486
X_baseline = X_baseline ,
466
487
prune_baseline = False ,
488
+ constraints = constraints ,
467
489
)
468
490
# includes strict superset of kwargs tested above
469
491
self .assertTrue (kwargs .items () <= log_kwargs .items ())
@@ -472,6 +494,7 @@ def test_construct_inputs_qNEI(self):
472
494
self .assertEqual (log_kwargs ["tau_max" ], TAU_MAX )
473
495
self .assertTrue ("tau_relu" in log_kwargs )
474
496
self .assertEqual (log_kwargs ["tau_relu" ], TAU_RELU )
497
+ self .assertIs (log_kwargs ["constraints" ], constraints )
475
498
476
499
def test_construct_inputs_qPI (self ):
477
500
c = get_acqf_input_constructor (qProbabilityOfImprovement )
@@ -499,23 +522,28 @@ def test_construct_inputs_qPI(self):
499
522
self .assertTrue (torch .equal (kwargs ["X_pending" ], X_pending ))
500
523
self .assertIsNone (kwargs ["sampler" ])
501
524
self .assertEqual (kwargs ["tau" ], 1e-2 )
502
- self .assertIsNone (kwargs ["constraints" ])
503
525
self .assertIsInstance (kwargs ["eta" ], float )
504
526
self .assertTrue (kwargs ["eta" ] < 1 )
505
527
multi_Y = torch .cat ([d .Y () for d in self .blockX_multiY .values ()], dim = - 1 )
506
528
best_f_expected = objective (multi_Y ).max ()
507
529
self .assertEqual (kwargs ["best_f" ], best_f_expected )
508
530
# Check explicitly specifying `best_f`.
509
531
best_f_expected = best_f_expected - 1 # Random value.
532
+ outcome_constraints = (torch .tensor ([[0.0 , 1.0 ]]), torch .tensor ([[0.5 ]]))
533
+ constraints = get_outcome_constraint_transforms (
534
+ outcome_constraints = outcome_constraints
535
+ )
510
536
kwargs = c (
511
537
model = mock_model ,
512
538
training_data = self .blockX_multiY ,
513
539
objective = objective ,
514
540
X_pending = X_pending ,
515
541
tau = 1e-2 ,
516
542
best_f = best_f_expected ,
543
+ constraints = constraints ,
517
544
)
518
545
self .assertEqual (kwargs ["best_f" ], best_f_expected )
546
+ self .assertIs (kwargs ["constraints" ], constraints )
519
547
520
548
def test_construct_inputs_qUCB (self ):
521
549
c = get_acqf_input_constructor (qUpperConfidenceBound )
@@ -564,7 +592,7 @@ def test_construct_inputs_EHVI(self):
564
592
model = mock_model ,
565
593
training_data = self .blockX_blockY ,
566
594
objective_thresholds = objective_thresholds ,
567
- outcome_constraints = mock .Mock (),
595
+ constraints = mock .Mock (),
568
596
)
569
597
570
598
# test with Y_pmean supplied explicitly
@@ -702,13 +730,16 @@ def test_construct_inputs_qEHVI(self):
702
730
weights = torch .rand (2 )
703
731
obj = WeightedMCMultiOutputObjective (weights = weights )
704
732
outcome_constraints = (torch .tensor ([[0.0 , 1.0 ]]), torch .tensor ([[0.5 ]]))
733
+ constraints = get_outcome_constraint_transforms (
734
+ outcome_constraints = outcome_constraints
735
+ )
705
736
X_pending = torch .rand (1 , 2 )
706
737
kwargs = c (
707
738
model = mm ,
708
739
training_data = self .blockX_blockY ,
709
740
objective_thresholds = objective_thresholds ,
710
741
objective = obj ,
711
- outcome_constraints = outcome_constraints ,
742
+ constraints = constraints ,
712
743
X_pending = X_pending ,
713
744
alpha = 0.05 ,
714
745
eta = 1e-2 ,
@@ -723,11 +754,7 @@ def test_construct_inputs_qEHVI(self):
723
754
Y_expected = mean [:1 ] * weights
724
755
self .assertTrue (torch .equal (partitioning ._neg_Y , - Y_expected ))
725
756
self .assertTrue (torch .equal (kwargs ["X_pending" ], X_pending ))
726
- cons_tfs = kwargs ["constraints" ]
727
- self .assertEqual (len (cons_tfs ), 1 )
728
- cons_eval = cons_tfs [0 ](mean )
729
- cons_eval_expected = torch .tensor ([- 0.25 , 0.5 ])
730
- self .assertTrue (torch .equal (cons_eval , cons_eval_expected ))
757
+ self .assertIs (kwargs ["constraints" ], constraints )
731
758
self .assertEqual (kwargs ["eta" ], 1e-2 )
732
759
733
760
# Test check for block designs
@@ -737,7 +764,7 @@ def test_construct_inputs_qEHVI(self):
737
764
training_data = self .multiX_multiY ,
738
765
objective_thresholds = objective_thresholds ,
739
766
objective = obj ,
740
- outcome_constraints = outcome_constraints ,
767
+ constraints = constraints ,
741
768
X_pending = X_pending ,
742
769
alpha = 0.05 ,
743
770
eta = 1e-2 ,
@@ -798,6 +825,9 @@ def test_construct_inputs_qNEHVI(self):
798
825
X_baseline = torch .rand (2 , 2 )
799
826
sampler = IIDNormalSampler (sample_shape = torch .Size ([4 ]))
800
827
outcome_constraints = (torch .tensor ([[0.0 , 1.0 ]]), torch .tensor ([[0.5 ]]))
828
+ constraints = get_outcome_constraint_transforms (
829
+ outcome_constraints = outcome_constraints
830
+ )
801
831
X_pending = torch .rand (1 , 2 )
802
832
kwargs = c (
803
833
model = mock_model ,
@@ -806,7 +836,7 @@ def test_construct_inputs_qNEHVI(self):
806
836
objective = objective ,
807
837
X_baseline = X_baseline ,
808
838
sampler = sampler ,
809
- outcome_constraints = outcome_constraints ,
839
+ constraints = constraints ,
810
840
X_pending = X_pending ,
811
841
eta = 1e-2 ,
812
842
prune_baseline = True ,
@@ -823,11 +853,7 @@ def test_construct_inputs_qNEHVI(self):
823
853
self .assertIsInstance (sampler_ , IIDNormalSampler )
824
854
self .assertEqual (sampler_ .sample_shape , torch .Size ([4 ]))
825
855
self .assertEqual (kwargs ["objective" ], objective )
826
- cons_tfs_expected = get_outcome_constraint_transforms (outcome_constraints )
827
- cons_tfs = kwargs ["constraints" ]
828
- self .assertEqual (len (cons_tfs ), 1 )
829
- test_Y = torch .rand (1 , 2 )
830
- self .assertTrue (torch .equal (cons_tfs [0 ](test_Y ), cons_tfs_expected [0 ](test_Y )))
856
+ self .assertIs (kwargs ["constraints" ], constraints )
831
857
self .assertTrue (torch .equal (kwargs ["X_pending" ], X_pending ))
832
858
self .assertEqual (kwargs ["eta" ], 1e-2 )
833
859
self .assertTrue (kwargs ["prune_baseline" ])
@@ -844,7 +870,7 @@ def test_construct_inputs_qNEHVI(self):
844
870
training_data = self .blockX_blockY ,
845
871
objective_thresholds = objective_thresholds ,
846
872
objective = MultiOutputExpectation (n_w = 3 ),
847
- outcome_constraints = outcome_constraints ,
873
+ constraints = constraints ,
848
874
)
849
875
for use_preprocessing in (True , False ):
850
876
obj = MultiOutputExpectation (
0 commit comments