Skip to content

Commit 4cc7ab9

Browse files
authored
Raise Error For Invalid Collection Config (aws#162)
1 parent d2f9c2b commit 4cc7ab9

File tree

12 files changed

+120
-18
lines changed

12 files changed

+120
-18
lines changed

smdebug/core/collection.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ class CollectionKeys:
3434

3535
OPTIMIZER_VARIABLES = "optimizer_variables"
3636
TENSORFLOW_SUMMARIES = "tensorflow_summaries"
37+
METRICS = "metrics"
3738

3839
# XGBOOST
3940
HYPERPARAMETERS = "hyperparameters"
40-
METRICS = "metrics"
4141
PREDICTIONS = "predictions"
4242
LABELS = "labels"
4343
FEATURE_IMPORTANCE = "feature_importance"
@@ -65,6 +65,50 @@ class CollectionKeys:
6565

6666
NON_HISTOGRAM_COLLECTIONS = SCALAR_COLLECTIONS.union(SUMMARIES_COLLECTIONS)
6767

68+
DEFAULT_TF_COLLECTIONS = {
69+
CollectionKeys.ALL,
70+
CollectionKeys.DEFAULT,
71+
CollectionKeys.WEIGHTS,
72+
CollectionKeys.BIASES,
73+
CollectionKeys.GRADIENTS,
74+
CollectionKeys.LOSSES,
75+
CollectionKeys.METRICS,
76+
CollectionKeys.INPUTS,
77+
CollectionKeys.OUTPUTS,
78+
CollectionKeys.SM_METRICS,
79+
CollectionKeys.OPTIMIZER_VARIABLES,
80+
}
81+
82+
DEFAULT_PYTORCH_COLLECTIONS = {
83+
CollectionKeys.ALL,
84+
CollectionKeys.DEFAULT,
85+
CollectionKeys.WEIGHTS,
86+
CollectionKeys.BIASES,
87+
CollectionKeys.GRADIENTS,
88+
CollectionKeys.LOSSES,
89+
}
90+
91+
DEFAULT_MXNET_COLLECTIONS = {
92+
CollectionKeys.ALL,
93+
CollectionKeys.DEFAULT,
94+
CollectionKeys.WEIGHTS,
95+
CollectionKeys.BIASES,
96+
CollectionKeys.GRADIENTS,
97+
CollectionKeys.LOSSES,
98+
}
99+
100+
DEFAULT_XGBOOST_COLLECTIONS = {
101+
CollectionKeys.ALL,
102+
CollectionKeys.DEFAULT,
103+
CollectionKeys.HYPERPARAMETERS,
104+
CollectionKeys.PREDICTIONS,
105+
CollectionKeys.LABELS,
106+
CollectionKeys.FEATURE_IMPORTANCE,
107+
CollectionKeys.AVERAGE_SHAP,
108+
CollectionKeys.FULL_SHAP,
109+
CollectionKeys.TREES,
110+
}
111+
68112

69113
class Collection:
70114
"""

smdebug/core/hook.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from smdebug.core.state_store import StateStore
3737
from smdebug.core.utils import flatten, get_tb_worker, match_inc, size_and_shape
3838
from smdebug.core.writer import FileWriter
39+
from smdebug.exceptions import InvalidCollectionConfiguration
3940

4041
try:
4142
from smexperiments.metrics import SageMakerFileMetricsWriter
@@ -311,10 +312,17 @@ def _get_collections_with_tensor(self, tensor_name) -> Set["Collection"]:
311312
self.tensor_to_collections[tensor_name] = matched_colls
312313
return self.tensor_to_collections[tensor_name]
313314

315+
@abstractmethod
316+
def _get_default_collections(self):
317+
pass
318+
314319
def _prepare_collections(self):
315320
"""Populate collections_to_save and ensure every collection has
316321
a save_config and reduction_config."""
317322
for c_name, c in self.collection_manager.get_collections().items():
323+
if c_name not in self._get_default_collections():
324+
if bool(c.include_regex) is False and bool(c.tensor_names) is False:
325+
raise InvalidCollectionConfiguration(c_name)
318326
if c in self._collections_to_save:
319327
continue
320328
elif self._should_collection_be_saved(CollectionKeys.ALL):

smdebug/exceptions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,15 @@
22
from smdebug.core.modes import ModeKeys as modes
33

44

5+
class InvalidCollectionConfiguration(Exception):
6+
def __init__(self, c_name):
7+
self.c_name = c_name
8+
9+
def __str__(self):
10+
return f"Collection {self.c_name} has not been configured. \
11+
Please fill in tensor_name or include_regex"
12+
13+
514
class StepNotYetAvailable(Exception):
615
def __init__(self, step, mode):
716
self.step = step

smdebug/mxnet/collection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# First Party
2+
from smdebug.core.collection import DEFAULT_MXNET_COLLECTIONS
23
from smdebug.core.collection import Collection as BaseCollection
34
from smdebug.core.collection import CollectionKeys
45
from smdebug.core.collection_manager import CollectionManager as BaseCollectionManager
@@ -21,6 +22,8 @@ def __init__(self, create_default=True):
2122
self._register_default_collections()
2223

2324
def _register_default_collections(self):
25+
for c in DEFAULT_MXNET_COLLECTIONS:
26+
self.create_collection(c)
2427
self.get(CollectionKeys.WEIGHTS).include("^(?!gradient).*weight")
2528
self.get(CollectionKeys.BIASES).include("^(?!gradient).*bias")
2629
self.get(CollectionKeys.GRADIENTS).include("^gradient")

smdebug/mxnet/hook.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import mxnet as mx
33

44
# First Party
5-
from smdebug.core.collection import CollectionKeys
5+
from smdebug.core.collection import DEFAULT_MXNET_COLLECTIONS, CollectionKeys
66
from smdebug.core.hook import CallbackHook
77
from smdebug.core.json_config import DEFAULT_WORKER_NAME
88
from smdebug.mxnet.collection import CollectionManager
@@ -113,6 +113,9 @@ def _export_model(self):
113113
f"due to the mxnet exception: {e}"
114114
)
115115

116+
def _get_default_collections(self):
117+
return DEFAULT_MXNET_COLLECTIONS
118+
116119
# This hook is invoked by trainer prior to running the forward pass.
117120
def forward_pre_hook(self, block, inputs):
118121
if self.writer is not None:

smdebug/pytorch/collection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# First Party
2+
from smdebug.core.collection import DEFAULT_PYTORCH_COLLECTIONS
23
from smdebug.core.collection import Collection as BaseCollection
34
from smdebug.core.collection import CollectionKeys
45
from smdebug.core.collection_manager import CollectionManager as BaseCollectionManager
@@ -36,6 +37,8 @@ def __init__(self, create_default=True):
3637
self._register_default_collections()
3738

3839
def _register_default_collections(self):
40+
for c in DEFAULT_PYTORCH_COLLECTIONS:
41+
self.create_collection(c)
3942
self.get(CollectionKeys.WEIGHTS).include("^(?!gradient).*weight")
4043
self.get(CollectionKeys.BIASES).include("^(?!gradient).*bias")
4144
self.get(CollectionKeys.GRADIENTS).include("^gradient")

smdebug/pytorch/hook.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch.distributed as dist
66

77
# First Party
8-
from smdebug.core.collection import CollectionKeys
8+
from smdebug.core.collection import DEFAULT_PYTORCH_COLLECTIONS, CollectionKeys
99
from smdebug.core.hook import CallbackHook
1010
from smdebug.core.json_config import DEFAULT_WORKER_NAME
1111
from smdebug.pytorch.collection import CollectionManager
@@ -103,15 +103,18 @@ def _log_params(self, module):
103103
def _export_model(self):
104104
pass
105105

106+
def _get_default_collections(self):
107+
return DEFAULT_PYTORCH_COLLECTIONS
108+
106109
def _prepare_collections(self):
107-
super()._prepare_collections()
108110
for coll in self.collection_manager.collections.values():
109111
for m, (include_inputs, include_outputs) in coll.modules.items():
110112
module_name = self.module_maps[m]
111113
if include_inputs:
112114
coll.include(module_name + "_input_")
113115
if include_outputs:
114116
coll.include(module_name + "_output_")
117+
super()._prepare_collections()
115118

116119
# This hook is invoked by trainer prior to running the forward pass.
117120
def forward_pre_hook(self, module, inputs):

smdebug/tensorflow/base_hook.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from tensorflow.python.distribute.distribute_lib import _DefaultDistributionStrategy
88

99
# First Party
10+
from smdebug.core.collection import DEFAULT_TF_COLLECTIONS
1011
from smdebug.core.config_constants import DEFAULT_WORKER_NAME
1112
from smdebug.core.hook import BaseHook
1213
from smdebug.core.modes import ModeKeys
@@ -179,6 +180,9 @@ def _get_worker_name(self) -> str:
179180
elif self.distribution_strategy == TFDistributionStrategy.UNSUPPORTED:
180181
raise NotImplementedError
181182

183+
def _get_default_collections(self):
184+
return DEFAULT_TF_COLLECTIONS
185+
182186
def export_collections(self):
183187
assert self._prepared_tensors[self.mode]
184188

smdebug/tensorflow/collection.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from tensorflow.python.distribute import values
1010

1111
# First Party
12+
from smdebug.core.collection import DEFAULT_TF_COLLECTIONS
1213
from smdebug.core.collection import Collection as BaseCollection
1314
from smdebug.core.collection import CollectionKeys
1415
from smdebug.core.collection_manager import CollectionManager as BaseCollectionManager
@@ -136,18 +137,7 @@ class CollectionManager(BaseCollectionManager):
136137
def __init__(self, collections=None, create_default=True):
137138
super().__init__(collections=collections)
138139
if create_default:
139-
for n in [
140-
CollectionKeys.DEFAULT,
141-
CollectionKeys.WEIGHTS,
142-
CollectionKeys.BIASES,
143-
CollectionKeys.GRADIENTS,
144-
CollectionKeys.LOSSES,
145-
CollectionKeys.METRICS,
146-
CollectionKeys.INPUTS,
147-
CollectionKeys.OUTPUTS,
148-
CollectionKeys.ALL,
149-
CollectionKeys.SM_METRICS,
150-
]:
140+
for n in DEFAULT_TF_COLLECTIONS:
151141
self.create_collection(n)
152142
self.get(CollectionKeys.BIASES).include("bias")
153143

smdebug/xgboost/collection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# First Party
2-
from smdebug.core.collection import CollectionKeys
2+
from smdebug.core.collection import DEFAULT_XGBOOST_COLLECTIONS, CollectionKeys
33
from smdebug.core.collection_manager import CollectionManager as BaseCollectionManager
44

55

@@ -10,6 +10,8 @@ def __init__(self, create_default=True):
1010
self._register_default_collections()
1111

1212
def _register_default_collections(self):
13+
for c in DEFAULT_XGBOOST_COLLECTIONS:
14+
self.create_collection(c)
1315
self.get(CollectionKeys.HYPERPARAMETERS).include("^hyperparameters/.*$")
1416
self.get(CollectionKeys.METRICS).include("^[a-zA-z]+-[a-zA-z0-9]+$")
1517
self.get(CollectionKeys.PREDICTIONS).include("^predictions$")

smdebug/xgboost/hook.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from xgboost.core import CallbackEnv
1010

1111
# First Party
12-
from smdebug.core.collection import CollectionKeys
12+
from smdebug.core.collection import DEFAULT_XGBOOST_COLLECTIONS, CollectionKeys
1313
from smdebug.core.hook import CallbackHook
1414
from smdebug.core.json_config import create_hook_from_json_config
1515
from smdebug.core.save_config import SaveConfig
@@ -144,6 +144,12 @@ def create_from_json_file(cls, json_file_path=None):
144144
def hook_from_config(cls, json_config_path=None):
145145
return cls.create_from_json_file(json_file_path=json_config_path)
146146

147+
def _get_default_collections(self):
148+
return DEFAULT_XGBOOST_COLLECTIONS
149+
150+
def _prepare_collections(self):
151+
super()._prepare_collections()
152+
147153
def _is_last_step(self, env: CallbackEnv) -> bool:
148154
# env.iteration: current boosting round.
149155
# env.end_iteration: round # when training will end. this is always num_round + 1. # noqa: E501

tests/core/test_collections.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from smdebug.core.reduction_config import ReductionConfig
1111
from smdebug.core.save_config import SaveConfig, SaveConfigMode
1212
from smdebug.core.utils import get_path_to_collections
13+
from smdebug.exceptions import InvalidCollectionConfiguration
1314
from smdebug.mxnet.hook import Hook
1415

1516

@@ -87,6 +88,7 @@ def test_collection_defaults_to_hook_config():
8788
"""
8889
cm = CollectionManager()
8990
cm.create_collection("foo")
91+
cm.get("foo").include_regex = "*"
9092
cm.get("foo").save_config = {ModeKeys.EVAL: SaveConfigMode(save_interval=20)}
9193

9294
hook = Hook(
@@ -101,3 +103,28 @@ def test_collection_defaults_to_hook_config():
101103
hook._prepare_collections()
102104
assert cm.get("foo").save_config.mode_save_configs[ModeKeys.TRAIN].save_interval == 10
103105
assert cm.get("foo").reduction_config.save_raw_tensor is True
106+
107+
108+
def test_invalid_collection_config_exception():
109+
cm = CollectionManager()
110+
cm.create_collection("foo")
111+
112+
hook = Hook(
113+
out_dir="/tmp/test_collections/" + str(datetime.datetime.now()),
114+
save_config={ModeKeys.TRAIN: SaveConfigMode(save_interval=10)},
115+
include_collections=["foo"],
116+
reduction_config=ReductionConfig(save_raw_tensor=True),
117+
)
118+
hook.collection_manager = cm
119+
try:
120+
hook._prepare_collections()
121+
except InvalidCollectionConfiguration:
122+
pass
123+
else:
124+
assert False, "Invalid Collection Name did not raise error"
125+
126+
cm.get("foo").include_regex = "*"
127+
try:
128+
hook._prepare_collections()
129+
except InvalidCollectionConfiguration:
130+
assert False, "Valid Collection Name raised an error"

0 commit comments

Comments
 (0)