Skip to content

Commit 52529e1

Browse files
sdaultonfacebook-github-bot
authored andcommitted
fix constraint handling in single objective MBM (#1973)
Summary: X-link: facebook/Ax#1771 Pull Request resolved: #1973 Currently, constraints are not used in single objective AFs in MBM due to a name mismatch between `outcome_constraints` and `constraints`. Reviewed By: SebastianAment Differential Revision: D48176978 fbshipit-source-id: 9495708002c11a874bb6b8c06327f0f4643039df
1 parent 3506538 commit 52529e1

File tree

4 files changed

+178
-97
lines changed

4 files changed

+178
-97
lines changed

botorch/acquisition/input_constructors.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@
9797
from botorch.optim.optimize import optimize_acqf
9898
from botorch.sampling.base import MCSampler
9999
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
100-
from botorch.utils.constraints import get_outcome_constraint_transforms
101100
from botorch.utils.containers import BotorchContainer
102101
from botorch.utils.datasets import SupervisedDataset
103102
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
@@ -718,7 +717,7 @@ def construct_inputs_qLogNEI(
718717
X_baseline=X_baseline,
719718
prune_baseline=prune_baseline,
720719
cache_root=cache_root,
721-
constraint=constraints,
720+
constraints=constraints,
722721
eta=eta,
723722
),
724723
"fat": fat,
@@ -853,11 +852,12 @@ def construct_inputs_EHVI(
853852
training_data: MaybeDict[SupervisedDataset],
854853
objective_thresholds: Tensor,
855854
objective: Optional[AnalyticMultiOutputObjective] = None,
855+
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
856856
**kwargs: Any,
857857
) -> Dict[str, Any]:
858858
r"""Construct kwargs for `ExpectedHypervolumeImprovement` constructor."""
859859
num_objectives = objective_thresholds.shape[0]
860-
if kwargs.get("outcome_constraints") is not None:
860+
if constraints is not None:
861861
raise NotImplementedError("EHVI does not yet support outcome constraints.")
862862

863863
X = _get_dataset_field(
@@ -914,6 +914,7 @@ def construct_inputs_qEHVI(
914914
training_data: MaybeDict[SupervisedDataset],
915915
objective_thresholds: Tensor,
916916
objective: Optional[MCMultiOutputObjective] = None,
917+
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
917918
**kwargs: Any,
918919
) -> Dict[str, Any]:
919920
r"""Construct kwargs for `qExpectedHypervolumeImprovement` constructor."""
@@ -928,15 +929,10 @@ def construct_inputs_qEHVI(
928929
# compute posterior mean (for ref point computation ref pareto frontier)
929930
with torch.no_grad():
930931
Y_pmean = model.posterior(X).mean
931-
932-
outcome_constraints = kwargs.pop("outcome_constraints", None)
933932
# For HV-based acquisition functions we pass the constraint transform directly
934-
if outcome_constraints is None:
935-
cons_tfs = None
936-
else:
937-
cons_tfs = get_outcome_constraint_transforms(outcome_constraints)
933+
if constraints is not None:
938934
# Adjust `Y_pmean` to contrain feasible points only.
939-
feas = torch.stack([c(Y_pmean) <= 0 for c in cons_tfs], dim=-1).all(dim=-1)
935+
feas = torch.stack([c(Y_pmean) <= 0 for c in constraints], dim=-1).all(dim=-1)
940936
Y_pmean = Y_pmean[feas]
941937

942938
if objective is None:
@@ -962,7 +958,7 @@ def construct_inputs_qEHVI(
962958
add_qehvi_kwargs = {
963959
"sampler": sampler,
964960
"X_pending": kwargs.get("X_pending"),
965-
"constraints": cons_tfs,
961+
"constraints": constraints,
966962
"eta": kwargs.get("eta", 1e-3),
967963
}
968964
return {**ehvi_kwargs, **add_qehvi_kwargs}
@@ -975,6 +971,7 @@ def construct_inputs_qNEHVI(
975971
objective_thresholds: Tensor,
976972
objective: Optional[MCMultiOutputObjective] = None,
977973
X_baseline: Optional[Tensor] = None,
974+
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
978975
**kwargs: Any,
979976
) -> Dict[str, Any]:
980977
r"""Construct kwargs for `qNoisyExpectedHypervolumeImprovement` constructor."""
@@ -991,16 +988,12 @@ def construct_inputs_qNEHVI(
991988
if objective is None:
992989
objective = IdentityMCMultiOutputObjective()
993990

994-
outcome_constraints = kwargs.pop("outcome_constraints", None)
995-
if outcome_constraints is None:
996-
cons_tfs = None
997-
else:
991+
if constraints is not None:
998992
if isinstance(objective, RiskMeasureMCObjective):
999993
raise UnsupportedError(
1000994
"Outcome constraints are not supported with risk measures. "
1001995
"Use a feasibility-weighted risk measure instead."
1002996
)
1003-
cons_tfs = get_outcome_constraint_transforms(outcome_constraints)
1004997

1005998
sampler = kwargs.get("sampler")
1006999
if sampler is None and isinstance(model, GPyTorchModel):
@@ -1021,7 +1014,7 @@ def construct_inputs_qNEHVI(
10211014
"X_baseline": X_baseline,
10221015
"sampler": sampler,
10231016
"objective": objective,
1024-
"constraints": cons_tfs,
1017+
"constraints": constraints,
10251018
"X_pending": kwargs.get("X_pending"),
10261019
"eta": kwargs.get("eta", 1e-3),
10271020
"prune_baseline": kwargs.get("prune_baseline", True),

botorch/acquisition/utils.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -306,13 +306,12 @@ def compute_best_feasible_objective(
306306
is_feasible = compute_feasibility_indicator(
307307
constraints=constraints, samples=samples
308308
) # sample_shape x batch_shape x q
309-
if is_feasible.any():
310-
obj = torch.where(is_feasible, obj, -torch.inf)
311-
with torch.no_grad():
312-
return obj.amax(dim=-1, keepdim=True)
309+
310+
if is_feasible.any(dim=-1).all():
311+
infeasible_value = -torch.inf
313312

314313
elif infeasible_obj is not None:
315-
return infeasible_obj.expand(*obj.shape[:-1], 1)
314+
infeasible_value = infeasible_obj.item()
316315

317316
else:
318317
if model is None:
@@ -323,12 +322,16 @@ def compute_best_feasible_objective(
323322
raise ValueError(
324323
"Must specify `X_baseline` when no feasible observation exists."
325324
)
326-
return _estimate_objective_lower_bound(
325+
infeasible_value = _estimate_objective_lower_bound(
327326
model=model,
328327
objective=objective,
329328
posterior_transform=posterior_transform,
330329
X=X_baseline,
331-
).expand(*obj.shape[:-1], 1)
330+
).item()
331+
332+
obj = torch.where(is_feasible, obj, infeasible_value)
333+
with torch.no_grad():
334+
return obj.amax(dim=-1, keepdim=True)
332335

333336

334337
def _estimate_objective_lower_bound(

test/acquisition/test_input_constructors.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,6 @@ def test_construct_inputs_qEI(self):
390390
self.assertTrue(torch.equal(kwargs["objective"].weights, objective.weights))
391391
self.assertTrue(torch.equal(kwargs["X_pending"], X_pending))
392392
self.assertIsNone(kwargs["sampler"])
393-
self.assertIsNone(kwargs["constraints"])
394393
self.assertIsInstance(kwargs["eta"], float)
395394
self.assertTrue(kwargs["eta"] < 1)
396395
multi_Y = torch.cat([d.Y() for d in self.blockX_multiY.values()], dim=-1)
@@ -406,6 +405,20 @@ def test_construct_inputs_qEI(self):
406405
best_f=best_f_expected,
407406
)
408407
self.assertEqual(kwargs["best_f"], best_f_expected)
408+
# test passing constraints
409+
outcome_constraints = (torch.tensor([[0.0, 1.0]]), torch.tensor([[0.5]]))
410+
constraints = get_outcome_constraint_transforms(
411+
outcome_constraints=outcome_constraints
412+
)
413+
kwargs = c(
414+
model=mock_model,
415+
training_data=self.blockX_multiY,
416+
objective=objective,
417+
X_pending=X_pending,
418+
best_f=best_f_expected,
419+
constraints=constraints,
420+
)
421+
self.assertIs(kwargs["constraints"], constraints)
409422

410423
# testing qLogEI input constructor
411424
log_constructor = get_acqf_input_constructor(qLogExpectedImprovement)
@@ -415,6 +428,7 @@ def test_construct_inputs_qEI(self):
415428
objective=objective,
416429
X_pending=X_pending,
417430
best_f=best_f_expected,
431+
constraints=constraints,
418432
)
419433
# includes strict superset of kwargs tested above
420434
self.assertTrue(kwargs.items() <= log_kwargs.items())
@@ -423,6 +437,7 @@ def test_construct_inputs_qEI(self):
423437
self.assertEqual(log_kwargs["tau_max"], TAU_MAX)
424438
self.assertTrue("tau_relu" in log_kwargs)
425439
self.assertEqual(log_kwargs["tau_relu"], TAU_RELU)
440+
self.assertIs(log_kwargs["constraints"], constraints)
426441

427442
def test_construct_inputs_qNEI(self):
428443
c = get_acqf_input_constructor(qNoisyExpectedImprovement)
@@ -441,29 +456,36 @@ def test_construct_inputs_qNEI(self):
441456
with self.assertRaisesRegex(ValueError, "Field `X` must be shared"):
442457
c(model=mock_model, training_data=self.multiX_multiY)
443458
X_baseline = torch.rand(2, 2)
459+
outcome_constraints = (torch.tensor([[0.0, 1.0]]), torch.tensor([[0.5]]))
460+
constraints = get_outcome_constraint_transforms(
461+
outcome_constraints=outcome_constraints
462+
)
444463
kwargs = c(
445464
model=mock_model,
446465
training_data=self.blockX_blockY,
447466
X_baseline=X_baseline,
448467
prune_baseline=False,
468+
constraints=constraints,
449469
)
450470
self.assertEqual(kwargs["model"], mock_model)
451471
self.assertIsNone(kwargs["objective"])
452472
self.assertIsNone(kwargs["X_pending"])
453473
self.assertIsNone(kwargs["sampler"])
454474
self.assertFalse(kwargs["prune_baseline"])
455475
self.assertTrue(torch.equal(kwargs["X_baseline"], X_baseline))
456-
self.assertIsNone(kwargs["constraints"])
457476
self.assertIsInstance(kwargs["eta"], float)
458477
self.assertTrue(kwargs["eta"] < 1)
478+
self.assertIs(kwargs["constraints"], constraints)
459479

460480
# testing qLogNEI input constructor
461481
log_constructor = get_acqf_input_constructor(qLogNoisyExpectedImprovement)
482+
462483
log_kwargs = log_constructor(
463484
model=mock_model,
464485
training_data=self.blockX_blockY,
465486
X_baseline=X_baseline,
466487
prune_baseline=False,
488+
constraints=constraints,
467489
)
468490
# includes strict superset of kwargs tested above
469491
self.assertTrue(kwargs.items() <= log_kwargs.items())
@@ -472,6 +494,7 @@ def test_construct_inputs_qNEI(self):
472494
self.assertEqual(log_kwargs["tau_max"], TAU_MAX)
473495
self.assertTrue("tau_relu" in log_kwargs)
474496
self.assertEqual(log_kwargs["tau_relu"], TAU_RELU)
497+
self.assertIs(log_kwargs["constraints"], constraints)
475498

476499
def test_construct_inputs_qPI(self):
477500
c = get_acqf_input_constructor(qProbabilityOfImprovement)
@@ -499,23 +522,28 @@ def test_construct_inputs_qPI(self):
499522
self.assertTrue(torch.equal(kwargs["X_pending"], X_pending))
500523
self.assertIsNone(kwargs["sampler"])
501524
self.assertEqual(kwargs["tau"], 1e-2)
502-
self.assertIsNone(kwargs["constraints"])
503525
self.assertIsInstance(kwargs["eta"], float)
504526
self.assertTrue(kwargs["eta"] < 1)
505527
multi_Y = torch.cat([d.Y() for d in self.blockX_multiY.values()], dim=-1)
506528
best_f_expected = objective(multi_Y).max()
507529
self.assertEqual(kwargs["best_f"], best_f_expected)
508530
# Check explicitly specifying `best_f`.
509531
best_f_expected = best_f_expected - 1 # Random value.
532+
outcome_constraints = (torch.tensor([[0.0, 1.0]]), torch.tensor([[0.5]]))
533+
constraints = get_outcome_constraint_transforms(
534+
outcome_constraints=outcome_constraints
535+
)
510536
kwargs = c(
511537
model=mock_model,
512538
training_data=self.blockX_multiY,
513539
objective=objective,
514540
X_pending=X_pending,
515541
tau=1e-2,
516542
best_f=best_f_expected,
543+
constraints=constraints,
517544
)
518545
self.assertEqual(kwargs["best_f"], best_f_expected)
546+
self.assertIs(kwargs["constraints"], constraints)
519547

520548
def test_construct_inputs_qUCB(self):
521549
c = get_acqf_input_constructor(qUpperConfidenceBound)
@@ -564,7 +592,7 @@ def test_construct_inputs_EHVI(self):
564592
model=mock_model,
565593
training_data=self.blockX_blockY,
566594
objective_thresholds=objective_thresholds,
567-
outcome_constraints=mock.Mock(),
595+
constraints=mock.Mock(),
568596
)
569597

570598
# test with Y_pmean supplied explicitly
@@ -702,13 +730,16 @@ def test_construct_inputs_qEHVI(self):
702730
weights = torch.rand(2)
703731
obj = WeightedMCMultiOutputObjective(weights=weights)
704732
outcome_constraints = (torch.tensor([[0.0, 1.0]]), torch.tensor([[0.5]]))
733+
constraints = get_outcome_constraint_transforms(
734+
outcome_constraints=outcome_constraints
735+
)
705736
X_pending = torch.rand(1, 2)
706737
kwargs = c(
707738
model=mm,
708739
training_data=self.blockX_blockY,
709740
objective_thresholds=objective_thresholds,
710741
objective=obj,
711-
outcome_constraints=outcome_constraints,
742+
constraints=constraints,
712743
X_pending=X_pending,
713744
alpha=0.05,
714745
eta=1e-2,
@@ -723,11 +754,7 @@ def test_construct_inputs_qEHVI(self):
723754
Y_expected = mean[:1] * weights
724755
self.assertTrue(torch.equal(partitioning._neg_Y, -Y_expected))
725756
self.assertTrue(torch.equal(kwargs["X_pending"], X_pending))
726-
cons_tfs = kwargs["constraints"]
727-
self.assertEqual(len(cons_tfs), 1)
728-
cons_eval = cons_tfs[0](mean)
729-
cons_eval_expected = torch.tensor([-0.25, 0.5])
730-
self.assertTrue(torch.equal(cons_eval, cons_eval_expected))
757+
self.assertIs(kwargs["constraints"], constraints)
731758
self.assertEqual(kwargs["eta"], 1e-2)
732759

733760
# Test check for block designs
@@ -737,7 +764,7 @@ def test_construct_inputs_qEHVI(self):
737764
training_data=self.multiX_multiY,
738765
objective_thresholds=objective_thresholds,
739766
objective=obj,
740-
outcome_constraints=outcome_constraints,
767+
constraints=constraints,
741768
X_pending=X_pending,
742769
alpha=0.05,
743770
eta=1e-2,
@@ -798,6 +825,9 @@ def test_construct_inputs_qNEHVI(self):
798825
X_baseline = torch.rand(2, 2)
799826
sampler = IIDNormalSampler(sample_shape=torch.Size([4]))
800827
outcome_constraints = (torch.tensor([[0.0, 1.0]]), torch.tensor([[0.5]]))
828+
constraints = get_outcome_constraint_transforms(
829+
outcome_constraints=outcome_constraints
830+
)
801831
X_pending = torch.rand(1, 2)
802832
kwargs = c(
803833
model=mock_model,
@@ -806,7 +836,7 @@ def test_construct_inputs_qNEHVI(self):
806836
objective=objective,
807837
X_baseline=X_baseline,
808838
sampler=sampler,
809-
outcome_constraints=outcome_constraints,
839+
constraints=constraints,
810840
X_pending=X_pending,
811841
eta=1e-2,
812842
prune_baseline=True,
@@ -823,11 +853,7 @@ def test_construct_inputs_qNEHVI(self):
823853
self.assertIsInstance(sampler_, IIDNormalSampler)
824854
self.assertEqual(sampler_.sample_shape, torch.Size([4]))
825855
self.assertEqual(kwargs["objective"], objective)
826-
cons_tfs_expected = get_outcome_constraint_transforms(outcome_constraints)
827-
cons_tfs = kwargs["constraints"]
828-
self.assertEqual(len(cons_tfs), 1)
829-
test_Y = torch.rand(1, 2)
830-
self.assertTrue(torch.equal(cons_tfs[0](test_Y), cons_tfs_expected[0](test_Y)))
856+
self.assertIs(kwargs["constraints"], constraints)
831857
self.assertTrue(torch.equal(kwargs["X_pending"], X_pending))
832858
self.assertEqual(kwargs["eta"], 1e-2)
833859
self.assertTrue(kwargs["prune_baseline"])
@@ -844,7 +870,7 @@ def test_construct_inputs_qNEHVI(self):
844870
training_data=self.blockX_blockY,
845871
objective_thresholds=objective_thresholds,
846872
objective=MultiOutputExpectation(n_w=3),
847-
outcome_constraints=outcome_constraints,
873+
constraints=constraints,
848874
)
849875
for use_preprocessing in (True, False):
850876
obj = MultiOutputExpectation(

0 commit comments

Comments
 (0)