Skip to content

Commit 5e3677f

Browse files
dme65facebook-github-bot
authored andcommitted
Fix SCBO tutorial (#2034)
Summary: This fixes several constraint handling bugs in the SCBO tutorial as reported in #2031, #2032, and #2033. In addition, it also fixes a bug and cleans up some of the logic in `ConstrainedMaxPosteriorSampling`. The previous problem was too easy, so we never actually tested some of the aspects of the constraint handling. I switched to the same version of the Ackley problem as considered in https://arxiv.org/pdf/2002.08526.pdf and I'm observing comparable performance. Pull Request resolved: #2034 Reviewed By: Balandat Differential Revision: D49958875 Pulled By: dme65 fbshipit-source-id: 72acb2c7fbbc598bbdc0d46529136e3173df3f3a
1 parent dcb2ba4 commit 5e3677f

File tree

3 files changed

+709
-788
lines changed

3 files changed

+709
-788
lines changed

botorch/generation/sampling.py

Lines changed: 78 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -228,21 +228,19 @@ def forward(self, X: Tensor, num_samples: int = 1) -> Tensor:
228228

229229

230230
class ConstrainedMaxPosteriorSampling(MaxPosteriorSampling):
231-
r"""Sample from a set of points according to
232-
their max posterior value,
233-
which also likely meet a set of constraints
234-
c1(x) <= 0, c2(x) <= 0, ..., cm(x) <= 0
235-
c1, c2, ..., cm are black-box constraint functions
236-
Each constraint function is modeled by a seperate
237-
surrogate GP constraint model
238-
We sample points for which the posterior value
239-
for each constraint model <= 0,
240-
as described in https://doi.org/10.48550/arxiv.2002.08526
231+
r"""Constrained max posterior sampling.
232+
233+
Posterior sampling where we try to maximize an objective function while
234+
simulatenously satisfying a set of constraints c1(x) <= 0, c2(x) <= 0,
235+
..., cm(x) <= 0 where c1, c2, ..., cm are black-box constraint functions.
236+
Each constraint function is modeled by a seperate GP model. We follow the
237+
procedure as described in https://doi.org/10.48550/arxiv.2002.08526.
241238
242239
Example:
243-
>>> CMPS = ConstrainedMaxPosteriorSampling(model,
244-
constraint_model=ModelListGP(cmodel1, cmodel2,
245-
..., cmodelm) # models w/ feature dim d=3
240+
>>> CMPS = ConstrainedMaxPosteriorSampling(
241+
model,
242+
constraint_model=ModelListGP(cmodel1, cmodel2),
243+
)
246244
>>> X = torch.rand(2, 100, 3)
247245
>>> sampled_X = CMPS(X, num_samples=5)
248246
"""
@@ -254,82 +252,101 @@ def __init__(
254252
objective: Optional[MCAcquisitionObjective] = None,
255253
posterior_transform: Optional[PosteriorTransform] = None,
256254
replacement: bool = True,
257-
minimize_constraints_only: bool = False,
258255
) -> None:
259256
r"""Constructor for the SamplingStrategy base class.
260257
261258
Args:
262259
model: A fitted model.
263-
objective: The MCAcquisitionObjective under
264-
which the samples are evaluated.
260+
objective: The MCAcquisitionObjective under which the samples are evaluated.
265261
Defaults to `IdentityMCObjective()`.
266-
posterior_transform: An optional PosteriorTransform.
262+
posterior_transform: An optional PosteriorTransform for the objective
263+
function (corresponding to `model`).
267264
replacement: If True, sample with replacement.
268-
constraint_model: either a ModelListGP where each submodel
269-
is a GP model for one constraint function,
270-
or a MultiTaskGP model where each task is one
271-
constraint function
272-
All constraints are of the form c(x) <= 0.
273-
In the case when the constraint model predicts
274-
that all candidates violate constraints,
275-
we pick the candidates with minimum violation.
276-
minimize_constraints_only: False by default, if true,
277-
we will automatically return the candidates
278-
with minimum posterior constraint values,
279-
(minimum predicted c(x) summed over all constraints)
280-
reguardless of predicted objective values.
265+
constraint_model: either a ModelListGP where each submodel is a GP model for
266+
one constraint function, or a MultiTaskGP model where each task is one
267+
constraint function. All constraints are of the form c(x) <= 0. In the
268+
case when the constraint model predicts that all candidates
269+
violate constraints, we pick the candidates with minimum violation.
281270
"""
271+
if objective is not None:
272+
raise NotImplementedError(
273+
"`objective` is not supported for `ConstrainedMaxPosteriorSampling`."
274+
)
275+
282276
super().__init__(
283277
model=model,
284278
objective=objective,
285279
posterior_transform=posterior_transform,
286280
replacement=replacement,
287281
)
288282
self.constraint_model = constraint_model
289-
self.minimize_constraints_only = minimize_constraints_only
283+
284+
def _convert_samples_to_scores(self, Y_samples, C_samples) -> Tensor:
285+
r"""Convert the objective and constraint samples into a score.
286+
287+
The logic is as follows:
288+
- If a realization has at least one feasible candidate we use the objective
289+
value as the score and set all infeasible candidates to -inf.
290+
- If a realization doesn't have a feasible candidate we set the score to
291+
the negative total violation of the constraints to incentivize choosing
292+
the candidate with the smallest constraint violation.
293+
294+
Args:
295+
Y_samples: A `num_samples x batch_shape x num_cand x 1`-dim Tensor of
296+
samples from the objective function.
297+
C_samples: A `num_samples x batch_shape x num_cand x num_constraints`-dim
298+
Tensor of samples from the constraints.
299+
300+
Returns:
301+
A `num_samples x batch_shape x num_cand x 1`-dim Tensor of scores.
302+
"""
303+
is_feasible = (C_samples <= 0).all(
304+
dim=-1
305+
) # num_samples x batch_shape x num_cand
306+
has_feasible_candidate = is_feasible.any(dim=-1)
307+
308+
scores = Y_samples.clone()
309+
scores[~is_feasible] = -float("inf")
310+
if not has_feasible_candidate.all():
311+
# Use negative total violation for samples where no candidate is feasible
312+
total_violation = (
313+
C_samples[~has_feasible_candidate]
314+
.clamp(min=0)
315+
.sum(dim=-1, keepdim=True)
316+
)
317+
scores[~has_feasible_candidate] = -total_violation
318+
return scores
290319

291320
def forward(
292321
self, X: Tensor, num_samples: int = 1, observation_noise: bool = False
293322
) -> Tensor:
294323
r"""Sample from the model posterior.
295324
296325
Args:
297-
X: A `batch_shape x N x d`-dim Tensor
298-
from which to sample (in the `N`
299-
dimension) according to the maximum
300-
posterior value under the objective.
326+
X: A `batch_shape x N x d`-dim Tensor from which to sample (in the `N`
327+
dimension) according to the maximum posterior value under the objective.
301328
num_samples: The number of samples to draw.
302329
observation_noise: If True, sample with observation noise.
303330
304331
Returns:
305-
A `batch_shape x num_samples x d`-dim
306-
Tensor of samples from `X`, where
307-
`X[..., i, :]` is the `i`-th sample.
332+
A `batch_shape x num_samples x d`-dim Tensor of samples from `X`, where
333+
`X[..., i, :]` is the `i`-th sample.
308334
"""
309-
posterior = self.model.posterior(X, observation_noise=observation_noise)
310-
samples = posterior.rsample(sample_shape=torch.Size([num_samples]))
335+
posterior = self.model.posterior(
336+
X=X,
337+
observation_noise=observation_noise,
338+
# Note: `posterior_transform` is only used for the objective
339+
posterior_transform=self.posterior_transform,
340+
)
341+
Y_samples = posterior.rsample(sample_shape=torch.Size([num_samples]))
311342

312343
c_posterior = self.constraint_model.posterior(
313-
X, observation_noise=observation_noise
314-
)
315-
constraint_samples = c_posterior.rsample(sample_shape=torch.Size([num_samples]))
316-
valid_samples = constraint_samples <= 0
317-
if valid_samples.shape[-1] > 1: # if more than one constraint
318-
valid_samples = torch.all(valid_samples, dim=-1).unsqueeze(-1)
319-
if (valid_samples.sum() == 0) or self.minimize_constraints_only:
320-
# if none of the samples meet the constraints
321-
# we pick the one that minimizes total violation
322-
constraint_samples = constraint_samples.sum(dim=-1)
323-
idcs = torch.argmin(constraint_samples, dim=-1)
324-
if idcs.ndim > 1:
325-
idcs = idcs.permute(*range(1, idcs.ndim), 0)
326-
idcs = idcs.unsqueeze(-1).expand(*idcs.shape, X.size(-1))
327-
Xe = X.expand(*constraint_samples.shape[1:], X.size(-1))
328-
return torch.gather(Xe, -2, idcs)
329-
# replace all violators with -infinty so it will never choose them
330-
replacement_infs = -torch.inf * torch.ones(samples.shape).to(X.device).to(
331-
X.dtype
344+
X=X, observation_noise=observation_noise
332345
)
333-
samples = torch.where(valid_samples, samples, replacement_infs)
346+
C_samples = c_posterior.rsample(sample_shape=torch.Size([num_samples]))
334347

335-
return self.maximize_samples(X, samples, num_samples)
348+
# Convert the objective and constraint samples into a scalar-valued "score"
349+
scores = self._convert_samples_to_scores(
350+
Y_samples=Y_samples, C_samples=C_samples
351+
)
352+
return self.maximize_samples(X=X, samples=scores, num_samples=num_samples)

test/generation/test_sampling.py

Lines changed: 71 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -154,24 +154,33 @@ class TestConstrainedMaxPosteriorSampling(BotorchTestCase):
154154
def test_init(self):
155155
mm = MockModel(MockPosterior(mean=None))
156156
cmms = MockModel(MockPosterior(mean=None))
157-
MPS = ConstrainedMaxPosteriorSampling(mm, cmms)
158-
self.assertEqual(MPS.model, mm)
159-
self.assertTrue(MPS.replacement)
160-
self.assertIsInstance(MPS.objective, IdentityMCObjective)
157+
for replacement in (True, False):
158+
MPS = ConstrainedMaxPosteriorSampling(mm, cmms, replacement=replacement)
159+
self.assertEqual(MPS.model, mm)
160+
self.assertEqual(MPS.replacement, replacement)
161+
self.assertIsInstance(MPS.objective, IdentityMCObjective)
162+
161163
obj = LinearMCObjective(torch.rand(2))
162-
MPS = ConstrainedMaxPosteriorSampling(
163-
mm, cmms, objective=obj, replacement=False
164-
)
165-
self.assertEqual(MPS.objective, obj)
166-
self.assertFalse(MPS.replacement)
164+
with self.assertRaisesRegex(
165+
NotImplementedError, "`objective` is not supported"
166+
):
167+
ConstrainedMaxPosteriorSampling(mm, cmms, objective=obj, replacement=False)
167168

168169
def test_constrained_max_posterior_sampling(self):
169170
batch_shapes = (torch.Size(), torch.Size([3]), torch.Size([3, 2]))
170171
dtypes = (torch.float, torch.double)
171-
for batch_shape, dtype, N, num_samples, d in itertools.product(
172-
batch_shapes, dtypes, (5, 6), (1, 2), (1, 2)
172+
for (
173+
batch_shape,
174+
dtype,
175+
N,
176+
num_samples,
177+
d,
178+
observation_noise,
179+
) in itertools.product(
180+
batch_shapes, dtypes, (5, 6), (1, 2), (1, 2), (True, False)
173181
):
174182
tkwargs = {"device": self.device, "dtype": dtype}
183+
expected_shape = torch.Size(list(batch_shape) + [num_samples] + [d])
175184
# X is `batch_shape x N x d` = batch_shape x N x 1.
176185
X = torch.randn(*batch_shape, N, d, **tkwargs)
177186
# the event shape is `num_samples x batch_shape x N x m`
@@ -196,42 +205,57 @@ def test_constrained_max_posterior_sampling(self):
196205
cmms2 = ModelListGP(c_model1, c_model2)
197206
cmms3 = ModelListGP(c_model1, c_model2, c_model3)
198207
for cmms in [cmms1, cmms2, cmms3]:
199-
MPS = ConstrainedMaxPosteriorSampling(mm, cmms)
200-
s1 = MPS(X, num_samples=num_samples)
201-
# run again with minimize_constraints_only
202-
MPS = ConstrainedMaxPosteriorSampling(
203-
mm, cmms, minimize_constraints_only=True
208+
CPS = ConstrainedMaxPosteriorSampling(mm, cmms)
209+
s1 = CPS(
210+
X=X,
211+
num_samples=num_samples,
212+
observation_noise=observation_noise,
204213
)
205-
s2 = MPS(X, num_samples=num_samples)
206-
assert s1.shape == s2.shape
214+
self.assertEqual(s1.shape, expected_shape)
207215

208-
# ScalarizedPosteriorTransform w/ replacement
209-
with mock.patch.object(MockPosterior, "rsample", return_value=psamples):
210-
mp = MockPosterior(None)
211-
with mock.patch.object(MockModel, "posterior", return_value=mp):
212-
mm = MockModel(None)
213-
cmms = MockModel(None)
214-
with mock.patch.object(
215-
ScalarizedPosteriorTransform, "forward", return_value=mp
216-
):
217-
post_tf = ScalarizedPosteriorTransform(torch.rand(2, **tkwargs))
218-
MPS = ConstrainedMaxPosteriorSampling(
219-
mm, cmms, posterior_transform=post_tf
220-
)
221-
s = MPS(X, num_samples=num_samples)
222-
self.assertTrue(s.shape[-2] == num_samples)
216+
# Test selection (_convert_samples_to_scores is tested separately)
217+
m_model = SingleTaskGP(
218+
X, torch.randn(X.shape[0:-1], **tkwargs).unsqueeze(-1)
219+
)
220+
cmms = cmms2
221+
with torch.random.fork_rng():
222+
torch.manual_seed(123)
223+
Y = m_model.posterior(X=X, observation_noise=observation_noise).rsample(
224+
sample_shape=torch.Size([num_samples])
225+
)
226+
C = cmms.posterior(X=X, observation_noise=observation_noise).rsample(
227+
sample_shape=torch.Size([num_samples])
228+
)
229+
scores = CPS._convert_samples_to_scores(Y_samples=Y, C_samples=C)
230+
X_true = CPS.maximize_samples(
231+
X=X, samples=scores, num_samples=num_samples
232+
)
223233

224-
# without replacement
225-
psamples[..., 1, 0] = 1e-6
226-
with mock.patch.object(MockPosterior, "rsample", return_value=psamples):
227-
mp = MockPosterior(None)
228-
with mock.patch.object(MockModel, "posterior", return_value=mp):
229-
mm = MockModel(None)
230-
cmms = MockModel(None)
231-
MPS = ConstrainedMaxPosteriorSampling(mm, cmms, replacement=False)
232-
if len(batch_shape) > 1:
233-
with self.assertRaises(NotImplementedError):
234-
MPS(X, num_samples=num_samples)
235-
else:
236-
s = MPS(X, num_samples=num_samples)
237-
self.assertTrue(s.shape[-2] == num_samples)
234+
torch.manual_seed(123)
235+
CPS = ConstrainedMaxPosteriorSampling(m_model, cmms)
236+
X_cand = CPS(
237+
X=X,
238+
num_samples=num_samples,
239+
observation_noise=observation_noise,
240+
)
241+
self.assertAllClose(X_true, X_cand)
242+
243+
# Test `_convert_samples_to_scores`
244+
N, num_constraints, batch_shape = 10, 3, torch.Size([2])
245+
X = torch.randn(*batch_shape, N, d, **tkwargs)
246+
Y_samples = torch.rand(num_samples, *batch_shape, N, 1, **tkwargs)
247+
C_samples = -torch.rand(
248+
num_samples, *batch_shape, N, num_constraints, **tkwargs
249+
)
250+
251+
Y_samples[0, 0, 3] = 1.234
252+
C_samples[0, 1, 1:, 1] = 0.123 + torch.arange(N - 1, **tkwargs)
253+
C_samples[1, 0, :, :] = 1 + (torch.arange(N).unsqueeze(-1) - N // 2) ** 2
254+
Y_samples[1, 1, 7] = 10
255+
scores = ConstrainedMaxPosteriorSampling(
256+
m_model, cmms
257+
)._convert_samples_to_scores(Y_samples=Y_samples, C_samples=C_samples)
258+
self.assertEqual(scores[0, 0].argmax().item(), 3)
259+
self.assertEqual(scores[0, 1].argmax().item(), 0)
260+
self.assertEqual(scores[1, 0].argmax().item(), N // 2)
261+
self.assertEqual(scores[1, 1].argmax().item(), 7)

0 commit comments

Comments
 (0)