Skip to content

fix: add better default transform job name handling within Transformer #822

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions src/sagemaker/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ def transform(self, data, data_type='S3Prefix', content_type=None, compression_t
if job_name is not None:
self._current_job_name = job_name
else:
base_name = self.base_transform_job_name or base_name_from_image(self._retrieve_image_name())
base_name = self.base_transform_job_name

if base_name is None:
base_name = self._retrieve_base_name()

self._current_job_name = name_from_base(base_name)

if self.output_path is None:
Expand All @@ -120,10 +124,28 @@ def delete_model(self):
"""
self.sagemaker_session.delete_model(self.model_name)

def _retrieve_base_name(self):
image_name = self._retrieve_image_name()

if image_name:
return base_name_from_image(image_name)

return self.model_name

def _retrieve_image_name(self):
try:
model_desc = self.sagemaker_session.sagemaker_client.describe_model(ModelName=self.model_name)
return model_desc['PrimaryContainer']['Image']

primary_container = model_desc.get('PrimaryContainer')
if primary_container:
return primary_container.get('Image')

containers = model_desc.get('Containers')
if containers:
return containers[0].get('Image')

return None

except exceptions.ClientError:
raise ValueError('Failed to fetch model information for %s. '
'Please ensure that the model exists. '
Expand Down
75 changes: 69 additions & 6 deletions tests/unit/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest
from mock import MagicMock, Mock, patch

from sagemaker.transformer import Transformer, _TransformJob
from sagemaker.transformer import _TransformJob, Transformer
from tests.integ import test_local_mode

MODEL_NAME = 'model'
Expand All @@ -40,6 +40,18 @@
'base_transform_job_name': JOB_NAME
}

MODEL_DESC_PRIMARY_CONTAINER = {
'PrimaryContainer': {
'Image': IMAGE_NAME
}
}

MODEL_DESC_CONTAINERS_ONLY = {
'Containers': [
{'Image': IMAGE_NAME}
]
}


@pytest.fixture(autouse=True)
def mock_create_tar_file():
Expand Down Expand Up @@ -97,29 +109,80 @@ def test_transform_with_all_params(start_new_job, transformer):

@patch('sagemaker.transformer.name_from_base')
@patch('sagemaker.transformer._TransformJob.start_new')
def test_transform_with_base_job_name(start_new_job, name_from_base, transformer):
def test_transform_with_base_job_name_provided(start_new_job, name_from_base, transformer):
base_name = 'base-job-name'
full_name = '{}-{}'.format(base_name, TIMESTAMP)

transformer.base_transform_job_name = base_name
name_from_base.return_value = full_name

transformer.transform(DATA)
assert name_from_base.called_with(base_name)

name_from_base.assert_called_once_with(base_name)
assert transformer._current_job_name == full_name


@patch('sagemaker.transformer.Transformer._retrieve_base_name', return_value=IMAGE_NAME)
@patch('sagemaker.transformer.name_from_base')
@patch('sagemaker.transformer._TransformJob.start_new')
def test_transform_with_base_name(start_new_job, name_from_base, retrieve_base_name, transformer):
full_name = '{}-{}'.format(IMAGE_NAME, TIMESTAMP)
name_from_base.return_value = full_name

transformer.transform(DATA)

retrieve_base_name.assert_called_once_with()
name_from_base.assert_called_once_with(IMAGE_NAME)
assert transformer._current_job_name == full_name


@patch('sagemaker.transformer.Transformer._retrieve_image_name', return_value=IMAGE_NAME)
@patch('sagemaker.transformer.name_from_base')
@patch('sagemaker.transformer._TransformJob.start_new')
def test_transform_with_fully_generated_job_name(start_new_job, name_from_base, retrieve_image_name, transformer):
def test_transform_with_job_name_based_on_image(start_new_job, name_from_base, retrieve_image_name, transformer):
full_name = '{}-{}'.format(IMAGE_NAME, TIMESTAMP)
name_from_base.return_value = full_name

transformer.transform(DATA)

assert retrieve_image_name.called_once
assert name_from_base.called_with(IMAGE_NAME)
retrieve_image_name.assert_called_once_with()
name_from_base.assert_called_once_with(IMAGE_NAME)
assert transformer._current_job_name == full_name


@pytest.mark.parametrize('model_desc', [MODEL_DESC_PRIMARY_CONTAINER,
MODEL_DESC_CONTAINERS_ONLY])
@patch('sagemaker.transformer.name_from_base')
@patch('sagemaker.transformer._TransformJob.start_new')
def test_transform_with_job_name_based_on_containers(start_new_job, name_from_base, model_desc, transformer):
transformer.sagemaker_session.sagemaker_client.describe_model.return_value = model_desc

full_name = '{}-{}'.format(IMAGE_NAME, TIMESTAMP)
name_from_base.return_value = full_name

transformer.transform(DATA)

transformer.sagemaker_session.sagemaker_client.describe_model.assert_called_once_with(ModelName=MODEL_NAME)
name_from_base.assert_called_once_with(IMAGE_NAME)
assert transformer._current_job_name == full_name


@pytest.mark.parametrize('model_desc', [{'PrimaryContainer': dict()},
{'Containers': [dict()]},
dict(),
])
@patch('sagemaker.transformer.name_from_base')
@patch('sagemaker.transformer._TransformJob.start_new')
def test_transform_with_job_name_based_on_model_name(start_new_job, name_from_base, model_desc, transformer):
transformer.sagemaker_session.sagemaker_client.describe_model.return_value = model_desc

full_name = '{}-{}'.format(MODEL_NAME, TIMESTAMP)
name_from_base.return_value = full_name

transformer.transform(DATA)

transformer.sagemaker_session.sagemaker_client.describe_model.assert_called_once_with(ModelName=MODEL_NAME)
name_from_base.assert_called_once_with(MODEL_NAME)
assert transformer._current_job_name == full_name


Expand Down