Skip to content

Commit 686ae25

Browse files
authored
Add model saving warning at end of training (#171)
* Add model saving warning at end of training * Add warning when no model artifact is found * Add warning if model is not saved in the SavedModel bundle format * Combine logging messages * Enforce psutil version * Remove pinned version of sagemaker-containers and install this package last
1 parent b3cb548 commit 686ae25

File tree

8 files changed

+71
-6
lines changed

8 files changed

+71
-6
lines changed

docker/1.12.0/Dockerfile.cpu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,13 @@ COPY $framework_support_installable .
7676

7777
RUN pip install --no-cache-dir -U \
7878
keras==2.2.4 \
79-
sagemaker-containers==2.4.2 \
80-
$framework_support_installable \
8179
"sagemaker-tensorflow>=1.12,<1.13" && \
8280
# Let's install TensorFlow separately in the end to avoid
8381
# the library version to be overwritten
8482
pip install --force-reinstall --no-cache-dir -U \
8583
tensorflow-1.12.0-py2.py3-none-any.whl \
8684
horovod && \
85+
pip install --no-cache-dir -U $framework_support_installable && \
8786
rm -f tensorflow-1.12.0-py2.py3-none-any.whl && \
8887
rm -f $framework_support_installable && \
8988
pip uninstall -y --no-cache-dir \

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def read(fname):
4949
'Programming Language :: Python :: 3.6',
5050
],
5151

