Skip to content

Commit 04d8f05

Browse files
esantorellafacebook-github-bot
authored andcommitted
Remove some usages of _filter_kwargs and code that triggers warnings from _filter_kwargs (#1895)
Summary: ## Motivation Unused keyword arguments are a perennial bugbear. All else equal, it's better for a function to error upon receiving an unused keyword argument rather than ignoring it. `_filter_kwargs` gives a warning rather than an error if unused keyword arguments are passed. The warnings were cluttering unit test output and annoying me, so I removed usage of `_filter_kwargs` where it was straightforward to do so. I left a couple calls that would have been tricky to change. I also updated a tutorial to not use the deprected `fit_gpytorch_torch` function, which tends to generate a lot of `_filter_kwargs` warnings. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: #1895 Test Plan: Units Reviewed By: SebastianAment Differential Revision: D46831065 Pulled By: esantorella fbshipit-source-id: db8843cbe7e7649afe57472a3c6fae60c6e33fdb
1 parent 586a53a commit 04d8f05

File tree

7 files changed

+737
-699
lines changed

7 files changed

+737
-699
lines changed

botorch/generation/gen.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
NLC_TOL,
3030
)
3131
from botorch.optim.stopping import ExpMAStoppingCriterion
32-
from botorch.optim.utils import _filter_kwargs, columnwise_clamp, fix_features
32+
from botorch.optim.utils import columnwise_clamp, fix_features
3333
from botorch.optim.utils.timeout import minimize_with_timeout
3434
from scipy.optimize import OptimizeResult
3535
from torch import Tensor
@@ -367,9 +367,7 @@ def gen_candidates_torch(
367367

368368
i = 0
369369
stop = False
370-
stopping_criterion = ExpMAStoppingCriterion(
371-
**_filter_kwargs(ExpMAStoppingCriterion, **options)
372-
)
370+
stopping_criterion = ExpMAStoppingCriterion(**options)
373371
while not stop:
374372
i += 1
375373
with torch.no_grad():

botorch/optim/optimize.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -282,11 +282,6 @@ def _optimize_acqf_batch(opt_inputs: OptimizeAcqfInputs) -> Tuple[Tensor, Tensor
282282
if not opt_inputs.nonlinear_inequality_constraints
283283
else 1,
284284
)
285-
has_parameter_constraints = (
286-
opt_inputs.inequality_constraints is not None
287-
or opt_inputs.equality_constraints is not None
288-
or opt_inputs.nonlinear_inequality_constraints is not None
289-
)
290285

291286
def _optimize_batch_candidates() -> Tuple[Tensor, Tensor, List[Warning]]:
292287
batch_candidates_list: List[Tensor] = []
@@ -308,16 +303,17 @@ def _optimize_batch_candidates() -> Tuple[Tensor, Tensor, List[Warning]]:
308303
"timeout_sec": timeout_sec,
309304
}
310305

311-
if has_parameter_constraints:
312-
gen_kwargs.update(
313-
{
314-
"inequality_constraints": opt_inputs.inequality_constraints,
315-
"equality_constraints": opt_inputs.equality_constraints,
316-
"nonlinear_inequality_constraints": (
317-
opt_inputs.nonlinear_inequality_constraints
318-
),
319-
}
320-
)
306+
# only add parameter constraints to gen_kwargs if they are specified
307+
# to avoid unnecessary warnings in _filter_kwargs
308+
if opt_inputs.inequality_constraints is not None:
309+
gen_kwargs["inequality_constraints"] = opt_inputs.inequality_constraints
310+
if opt_inputs.equality_constraints is not None:
311+
gen_kwargs["equality_constraints"] = opt_inputs.equality_constraints
312+
if opt_inputs.nonlinear_inequality_constraints is not None:
313+
gen_kwargs[
314+
"nonlinear_inequality_constraints"
315+
] = opt_inputs.nonlinear_inequality_constraints
316+
321317
filtered_gen_kwargs = _filter_kwargs(opt_inputs.gen_candidates, **gen_kwargs)
322318

323319
for i, batched_ics_ in enumerate(batched_ics):

botorch/optim/stopping.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from __future__ import annotations
88

9-
import typing # noqa F401
109
from abc import ABC, abstractmethod
1110

1211
import torch

