Skip to content

Fix SCBO tutorial #2034

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 78 additions & 61 deletions botorch/generation/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,21 +228,19 @@ def forward(self, X: Tensor, num_samples: int = 1) -> Tensor:


class ConstrainedMaxPosteriorSampling(MaxPosteriorSampling):
r"""Sample from a set of points according to
their max posterior value,
which also likely meet a set of constraints
c1(x) <= 0, c2(x) <= 0, ..., cm(x) <= 0
c1, c2, ..., cm are black-box constraint functions
Each constraint function is modeled by a seperate
surrogate GP constraint model
We sample points for which the posterior value
for each constraint model <= 0,
as described in https://doi.org/10.48550/arxiv.2002.08526
r"""Constrained max posterior sampling.

Posterior sampling where we try to maximize an objective function while
simulatenously satisfying a set of constraints c1(x) <= 0, c2(x) <= 0,
..., cm(x) <= 0 where c1, c2, ..., cm are black-box constraint functions.
Each constraint function is modeled by a seperate GP model. We follow the
procedure as described in https://doi.org/10.48550/arxiv.2002.08526.

Example:
>>> CMPS = ConstrainedMaxPosteriorSampling(model,
constraint_model=ModelListGP(cmodel1, cmodel2,
..., cmodelm) # models w/ feature dim d=3
>>> CMPS = ConstrainedMaxPosteriorSampling(
model,
constraint_model=ModelListGP(cmodel1, cmodel2),
)
>>> X = torch.rand(2, 100, 3)
>>> sampled_X = CMPS(X, num_samples=5)
"""
Expand All @@ -254,82 +252,101 @@ def __init__(
objective: Optional[MCAcquisitionObjective] = None,
posterior_transform: Optional[PosteriorTransform] = None,
replacement: bool = True,
minimize_constraints_only: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we deprecate this arg so this doesn't cause hard failures in other code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be ok since this functionality probably has no users (that would have likely resulted in Github issues since it wasn't working as intended).

) -> None:
r"""Constructor for the SamplingStrategy base class.

Args:
model: A fitted model.
objective: The MCAcquisitionObjective under
which the samples are evaluated.
objective: The MCAcquisitionObjective under which the samples are evaluated.
Defaults to `IdentityMCObjective()`.
posterior_transform: An optional PosteriorTransform.
posterior_transform: An optional PosteriorTransform for the objective
function (corresponding to `model`).
replacement: If True, sample with replacement.
constraint_model: either a ModelListGP where each submodel
is a GP model for one constraint function,
or a MultiTaskGP model where each task is one
constraint function
All constraints are of the form c(x) <= 0.
In the case when the constraint model predicts
that all candidates violate constraints,
we pick the candidates with minimum violation.
minimize_constraints_only: False by default, if true,
we will automatically return the candidates
with minimum posterior constraint values,
(minimum predicted c(x) summed over all constraints)
reguardless of predicted objective values.
constraint_model: either a ModelListGP where each submodel is a GP model for
one constraint function, or a MultiTaskGP model where each task is one
constraint function. All constraints are of the form c(x) <= 0. In the
case when the constraint model predicts that all candidates
violate constraints, we pick the candidates with minimum violation.
"""
if objective is not None:
raise NotImplementedError(
"`objective` is not supported for `ConstrainedMaxPosteriorSampling`."
)

super().__init__(
model=model,
objective=objective,
posterior_transform=posterior_transform,
replacement=replacement,
)
self.constraint_model = constraint_model
self.minimize_constraints_only = minimize_constraints_only

def _convert_samples_to_scores(self, Y_samples, C_samples) -> Tensor:
r"""Convert the objective and constraint samples into a score.

The logic is as follows:
- If a realization has at least one feasible candidate we use the objective
value as the score and set all infeasible candidates to -inf.
- If a realization doesn't have a feasible candidate we set the score to
the negative total violation of the constraints to incentivize choosing
the candidate with the smallest constraint violation.

Args:
Y_samples: A `num_samples x batch_shape x num_cand x 1`-dim Tensor of
samples from the objective function.
C_samples: A `num_samples x batch_shape x num_cand x num_constraints`-dim
Tensor of samples from the constraints.

Returns:
A `num_samples x batch_shape x num_cand x 1`-dim Tensor of scores.
"""
is_feasible = (C_samples <= 0).all(
dim=-1
) # num_samples x batch_shape x num_cand
has_feasible_candidate = is_feasible.any(dim=-1)

scores = Y_samples.clone()
scores[~is_feasible] = -float("inf")
if not has_feasible_candidate.all():
# Use negative total violation for samples where no candidate is feasible
total_violation = (
C_samples[~has_feasible_candidate]
.clamp(min=0)
.sum(dim=-1, keepdim=True)
)
scores[~has_feasible_candidate] = -total_violation
return scores

def forward(
self, X: Tensor, num_samples: int = 1, observation_noise: bool = False
) -> Tensor:
r"""Sample from the model posterior.

Args:
X: A `batch_shape x N x d`-dim Tensor
from which to sample (in the `N`
dimension) according to the maximum
posterior value under the objective.
X: A `batch_shape x N x d`-dim Tensor from which to sample (in the `N`
dimension) according to the maximum posterior value under the objective.
num_samples: The number of samples to draw.
observation_noise: If True, sample with observation noise.

Returns:
A `batch_shape x num_samples x d`-dim
Tensor of samples from `X`, where
`X[..., i, :]` is the `i`-th sample.
A `batch_shape x num_samples x d`-dim Tensor of samples from `X`, where
`X[..., i, :]` is the `i`-th sample.
"""
posterior = self.model.posterior(X, observation_noise=observation_noise)
samples = posterior.rsample(sample_shape=torch.Size([num_samples]))
posterior = self.model.posterior(
X=X,
observation_noise=observation_noise,
# Note: `posterior_transform` is only used for the objective
posterior_transform=self.posterior_transform,
)
Y_samples = posterior.rsample(sample_shape=torch.Size([num_samples]))

c_posterior = self.constraint_model.posterior(
X, observation_noise=observation_noise
)
constraint_samples = c_posterior.rsample(sample_shape=torch.Size([num_samples]))
valid_samples = constraint_samples <= 0
if valid_samples.shape[-1] > 1: # if more than one constraint
valid_samples = torch.all(valid_samples, dim=-1).unsqueeze(-1)
if (valid_samples.sum() == 0) or self.minimize_constraints_only:
# if none of the samples meet the constraints
# we pick the one that minimizes total violation
constraint_samples = constraint_samples.sum(dim=-1)
idcs = torch.argmin(constraint_samples, dim=-1)
if idcs.ndim > 1:
idcs = idcs.permute(*range(1, idcs.ndim), 0)
idcs = idcs.unsqueeze(-1).expand(*idcs.shape, X.size(-1))
Xe = X.expand(*constraint_samples.shape[1:], X.size(-1))
return torch.gather(Xe, -2, idcs)
# replace all violators with -infinty so it will never choose them
replacement_infs = -torch.inf * torch.ones(samples.shape).to(X.device).to(
X.dtype
X=X, observation_noise=observation_noise
)
samples = torch.where(valid_samples, samples, replacement_infs)
C_samples = c_posterior.rsample(sample_shape=torch.Size([num_samples]))

return self.maximize_samples(X, samples, num_samples)
# Convert the objective and constraint samples into a scalar-valued "score"
scores = self._convert_samples_to_scores(
Y_samples=Y_samples, C_samples=C_samples
)
return self.maximize_samples(X=X, samples=scores, num_samples=num_samples)
118 changes: 71 additions & 47 deletions test/generation/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,24 +154,33 @@ class TestConstrainedMaxPosteriorSampling(BotorchTestCase):
def test_init(self):
mm = MockModel(MockPosterior(mean=None))
cmms = MockModel(MockPosterior(mean=None))
MPS = ConstrainedMaxPosteriorSampling(mm, cmms)
self.assertEqual(MPS.model, mm)
self.assertTrue(MPS.replacement)
self.assertIsInstance(MPS.objective, IdentityMCObjective)
for replacement in (True, False):
MPS = ConstrainedMaxPosteriorSampling(mm, cmms, replacement=replacement)
self.assertEqual(MPS.model, mm)
self.assertEqual(MPS.replacement, replacement)
self.assertIsInstance(MPS.objective, IdentityMCObjective)

obj = LinearMCObjective(torch.rand(2))
MPS = ConstrainedMaxPosteriorSampling(
mm, cmms, objective=obj, replacement=False
)
self.assertEqual(MPS.objective, obj)
self.assertFalse(MPS.replacement)
with self.assertRaisesRegex(
NotImplementedError, "`objective` is not supported"
):
ConstrainedMaxPosteriorSampling(mm, cmms, objective=obj, replacement=False)

def test_constrained_max_posterior_sampling(self):
batch_shapes = (torch.Size(), torch.Size([3]), torch.Size([3, 2]))
dtypes = (torch.float, torch.double)
for batch_shape, dtype, N, num_samples, d in itertools.product(
batch_shapes, dtypes, (5, 6), (1, 2), (1, 2)
for (
batch_shape,
dtype,
N,
num_samples,
d,
observation_noise,
) in itertools.product(
batch_shapes, dtypes, (5, 6), (1, 2), (1, 2), (True, False)
):
tkwargs = {"device": self.device, "dtype": dtype}
expected_shape = torch.Size(list(batch_shape) + [num_samples] + [d])
# X is `batch_shape x N x d` = batch_shape x N x 1.
X = torch.randn(*batch_shape, N, d, **tkwargs)
# the event shape is `num_samples x batch_shape x N x m`
Expand All @@ -196,42 +205,57 @@ def test_constrained_max_posterior_sampling(self):
cmms2 = ModelListGP(c_model1, c_model2)
cmms3 = ModelListGP(c_model1, c_model2, c_model3)
for cmms in [cmms1, cmms2, cmms3]:
MPS = ConstrainedMaxPosteriorSampling(mm, cmms)
s1 = MPS(X, num_samples=num_samples)
# run again with minimize_constraints_only
MPS = ConstrainedMaxPosteriorSampling(
mm, cmms, minimize_constraints_only=True
CPS = ConstrainedMaxPosteriorSampling(mm, cmms)
s1 = CPS(
X=X,
num_samples=num_samples,
observation_noise=observation_noise,
)
s2 = MPS(X, num_samples=num_samples)
assert s1.shape == s2.shape
self.assertEqual(s1.shape, expected_shape)

# ScalarizedPosteriorTransform w/ replacement
with mock.patch.object(MockPosterior, "rsample", return_value=psamples):
mp = MockPosterior(None)
with mock.patch.object(MockModel, "posterior", return_value=mp):
mm = MockModel(None)
cmms = MockModel(None)
with mock.patch.object(
ScalarizedPosteriorTransform, "forward", return_value=mp
):
post_tf = ScalarizedPosteriorTransform(torch.rand(2, **tkwargs))
MPS = ConstrainedMaxPosteriorSampling(
mm, cmms, posterior_transform=post_tf
)
s = MPS(X, num_samples=num_samples)
self.assertTrue(s.shape[-2] == num_samples)
# Test selection (_convert_samples_to_scores is tested separately)
m_model = SingleTaskGP(
X, torch.randn(X.shape[0:-1], **tkwargs).unsqueeze(-1)
)
cmms = cmms2
with torch.random.fork_rng():
torch.manual_seed(123)
Y = m_model.posterior(X=X, observation_noise=observation_noise).rsample(
sample_shape=torch.Size([num_samples])
)
C = cmms.posterior(X=X, observation_noise=observation_noise).rsample(
sample_shape=torch.Size([num_samples])
)
scores = CPS._convert_samples_to_scores(Y_samples=Y, C_samples=C)
X_true = CPS.maximize_samples(
X=X, samples=scores, num_samples=num_samples
)

# without replacement
psamples[..., 1, 0] = 1e-6
with mock.patch.object(MockPosterior, "rsample", return_value=psamples):
mp = MockPosterior(None)
with mock.patch.object(MockModel, "posterior", return_value=mp):
mm = MockModel(None)
cmms = MockModel(None)
MPS = ConstrainedMaxPosteriorSampling(mm, cmms, replacement=False)
if len(batch_shape) > 1:
with self.assertRaises(NotImplementedError):
MPS(X, num_samples=num_samples)
else:
s = MPS(X, num_samples=num_samples)
self.assertTrue(s.shape[-2] == num_samples)
torch.manual_seed(123)
CPS = ConstrainedMaxPosteriorSampling(m_model, cmms)
X_cand = CPS(
X=X,
num_samples=num_samples,
observation_noise=observation_noise,
)
self.assertAllClose(X_true, X_cand)

# Test `_convert_samples_to_scores`
N, num_constraints, batch_shape = 10, 3, torch.Size([2])
X = torch.randn(*batch_shape, N, d, **tkwargs)
Y_samples = torch.rand(num_samples, *batch_shape, N, 1, **tkwargs)
C_samples = -torch.rand(
num_samples, *batch_shape, N, num_constraints, **tkwargs
)

Y_samples[0, 0, 3] = 1.234
C_samples[0, 1, 1:, 1] = 0.123 + torch.arange(N - 1, **tkwargs)
C_samples[1, 0, :, :] = 1 + (torch.arange(N).unsqueeze(-1) - N // 2) ** 2
Y_samples[1, 1, 7] = 10
scores = ConstrainedMaxPosteriorSampling(
m_model, cmms
)._convert_samples_to_scores(Y_samples=Y_samples, C_samples=C_samples)
self.assertEqual(scores[0, 0].argmax().item(), 3)
self.assertEqual(scores[0, 1].argmax().item(), 0)
self.assertEqual(scores[1, 0].argmax().item(), N // 2)
self.assertEqual(scores[1, 1].argmax().item(), 7)
Loading