Skip to content

Commit 1b083f5

Browse files
author
Ignacio Quintero
committed
Merge branch 'master' into remove_cw_metrics_arg
2 parents 9967d52 + 6ff1e23 commit 1b083f5

File tree

4 files changed

+78
-23
lines changed

4 files changed

+78
-23
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ CHANGELOG
44

55
1.7.1dev
66
========
7+
8+
* bug-fix: Session: use existing model instead of failing during ``create_model()``
79
* deprecate enable_cloudwatch_metrics from Framework Estimators.
810

911
1.7.0

src/sagemaker/session.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -439,15 +439,24 @@ def create_model(self, name, role, primary_container):
439439
role = self.expand_role(role)
440440
primary_container = _expand_container_def(primary_container)
441441
LOGGER.info('Creating model with name: {}'.format(name))
442-
LOGGER.debug("create_model request: {}".format({
442+
LOGGER.debug('create_model request: {}'.format({
443443
'name': name,
444444
'role': role,
445445
'primary_container': primary_container
446446
}))
447447

448-
self.sagemaker_client.create_model(ModelName=name,
449-
PrimaryContainer=primary_container,
450-
ExecutionRoleArn=role)
448+
try:
449+
self.sagemaker_client.create_model(ModelName=name,
450+
PrimaryContainer=primary_container,
451+
ExecutionRoleArn=role)
452+
except ClientError as e:
453+
error_code = e.response['Error']['Code']
454+
message = e.response['Error']['Message']
455+
456+
if error_code == 'ValidationException' and 'Cannot create already existing model' in message:
457+
LOGGER.warning('Using already existing model: {}'.format(name))
458+
else:
459+
raise
451460

452461
return name
453462

tests/unit/test_session.py

Lines changed: 58 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,17 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
import pytest
15+
import datetime
1616
import io
17+
import logging
18+
19+
import pytest
1720
import six
21+
from botocore.exceptions import ClientError
1822
from mock import Mock, patch, call
23+
1924
import sagemaker
2025
from sagemaker import s3_input, Session, get_execution_role
21-
import datetime
22-
23-
from botocore.exceptions import ClientError
24-
2526
from sagemaker.session import _tuning_job_status, _transform_job_status
2627

2728
REGION = 'us-west-2'
@@ -502,18 +503,57 @@ def test_logs_for_job_full_lifecycle(time, cw, sagemaker_session_full_lifecycle)
502503
call(0, 'hi there #2a'), call(0, 'hi there #3')]
503504

504505

506+
MODEL_NAME = 'some-model'
507+
PRIMARY_CONTAINER = {
508+
'Environment': {},
509+
'Image': IMAGE,
510+
'ModelDataUrl': 's3://sagemaker-123/output/jobname/model/model.tar.gz',
511+
}
512+
513+
514+
@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER)
515+
def test_create_model(expand_container_def, sagemaker_session):
516+
model = sagemaker_session.create_model(MODEL_NAME, ROLE, PRIMARY_CONTAINER)
517+
518+
assert model == MODEL_NAME
519+
sagemaker_session.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE,
520+
ModelName=MODEL_NAME,
521+
PrimaryContainer=PRIMARY_CONTAINER)
522+
523+
524+
@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER)
525+
def test_create_model_already_exists(expand_container_def, sagemaker_session, caplog):
526+
error_response = {'Error': {'Code': 'ValidationException', 'Message': 'Cannot create already existing model'}}
527+
exception = ClientError(error_response, 'Operation')
528+
sagemaker_session.sagemaker_client.create_model.side_effect = exception
529+
530+
model = sagemaker_session.create_model(MODEL_NAME, ROLE, PRIMARY_CONTAINER)
531+
assert model == MODEL_NAME
532+
533+
expected_warning = ('sagemaker', logging.WARNING, 'Using already existing model: {}'.format(MODEL_NAME))
534+
assert expected_warning in caplog.record_tuples
535+
536+
537+
@patch('sagemaker.session._expand_container_def', return_value=PRIMARY_CONTAINER)
538+
def test_create_model_failure(expand_container_def, sagemaker_session):
539+
error_message = 'this is expected'
540+
sagemaker_session.sagemaker_client.create_model.side_effect = RuntimeError(error_message)
541+
542+
with pytest.raises(RuntimeError) as e:
543+
sagemaker_session.create_model(MODEL_NAME, ROLE, PRIMARY_CONTAINER)
544+
545+
assert error_message in str(e)
546+
547+
505548
def test_create_model_from_job(sagemaker_session):
506549
ims = sagemaker_session
507550
ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT
508551
ims.create_model_from_job(JOB_NAME)
509552