test/generation/test_gen.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,7 @@ def test_gen_candidates(
105105
self.assertTrue(-EPS <= candidates <= 1 + EPS)
106106

107107
def test_gen_candidates_torch(self):
108-
self.test_gen_candidates(
109-
gen_candidates=gen_candidates_torch, options={"disp": False}
110-
)
108+
self.test_gen_candidates(gen_candidates=gen_candidates_torch)
111109

112110
def test_gen_candidates_with_none_fixed_features(
113111
self,
@@ -144,7 +142,7 @@ def test_gen_candidates_with_none_fixed_features(
144142

145143
def test_gen_candidates_torch_with_none_fixed_features(self):
146144
self.test_gen_candidates_with_none_fixed_features(
147-
gen_candidates=gen_candidates_torch, options={"disp": False}
145+
gen_candidates=gen_candidates_torch
148146
)
149147

150148
def test_gen_candidates_with_fixed_features(
@@ -184,21 +182,20 @@ def test_gen_candidates_with_fixed_features(
184182
def test_gen_candidates_with_fixed_features_and_timeout(self):
185183
with self.assertLogs("botorch", level="INFO") as logs:
186184
self.test_gen_candidates_with_fixed_features(
187-
options={"disp": False},
188185
timeout_sec=1e-4,
186+
options={"disp": False},
189187
)
190188
self.assertTrue(any("Optimization timed out" in o for o in logs.output))
191189

192190
def test_gen_candidates_torch_with_fixed_features(self):
193191
self.test_gen_candidates_with_fixed_features(
194-
gen_candidates=gen_candidates_torch, options={"disp": False}
192+
gen_candidates=gen_candidates_torch
195193
)
196194

197195
def test_gen_candidates_torch_with_fixed_features_and_timeout(self):
198196
with self.assertLogs("botorch", level="INFO") as logs:
199197
self.test_gen_candidates_with_fixed_features(
200198
gen_candidates=gen_candidates_torch,
201-
options={"disp": False},
202199
timeout_sec=1e-4,
203200
)
204201
self.assertTrue(any("Optimization timed out" in o for o in logs.output))
@@ -335,23 +332,23 @@ def test_gen_candidates_scipy_nan_handling(self):
335332
acquisition_function=mock.Mock(),
336333
)
337334

338-
def test_gen_candidates_without_grad(self):
335+
def test_gen_candidates_without_grad(self) -> None:
336+
"""Test with `with_grad=False` (not supported for gen_candidates_torch)."""
339337

340-
for gen_candidates in (gen_candidates_scipy, gen_candidates_torch):
341-
self.test_gen_candidates(
342-
gen_candidates=gen_candidates,
343-
options={"disp": False, "with_grad": False},
344-
)
338+
self.test_gen_candidates(
339+
gen_candidates=gen_candidates_scipy,
340+
options={"disp": False, "with_grad": False},
341+
)
345342

346-
self.test_gen_candidates_with_fixed_features(
347-
gen_candidates=gen_candidates,
348-
options={"disp": False, "with_grad": False},
349-
)
343+
self.test_gen_candidates_with_fixed_features(
344+
gen_candidates=gen_candidates_scipy,
345+
options={"disp": False, "with_grad": False},
346+
)
350347

351-
self.test_gen_candidates_with_none_fixed_features(
352-
gen_candidates=gen_candidates,
353-
options={"disp": False, "with_grad": False},
354-
)
348+
self.test_gen_candidates_with_none_fixed_features(
349+
gen_candidates=gen_candidates_scipy,
350+
options={"disp": False, "with_grad": False},
351+
)
355352

356353

357354
class TestRandomRestartOptimization(TestBaseCandidateGeneration):

test/optim/test_fit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def test_fit_gpytorch_torch(self):
393393
self._test_fit_gpytorch_torch(mll.to(dtype=dtype))
394394

395395
def _test_fit_gpytorch_torch(self, mll):
396-
options = {"disp": False, "maxiter": 3}
396+
options = {"maxiter": 3}
397397
ckpt = {
398398
k: TensorCheckpoint(v.detach().clone(), v.device, v.dtype)
399399
for k, v in mll.state_dict().items()

test/optim/test_optimize.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,9 +292,19 @@ def test_optimize_acqf_sequential(
292292
4 * torch.ones(3, device=self.device, dtype=dtype),
293293
]
294294
)
295-
inequality_constraints = [
296-
(torch.tensor([2]), torch.tensor([4]), torch.tensor(5))
297-
]
295+
if mock_gen_candidates is mock_gen_candidates_scipy:
296+
# x[2] * 4 >= 5
297+
inequality_constraints = [
298+
(torch.tensor([2]), torch.tensor([4]), torch.tensor(5))
299+
]
300+
equality_constraints = [
301+
(torch.tensor([0, 1]), torch.ones(2), torch.tensor(4.0))
302+
]
303+
# gen_candidates_torch does not support constraints
304+
else:
305+
inequality_constraints = None
306+
equality_constraints = None
307+
298308
mock_gen_candidates.reset_mock()
299309
candidates, acq_value = optimize_acqf(
300310
acq_function=mock_acq_function,
@@ -304,6 +314,7 @@ def test_optimize_acqf_sequential(
304314
raw_samples=raw_samples,
305315
options=options,
306316
inequality_constraints=inequality_constraints,
317+
equality_constraints=equality_constraints,
307318
post_processing_func=rounding_func if use_rounding else None,
308319
sequential=True,
309320
timeout_sec=timeout_sec,
@@ -1015,9 +1026,8 @@ def nlc(x):
10151026
if mock_gen_candidates == mock_gen_candidates_torch:
10161027
self.assertEqual(len(ws), 3)
10171028
message = (
1018-
"Keyword arguments ['nonlinear_inequality_constraints',"
1019-
" 'equality_constraints', 'inequality_constraints'] will"
1020-
" be ignored because they are not allowed parameters for"
1029+
"Keyword arguments ['nonlinear_inequality_constraints']"
1030+
" will be ignored because they are not allowed parameters for"
10211031
" function gen_candidates. Allowed parameters are "
10221032
" ['initial_conditions', 'acquisition_function', "
10231033
"'lower_bounds', 'upper_bounds', 'optimizer', 'options',"

0 commit comments

Comments
 (0)