Skip to content

Commit ffd2211

Browse files
Edward J Kimjarednielsen
authored andcommitted
Override default save configs in xgboost (aws#341)
* Override default save configs in xgboost * Change kwargs in SaveConfig.from_dict() to default_values * default_values is an empty dict by default * Remove unncessary dict copy * Explicit is better than implicit * Parse default_values in SaveConfigMode
1 parent 2e29782 commit ffd2211

File tree

3 files changed

+48
-27
lines changed

3 files changed

+48
-27
lines changed

tornasole/core/json_config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def get_json_config_as_dict(json_config_path) -> Dict:
8080
return params_dict
8181

8282

83-
def create_hook_from_json_config(hook_cls, collection_manager, json_config_path):
83+
def create_hook_from_json_config(
84+
hook_cls, collection_manager, json_config_path, default_values=None
85+
):
8486
"""Returns a TornasoleHook object corresponding to either TF, PT, or MXNet.
8587
8688
If json_config_path is None, an environment variable must be set.
@@ -100,7 +102,7 @@ def create_hook_from_json_config(hook_cls, collection_manager, json_config_path)
100102
out_dir = tornasole_params.get("out_dir", DEFAULT_SAGEMAKER_TORNASOLE_PATH)
101103
dry_run = tornasole_params.get("dry_run", False)
102104
reduction_config = tornasole_params.get(TORNASOLE_CONFIG_RDN_CFG_KEY)
103-
save_config = SaveConfig.from_dict(tornasole_params.get("save_config_modes"))
105+
save_config = SaveConfig.from_dict(tornasole_params.get("save_config_modes"), default_values)
104106
include_regex = tornasole_params.get(TORNASOLE_CONFIG_INCLUDE_REGEX_KEY)
105107
save_all = tornasole_params.get(TORNASOLE_CONFIG_SAVE_ALL_KEY, False)
106108
return hook_cls(

tornasole/core/save_config.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,9 @@ def merge_default_save_config(self, default_save_config):
108108
self.set_save_config(mode=mode, save_config_mode=SaveConfigMode())
109109

110110
@classmethod
111-
def from_dict(cls, params: Dict[ModeKeys, Any]) -> "SaveConfig":
111+
def from_dict(
112+
cls, params: Dict[ModeKeys, Any], default_values: Dict[str, Any] = None
113+
) -> "SaveConfig":
112114
"""Parses a dict into a SaveConfig object.
113115
114116
Appropriate formats:
@@ -119,12 +121,17 @@ def from_dict(cls, params: Dict[ModeKeys, Any]) -> "SaveConfig":
119121
"""
120122
if params is None:
121123
return None
124+
if default_values is None:
125+
default_values = {}
122126
# Maybe convert strings to enums
123-
if all([isinstance(key, str) for key, value in params.items()]):
127+
if all(isinstance(key, str) for key, value in params.items()):
124128
params = {ModeKeys[key]: value for key, value in params.items()}
125129
# Maybe convert dicts to SaveConfigMode
126-
if all([value is None or isinstance(value, dict) for key, value in params.items()]):
127-
params = {key: SaveConfigMode.from_dict(value) for key, value in params.items()}
130+
if all(value is None or isinstance(value, dict) for key, value in params.items()):
131+
params = {
132+
key: SaveConfigMode.from_dict(value, default_values)
133+
for key, value in params.items()
134+
}
128135
return cls(mode_save_configs=params)
129136

130137
@classmethod
@@ -171,18 +178,18 @@ def __repr__(self):
171178

172179
class SaveConfigMode:
173180
"""
174-
Wrapping all the save configuration parameters into this object.
175-
This would make it easier to set different save configuration for
176-
different collections and for the base tensors saved.
181+
Wrapping all the save configuration parameters into this object.
182+
This would make it easier to set different save configuration for
183+
different collections and for the base tensors saved.
177184
178-
This class should not be serialized by itself, only inside of SaveConfig.
185+
This class should not be serialized by itself, only inside of SaveConfig.
179186
180-
Parameters:
181-
save_interval (int): Save every n steps.
182-
save_steps (list of int): Save at all the steps given in this list. Overrides save_interval.
183-
start_step (int): Save after n steps.
184-
end_step (int): Stop saving after n steps.
185-
"""
187+
Parameters:
188+
save_interval (int): Save every n steps.
189+
save_steps (list of int): Save at all the steps given in this list. Overrides save_interval.
190+
start_step (int): Save after n steps.
191+
end_step (int): Stop saving after n steps.
192+
"""
186193

187194
def __init__(
188195
self,
@@ -222,16 +229,20 @@ def to_json_dict(self):
222229
}
223230

224231
@classmethod
225-
def from_dict(cls, params: Dict[str, Any]):
232+
def from_dict(cls, params: Dict[str, Any], default_values: Dict[str, Any] = None):
226233
if params is None:
227234
return None
228235
elif not isinstance(params, dict):
229236
raise TypeError(f"params={params} is not a dict.")
237+
if default_values is None:
238+
default_values = {}
239+
elif not isinstance(default_values, dict):
240+
raise TypeError(f"default_values={default_values} is not a dict.")
230241
return cls(
231-
save_interval=params.get("save_interval"),
232-
start_step=params.get("start_step"),
233-
end_step=params.get("end_step"),
234-
save_steps=params.get("save_steps"),
242+
save_interval=params.get("save_interval", default_values.get("save_interval")),
243+
start_step=params.get("start_step", default_values.get("start_step")),
244+
end_step=params.get("end_step", default_values.get("end_step")),
245+
save_steps=params.get("save_steps", default_values.get("save_steps")),
235246
)
236247

237248
def __eq__(self, other):

tornasole/xgboost/hook.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@
99
from tornasole.core.hook import CallbackHook
1010
from tornasole.core.tfevent.util import make_numpy_array
1111
from tornasole.core.access_layer.utils import training_has_ended
12-
from tornasole.core.json_config import (
13-
create_hook_from_json_config,
14-
TORNASOLE_CONFIG_DEFAULT_WORKER_NAME,
15-
)
12+
from tornasole.core.json_config import create_hook_from_json_config
1613
from tornasole.xgboost.singleton_utils import set_hook
1714

1815

@@ -21,8 +18,10 @@
2118

2219

2320
DEFAULT_INCLUDE_COLLECTIONS = [CollectionKeys.METRICS]
24-
2521
DEFAULT_SAVE_CONFIG_INTERVAL = 10
22+
DEFAULT_SAVE_CONFIG_START_STEP = 0
23+
DEFAULT_SAVE_CONFIG_END_STEP = None
24+
DEFAULT_SAVE_CONFIG_SAVE_STEPS = []
2625

2726

2827
class TornasoleHook(CallbackHook):
@@ -120,8 +119,17 @@ def hook_from_config(cls, json_config_path=None):
120119
Otherwise,
121120
return None.
122121
"""
122+
default_values = dict(
123+
save_interval=DEFAULT_SAVE_CONFIG_INTERVAL,
124+
start_step=DEFAULT_SAVE_CONFIG_START_STEP,
125+
end_step=DEFAULT_SAVE_CONFIG_END_STEP,
126+
save_steps=DEFAULT_SAVE_CONFIG_SAVE_STEPS,
127+
)
123128
return create_hook_from_json_config(
124-
cls, get_collection_manager(), json_config_path=json_config_path
129+
cls,
130+
get_collection_manager(),
131+
json_config_path=json_config_path,
132+
default_values=default_values,
125133
)
126134

127135
def _is_last_step(self, env: CallbackEnv) -> bool:

0 commit comments

Comments
 (0)