510-
assert call(TrainingJobName='jobname') in ims.sagemaker_client.describe_training_job.call_args_list
511-
ims.sagemaker_client.create_model.assert_called_with(
512-
ExecutionRoleArn='arn:aws:iam::111111111111:role/ExpandedRole',
513-
ModelName='jobname',
514-
PrimaryContainer={
515-
'Environment': {}, 'ModelDataUrl': 's3://sagemaker-123/output/jobname/model/model.tar.gz',
516-
'Image': 'myimage'})
553+
assert call(TrainingJobName=JOB_NAME) in ims.sagemaker_client.describe_training_job.call_args_list
554+
ims.sagemaker_client.create_model.assert_called_with(ExecutionRoleArn=EXPANDED_ROLE,
555+
ModelName=JOB_NAME,
556+
PrimaryContainer=PRIMARY_CONTAINER)
517557

518558

519559
def test_create_model_from_job_with_image(sagemaker_session):
@@ -592,7 +632,8 @@ def test_endpoint_from_production_variants_with_tags(sagemaker_session):
592632
Tags=tags)
593633

594634

595-
def test_wait_for_tuning_job(sagemaker_session):
635+
@patch('time.sleep')
636+
def test_wait_for_tuning_job(sleep, sagemaker_session):
596637
hyperparameter_tuning_job_desc = {'HyperParameterTuningJobStatus': 'Completed'}
597638
sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(
598639
name='describe_hyper_parameter_tuning_job', return_value=hyperparameter_tuning_job_desc)
@@ -621,15 +662,17 @@ def test_tune_job_status_none(sagemaker_session):
621662
assert result is None
622663

623664

624-
def test_wait_for_transform_job_completed(sagemaker_session):
665+
@patch('time.sleep')
666+
def test_wait_for_transform_job_completed(sleep, sagemaker_session):
625667
transform_job_desc = {'TransformJobStatus': 'Completed'}
626668
sagemaker_session.sagemaker_client.describe_transform_job = Mock(
627669
name='describe_transform_job', return_value=transform_job_desc)
628670

629671
assert sagemaker_session.wait_for_transform_job(JOB_NAME)['TransformJobStatus'] == 'Completed'
630672

631673

632-
def test_wait_for_transform_job_in_progress(sagemaker_session):
674+
@patch('time.sleep')
675+
def test_wait_for_transform_job_in_progress(sleep, sagemaker_session):
633676
transform_job_desc_in_progress = {'TransformJobStatus': 'InProgress'}
634677
transform_job_desc_in_completed = {'TransformJobStatus': 'Completed'}
635678
sagemaker_session.sagemaker_client.describe_transform_job = Mock(

tests/unit/test_tf_estimator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,8 @@ def test_run_tensorboard_locally_without_awscli_binary(time, strftime, popen, ca
308308
@patch('subprocess.Popen')
309309
@patch('time.strftime', return_value=TIMESTAMP)
310310
@patch('time.time', return_value=TIME)
311-
def test_run_tensorboard_locally(time, strftime, popen, call, access, rmtree, mkdtemp, sync, sagemaker_session):
311+
@patch('time.sleep')
312+
def test_run_tensorboard_locally(sleep, time, strftime, popen, call, access, rmtree, mkdtemp, sync, sagemaker_session):
312313
tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
313314
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE)
314315

@@ -318,8 +319,7 @@ def test_run_tensorboard_locally(time, strftime, popen, call, access, rmtree, mk
318319

319320
popen.assert_called_with(['tensorboard', '--logdir', '/my/temp/folder', '--host', 'localhost', '--port', '6006'],
320321
stderr=-1,
321-
stdout=-1
322-
)
322+
stdout=-1)
323323

324324

325325
@patch('sagemaker.tensorflow.estimator.Tensorboard._sync_directories')
@@ -331,7 +331,8 @@ def test_run_tensorboard_locally(time, strftime, popen, call, access, rmtree, mk
331331
@patch('subprocess.Popen')
332332
@patch('time.strftime', return_value=TIMESTAMP)
333333
@patch('time.time', return_value=TIME)
334-
def test_run_tensorboard_locally_port_in_use(time, strftime, popen, call, access, socket, rmtree, mkdtemp, sync,
334+
@patch('time.sleep')
335+
def test_run_tensorboard_locally_port_in_use(sleep, time, strftime, popen, call, access, socket, rmtree, mkdtemp, sync,
335336
sagemaker_session):
336337
tf = TensorFlow(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
337338
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE)

0 commit comments

Comments
 (0)