Skip to content

breaking: change Model parameter order to make model_data optional #1579

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sagemaker/amazon/factorization_machines.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
repo = "{}:{}".format(FactorizationMachines.repo_name, FactorizationMachines.repo_version)
image = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
super(FactorizationMachinesModel, self).__init__(
model_data,
image,
model_data,
role,
predictor_cls=FactorizationMachinesPredictor,
sagemaker_session=sagemaker_session,
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/amazon/ipinsights.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
)

super(IPInsightsModel, self).__init__(
model_data,
image,
model_data,
role,
predictor_cls=IPInsightsPredictor,
sagemaker_session=sagemaker_session,
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/amazon/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
repo = "{}:{}".format(KMeans.repo_name, KMeans.repo_version)
image = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
super(KMeansModel, self).__init__(
model_data,
image,
model_data,
role,
predictor_cls=KMeansPredictor,
sagemaker_session=sagemaker_session,
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/amazon/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
registry(sagemaker_session.boto_session.region_name, KNN.repo_name), repo
)
super(KNNModel, self).__init__(
model_data,
image,
model_data,
role,
predictor_cls=KNNPredictor,
sagemaker_session=sagemaker_session,
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/amazon/lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
registry(sagemaker_session.boto_session.region_name, LDA.repo_name), repo
)
super(LDAModel, self).__init__(
model_data,
image,
model_data,
role,
predictor_cls=LDAPredictor,
sagemaker_session=sagemaker_session,
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/amazon/linear_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,8 +474,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
repo = "{}:{}".format(LinearLearner.repo_name, LinearLearner.repo_version)
image = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
super(LinearLearnerModel, self).__init__(
model_data,
image,
model_data,
role,
predictor_cls=LinearLearnerPredictor,
sagemaker_session=sagemaker_session,
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/amazon/ntm.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
registry(sagemaker_session.boto_session.region_name, NTM.repo_name), repo
)
super(NTMModel, self).__init__(
model_data,
image,
model_data,
role,
predictor_cls=NTMPredictor,
sagemaker_session=sagemaker_session,
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/amazon/object2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
registry(sagemaker_session.boto_session.region_name, Object2Vec.repo_name), repo
)
super(Object2VecModel, self).__init__(
model_data,
image,
model_data,
role,
predictor_cls=RealTimePredictor,
sagemaker_session=sagemaker_session,
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/amazon/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
repo = "{}:{}".format(PCA.repo_name, PCA.repo_version)
image = "{}/{}".format(registry(sagemaker_session.boto_session.region_name), repo)
super(PCAModel, self).__init__(
model_data,
image,
model_data,
role,
predictor_cls=PCAPredictor,
sagemaker_session=sagemaker_session,
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/amazon/randomcutforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def __init__(self, model_data, role, sagemaker_session=None, **kwargs):
registry(sagemaker_session.boto_session.region_name, RandomCutForest.repo_name), repo
)
super(RandomCutForestModel, self).__init__(
model_data,
image,
model_data,
role,
predictor_cls=RandomCutForestPredictor,
sagemaker_session=sagemaker_session,
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1398,8 +1398,8 @@ def predict_wrapper(endpoint, session):
kwargs["enable_network_isolation"] = self.enable_network_isolation()

return Model(
self.model_data,
image or self.train_image(),
self.model_data,
role,
vpc_config=self.get_vpc_config(vpc_config_override),
sagemaker_session=self.sagemaker_session,
Expand Down
10 changes: 6 additions & 4 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ class Model(object):

def __init__(
self,
model_data,
image,
model_data=None,
role=None,
predictor_cls=None,
env=None,
Expand All @@ -74,9 +74,9 @@ def __init__(
"""Initialize an SageMaker ``Model``.

Args:
model_data (str): The S3 location of a SageMaker model data
``.tar.gz`` file.
image (str): A Docker image URI.
model_data (str): The S3 location of a SageMaker model data
``.tar.gz`` file (default: None).
role (str): An AWS IAM role (either name or full ARN). The Amazon
SageMaker training jobs and APIs that create Amazon SageMaker
endpoints use this role to access training data and model
Expand Down Expand Up @@ -361,6 +361,8 @@ def compile(
)
if job_name is None:
raise ValueError("You must provide a compilation job name")
if self.model_data is None:
raise ValueError("You must provide an S3 path to the compressed model artifacts.")

framework = framework.upper()
framework_version = self._get_framework_version() or framework_version
Expand Down Expand Up @@ -778,8 +780,8 @@ def __init__(
:class:`~sagemaker.model.Model`.
"""
super(FrameworkModel, self).__init__(
model_data,
image,
model_data,
role,
predictor_cls=predictor_cls,
env=env,
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/multidatamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def __init__(
# Set the ``Model`` parameters if the model parameter is not specified
if not self.model:
super(MultiDataModel, self).__init__(
self.model_data_prefix,
image,
self.model_data_prefix,
role,
name=self.name,
sagemaker_session=self.sagemaker_session,
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/sparkml/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def __init__(self, model_data, role=None, spark_version=2.2, sagemaker_session=N
region_name = (sagemaker_session or Session()).boto_region_name
image = "{}/{}:{}".format(registry(region_name, framework_name), repo_name, spark_version)
super(SparkMLModel, self).__init__(
model_data,
image,
model_data,
role,
predictor_cls=SparkMLPredictor,
sagemaker_session=sagemaker_session,
Expand Down
26 changes: 13 additions & 13 deletions tests/unit/sagemaker/model/test_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_deploy(name_from_image, prepare_container_def, production_variant, sage
container_def = {"Image": MODEL_IMAGE, "Environment": {}, "ModelDataUrl": MODEL_DATA}
prepare_container_def.return_value = container_def

model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE, sagemaker_session=sagemaker_session)
model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE, sagemaker_session=sagemaker_session)
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)

name_from_image.assert_called_with(MODEL_IMAGE)
Expand All @@ -81,7 +81,7 @@ def test_deploy(name_from_image, prepare_container_def, production_variant, sage
@patch("sagemaker.production_variant")
def test_deploy_accelerator_type(production_variant, create_sagemaker_model, sagemaker_session):
model = Model(
MODEL_DATA, MODEL_IMAGE, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
)

production_variant_result = copy.deepcopy(BASE_PRODUCTION_VARIANT)
Expand Down Expand Up @@ -113,7 +113,7 @@ def test_deploy_accelerator_type(production_variant, create_sagemaker_model, sag
@patch("sagemaker.model.Model._create_sagemaker_model", Mock())
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
def test_deploy_endpoint_name(sagemaker_session):
model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE, sagemaker_session=sagemaker_session)
model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE, sagemaker_session=sagemaker_session)

endpoint_name = "blah"
model.deploy(
Expand All @@ -136,7 +136,7 @@ def test_deploy_endpoint_name(sagemaker_session):
@patch("sagemaker.model.Model._create_sagemaker_model")
def test_deploy_tags(create_sagemaker_model, production_variant, sagemaker_session):
model = Model(
MODEL_DATA, MODEL_IMAGE, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
)

tags = [{"Key": "ModelName", "Value": "TestModel"}]
Expand All @@ -157,7 +157,7 @@ def test_deploy_tags(create_sagemaker_model, production_variant, sagemaker_sessi
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
def test_deploy_kms_key(production_variant, sagemaker_session):
model = Model(
MODEL_DATA, MODEL_IMAGE, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
)

key = "some-key-arn"
Expand All @@ -177,7 +177,7 @@ def test_deploy_kms_key(production_variant, sagemaker_session):
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
def test_deploy_async(production_variant, sagemaker_session):
model = Model(
MODEL_DATA, MODEL_IMAGE, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
)

model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, wait=False)
Expand All @@ -196,7 +196,7 @@ def test_deploy_async(production_variant, sagemaker_session):
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
def test_deploy_data_capture_config(production_variant, sagemaker_session):
model = Model(
MODEL_DATA, MODEL_IMAGE, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session
)

data_capture_config = Mock()
Expand All @@ -223,20 +223,20 @@ def test_deploy_data_capture_config(production_variant, sagemaker_session):
@patch("sagemaker.local.LocalSession")
def test_deploy_creates_correct_session(local_session, session):
# We expect a LocalSession when deploying to instance_type = 'local'
model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE)
model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE)
model.deploy(endpoint_name="blah", instance_type="local", initial_instance_count=1)
assert model.sagemaker_session == local_session.return_value

# We expect a real Session when deploying to instance_type != local/local_gpu
model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE)
model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE)
model.deploy(
endpoint_name="remote_endpoint", instance_type="ml.m4.4xlarge", initial_instance_count=2
)
assert model.sagemaker_session == session.return_value


def test_deploy_no_role(sagemaker_session):
model = Model(MODEL_DATA, MODEL_IMAGE, sagemaker_session=sagemaker_session)
model = Model(MODEL_IMAGE, MODEL_DATA, sagemaker_session=sagemaker_session)

with pytest.raises(ValueError, match="Role can not be null for deploying a model"):
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT)
Expand All @@ -248,8 +248,8 @@ def test_deploy_no_role(sagemaker_session):
@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT)
def test_deploy_predictor_cls(production_variant, sagemaker_session):
model = Model(
MODEL_DATA,
MODEL_IMAGE,
MODEL_DATA,
role=ROLE,
name=MODEL_NAME,
predictor_cls=sagemaker.predictor.RealTimePredictor,
Expand All @@ -269,7 +269,7 @@ def test_deploy_predictor_cls(production_variant, sagemaker_session):


def test_deploy_update_endpoint(sagemaker_session):
model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE, sagemaker_session=sagemaker_session)
model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE, sagemaker_session=sagemaker_session)
model.deploy(
instance_type=INSTANCE_TYPE, initial_instance_count=INSTANCE_COUNT, update_endpoint=True
)
Expand Down Expand Up @@ -300,7 +300,7 @@ def test_deploy_update_endpoint_optional_args(sagemaker_session):
kms_key = "foo"
data_capture_config = Mock()

model = Model(MODEL_DATA, MODEL_IMAGE, role=ROLE, sagemaker_session=sagemaker_session)
model = Model(MODEL_IMAGE, MODEL_DATA, role=ROLE, sagemaker_session=sagemaker_session)
model.deploy(
instance_type=INSTANCE_TYPE,
initial_instance_count=INSTANCE_COUNT,
Expand Down
Loading