Skip to content

Commit e26ef31

Browse files
authored
Allow setting global mode and add test (aws#88)
1 parent 7b31c99 commit e26ef31

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

smdebug/core/hook.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from smdebug.core.hook_utils import get_tensorboard_dir, verify_and_get_out_dir
2828
from smdebug.core.json_config import create_hook_from_json_config
2929
from smdebug.core.logger import get_logger
30-
from smdebug.core.modes import ALLOWED_MODES, ModeKeys
30+
from smdebug.core.modes import ALLOWED_MODE_NAMES, ALLOWED_MODES, ModeKeys
3131
from smdebug.core.reduction_config import ReductionConfig
3232
from smdebug.core.reductions import get_reduction_tensor_name
3333
from smdebug.core.sagemaker_utils import is_sagemaker_job
@@ -461,7 +461,7 @@ def set_mode(self, mode):
461461
self.mode = mode
462462
else:
463463
raise ValueError(
464-
"Invalid mode {}. Valid modes are {}.".format(mode, ",".join(ALLOWED_MODES))
464+
"Invalid mode {}. Valid modes are {}.".format(mode, ",".join(ALLOWED_MODE_NAMES))
465465
)
466466

467467
if mode not in self.mode_steps:

smdebug/core/modes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class ModeKeys(Enum):
1010
GLOBAL = 4
1111

1212

13-
ALLOWED_MODES = [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]
13+
ALLOWED_MODES = [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT, ModeKeys.GLOBAL]
1414
ALLOWED_MODE_NAMES = [x.name for x in ALLOWED_MODES]
1515
MODE_STEP_PLUGIN_NAME = "mode_step"
1616
MODE_PLUGIN_NAME = "mode"

tests/xgboost/test_hook.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44

55
# Third Party
66
import numpy as np
7+
import pytest
78
import xgboost
89

910
# First Party
10-
from smdebug import SaveConfig
11+
from smdebug import SaveConfig, modes
1112
from smdebug.core.access_layer.utils import has_training_ended
1213
from smdebug.core.json_config import CONFIG_FILE_PATH_ENV_STR
1314
from smdebug.trials import create_trial
@@ -220,3 +221,11 @@ def test_hook_tensorboard_dir_created(tmpdir):
220221
hook = Hook(out_dir=out_dir, export_tensorboard=True)
221222
run_xgboost_model(hook=hook)
222223
assert "tensorboard" in os.listdir(out_dir)
224+
225+
226+
def test_setting_mode(tmpdir):
227+
out_dir = os.path.join(tmpdir, str(uuid.uuid4()))
228+
hook = Hook(out_dir=out_dir, export_tensorboard=True)
229+
hook.set_mode(modes.GLOBAL)
230+
with pytest.raises(ValueError):
231+
hook.set_mode("a")

0 commit comments

Comments
 (0)