52-
install_requires=['sagemaker-containers>=2.3.4', 'numpy', 'scipy', 'sklearn',
52+
install_requires=['sagemaker-containers>=2.4.4', 'numpy', 'scipy', 'sklearn',
5353
'pandas', 'Pillow', 'h5py'],
5454
extras_require={
5555
'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'pytest-xdist', 'mock',

src/sagemaker_tensorflow_container/training.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
logger = logging.getLogger(__name__)
2929

3030
SAGEMAKER_PARAMETER_SERVER_ENABLED = 'sagemaker_parameter_server_enabled'
31+
MODEL_DIR = '/opt/ml/model'
3132

3233

3334
def _is_host_master(hosts, current_host):
@@ -159,6 +160,33 @@ def train(env):
159160
runner=runner_type)
160161

161162

163+
def _log_model_missing_warning(model_dir):
164+
pb_file_exists = False
165+
file_exists = False
166+
for dirpath, dirnames, filenames in os.walk(model_dir):
167+
if filenames:
168+
file_exists = True
169+
for f in filenames:
170+
if 'saved_model.pb' in f or 'saved_model.pbtxt' in f:
171+
pb_file_exists = True
172+
path, direct_parent_dir = os.path.split(dirpath)
173+
if not str.isdigit(direct_parent_dir):
174+
logger.warn('Your model will NOT be servable with SageMaker TensorFlow Serving containers.'
175+
'The SavedModel bundle is under directory \"{}\", not a numeric name.'
176+
.format(direct_parent_dir))
177+
178+
if not file_exists:
179+
logger.warn('No model artifact is saved under path {}.'
180+
' Your training job will not save any model files to S3.\n'
181+
'For details of how to construct your training script see:\n'
182+
'https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/tensorflow#adapting-your-local-tensorflow-script' # noqa
183+
.format(model_dir))
184+
elif not pb_file_exists:
185+
logger.warn('Your model will NOT be servable with SageMaker TensorFlow Serving container.'
186+
'The model artifact was not saved in the TensorFlow SavedModel directory structure:\n'
187+
'https://www.tensorflow.org/guide/saved_model#structure_of_a_savedmodel_directory')
188+
189+
162190
def main():
163191
"""Training entry point
164192
"""
@@ -167,3 +195,4 @@ def main():
167195
s3_utils.configure(env.hyperparameters.get('model_dir'), os.environ.get('SAGEMAKER_REGION'))
168196
logger.setLevel(env.log_level)
169197
train(env)
198+
_log_model_missing_warning(MODEL_DIR)

test/resources/test_dir_correct_model/12345/saved_model.pb

Whitespace-only changes.

test/resources/test_dir_wrong_model/fake_model.h5

Whitespace-only changes.

test/resources/test_dir_wrong_parent_dir/not-digit/saved_model.pb

Whitespace-only changes.

test/unit/test_training.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
PS_TASK_2 = {'index': 1, 'type': 'ps'}
4242
MODEL_DIR = 's3://bucket/prefix'
4343
REGION = 'us-west-2'
44+
RESOURCE_PATH = os.path.join(os.path.dirname(__file__), '..', 'resources')
4445

4546

4647
@pytest.fixture
@@ -200,18 +201,54 @@ def test_build_tf_config_error():
200201
assert 'Cannot have a ps task if there are no parameter servers in the cluster' in str(error)
201202

202203

204+
@patch('sagemaker_tensorflow_container.training.logger')
205+
def test_log_model_missing_warning_no_model(logger):
206+
path = os.path.join(RESOURCE_PATH, 'test_dir_empty')
207+
if not os.path.exists(path):
208+
os.mkdir(path)
209+
training._log_model_missing_warning(path)
210+
logger.warn.assert_called_with('No model artifact is saved under path {}.'
211+
' Your training job will not save any model files to S3.\n'
212+
'For details of how to construct your training script see:\n'
213+
'https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/tensorflow#adapting-your-local-tensorflow-script' # noqa
214+
.format(path))
215+
216+
217+
@patch('sagemaker_tensorflow_container.training.logger')
218+
def test_log_model_missing_warning_wrong_format(logger):
219+
training._log_model_missing_warning(os.path.join(RESOURCE_PATH, 'test_dir_wrong_model'))
220+
logger.warn.assert_called_with('Your model will NOT be servable with SageMaker TensorFlow Serving container.'
221+
'The model artifact was not saved in the TensorFlow '
222+
'SavedModel directory structure:\n'
223+
'https://www.tensorflow.org/guide/saved_model#structure_of_a_savedmodel_directory')
224+
225+
226+
@patch('sagemaker_tensorflow_container.training.logger')
227+
def test_log_model_missing_warning_wrong_parent_dir(logger):
228+
training._log_model_missing_warning(os.path.join(RESOURCE_PATH, 'test_dir_wrong_parent_dir'))
229+
logger.warn.assert_called_with('Your model will NOT be servable with SageMaker TensorFlow Serving containers.'
230+
'The SavedModel bundle is under directory \"{}\", not a numeric name.'
231+
.format('not-digit'))
232+
233+
234+
@patch('sagemaker_tensorflow_container.training.logger')
235+
def test_log_model_missing_warning_correct(logger):
236+
training._log_model_missing_warning(os.path.join(RESOURCE_PATH, 'test_dir_correct_model'))
237+
logger.warn.assert_not_called()
238+
239+
240+
@patch('sagemaker_tensorflow_container.training.logger')
203241
@patch('sagemaker_tensorflow_container.training.train')
204242
@patch('logging.Logger.setLevel')
205243
@patch('sagemaker_containers.beta.framework.training_env')
206244
@patch('sagemaker_containers.beta.framework.env.read_hyperparameters', return_value={})
207245
@patch('sagemaker_tensorflow_container.s3_utils.configure')
208246
def test_main(configure_s3_env, read_hyperparameters, training_env,
209-
set_level, train, single_machine_training_env):
247+
set_level, train, logger, single_machine_training_env):
210248
training_env.return_value = single_machine_training_env
211249
os.environ['SAGEMAKER_REGION'] = REGION
212250
training.main()
213251
read_hyperparameters.assert_called_once_with()
214252
training_env.assert_called_once_with(hyperparameters={})
215-
set_level.assert_called_once_with(LOG_LEVEL)
216253
train.assert_called_once_with(single_machine_training_env)
217254
configure_s3_env.assert_called_once()

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ python =
1212
3.6: py36, flake8
1313

1414
[flake8]
15-
max-line-length = 100
15+
max-line-length = 120
1616
exclude =
1717
build/
1818
.git

0 commit comments

Comments
 (0)