Skip to content

fix: make instance type fields as optional #3135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
8 changes: 4 additions & 4 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,8 +1286,8 @@ def register(
self,
content_types,
response_types,
inference_instances,
transform_instances,
inference_instances=None,
transform_instances=None,
image_uri=None,
model_package_name=None,
model_package_group_name=None,
Expand All @@ -1309,9 +1309,9 @@ def register(
content_types (list): The supported MIME types for the input data.
response_types (list): The supported MIME types for the output data.
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time.
generate inferences in real-time (default: None).
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed.
job can be run or on which an endpoint can be deployed (default: None).
image_uri (str): The container image uri for Model Package, if not specified,
Estimator's training container image will be used (default: None).
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
Expand Down
10 changes: 5 additions & 5 deletions src/sagemaker/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,8 @@ def register(
self,
content_types,
response_types,
inference_instances,
transform_instances,
inference_instances=None,
transform_instances=None,
model_package_name=None,
model_package_group_name=None,
image_uri=None,
Expand All @@ -313,9 +313,9 @@ def register(
content_types (list): The supported MIME types for the input data.
response_types (list): The supported MIME types for the output data.
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time.
generate inferences in real-time (default: None).
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed.
job can be run or on which an endpoint can be deployed (default: None).
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
using `model_package_name` makes the Model Package un-versioned.
Defaults to ``None``.
Expand All @@ -341,7 +341,7 @@ def register(
Returns:
A `sagemaker.model.ModelPackage` instance.
"""
instance_type = inference_instances[0]
instance_type = inference_instances[0] if inference_instances else None
self._init_sagemaker_session_if_does_not_exist(instance_type)

if image_uri:
Expand Down
13 changes: 6 additions & 7 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,8 @@ def register(
self,
content_types,
response_types,
inference_instances,
transform_instances,
inference_instances=None,
transform_instances=None,
model_package_name=None,
model_package_group_name=None,
image_uri=None,
Expand All @@ -317,9 +317,9 @@ def register(
content_types (list): The supported MIME types for the input data.
response_types (list): The supported MIME types for the output data.
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time.
generate inferences in real-time (default: None).
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed.
job can be run or on which an endpoint can be deployed (default: None).
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
using `model_package_name` makes the Model Package un-versioned (default: None).
model_package_group_name (str): Model Package Group name, exclusive to
Expand Down Expand Up @@ -351,12 +351,11 @@ def register(
container_def = self.prepare_container_def()
else:
container_def = {"Image": self.image_uri, "ModelDataUrl": self.model_data}

model_pkg_args = sagemaker.get_model_package_args(
content_types,
response_types,
inference_instances,
transform_instances,
inference_instances=inference_instances,
transform_instances=transform_instances,
model_package_name=model_package_name,
model_package_group_name=model_package_group_name,
model_metrics=model_metrics,
Expand Down
10 changes: 5 additions & 5 deletions src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ def register(
self,
content_types,
response_types,
inference_instances,
transform_instances,
inference_instances=None,
transform_instances=None,
model_package_name=None,
model_package_group_name=None,
image_uri=None,
Expand All @@ -166,9 +166,9 @@ def register(
content_types (list): The supported MIME types for the input data.
response_types (list): The supported MIME types for the output data.
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time.
generate inferences in real-time (default: None).
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed.
job can be run or on which an endpoint can be deployed (default: None).
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
using `model_package_name` makes the Model Package un-versioned (default: None).
model_package_group_name (str): Model Package Group name, exclusive to
Expand All @@ -192,7 +192,7 @@ def register(
Returns:
A `sagemaker.model.ModelPackage` instance.
"""
instance_type = inference_instances[0]
instance_type = inference_instances[0] if inference_instances else None
self._init_sagemaker_session_if_does_not_exist(instance_type)

if image_uri:
Expand Down
23 changes: 14 additions & 9 deletions src/sagemaker/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(
self.enable_network_isolation = enable_network_isolation
self.endpoint_name = None

def pipeline_container_def(self, instance_type):
def pipeline_container_def(self, instance_type=None):
"""The pipeline definition for deploying this model.

This is the dict created by ``sagemaker.pipeline_container_def()``.
Expand Down Expand Up @@ -266,8 +266,8 @@ def register(
self,
content_types: list,
response_types: list,
inference_instances: list,
transform_instances: list,
inference_instances: Optional[list] = None,
transform_instances: Optional[list] = None,
model_package_name: Optional[str] = None,
model_package_group_name: Optional[str] = None,
image_uri: Optional[str] = None,
Expand All @@ -286,9 +286,9 @@ def register(
content_types (list): The supported MIME types for the input data.
response_types (list): The supported MIME types for the output data.
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time.
generate inferences in real-time (default: None).
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed.
job can be run or on which an endpoint can be deployed (default: None).
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
using `model_package_name` makes the Model Package un-versioned (default: None).
model_package_group_name (str): Model Package Group name, exclusive to
Expand Down Expand Up @@ -316,18 +316,23 @@ def register(
if model.model_data is None:
raise ValueError("SageMaker Model Package cannot be created without model data.")
if model_package_group_name is not None:
container_def = self.pipeline_container_def(inference_instances[0])
container_def = self.pipeline_container_def(
inference_instances[0] if inference_instances else None
)
else:
container_def = [
{"Image": image_uri or model.image_uri, "ModelDataUrl": model.model_data}
{
"Image": image_uri or model.image_uri,
"ModelDataUrl": model.model_data,
}
for model in self.models
]

model_pkg_args = sagemaker.get_model_package_args(
content_types,
response_types,
inference_instances,
transform_instances,
inference_instances=inference_instances,
transform_instances=transform_instances,
model_package_name=model_package_name,
model_package_group_name=model_package_group_name,
model_metrics=model_metrics,
Expand Down
10 changes: 5 additions & 5 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def register(
self,
content_types,
response_types,
inference_instances,
transform_instances,
inference_instances=None,
transform_instances=None,
model_package_name=None,
model_package_group_name=None,
image_uri=None,
Expand All @@ -167,9 +167,9 @@ def register(
content_types (list): The supported MIME types for the input data.
response_types (list): The supported MIME types for the output data.
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time.
generate inferences in real-time (default: None).
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed.
job can be run or on which an endpoint can be deployed (default: None).
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
using `model_package_name` makes the Model Package un-versioned (default: None).
model_package_group_name (str): Model Package Group name, exclusive to
Expand All @@ -193,7 +193,7 @@ def register(
Returns:
A `sagemaker.model.ModelPackage` instance.
"""
instance_type = inference_instances[0]
instance_type = inference_instances[0] if inference_instances else None
self._init_sagemaker_session_if_does_not_exist(instance_type)

if image_uri:
Expand Down
13 changes: 6 additions & 7 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4206,8 +4206,8 @@ def _intercept_create_request(
def get_model_package_args(
content_types,
response_types,
inference_instances,
transform_instances,
inference_instances=None,
transform_instances=None,
model_package_name=None,
model_package_group_name=None,
model_data=None,
Expand All @@ -4230,9 +4230,9 @@ def get_model_package_args(
content_types (list): The supported MIME types for the input data.
response_types (list): The supported MIME types for the output data.
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time.
generate inferences in real-time (default: None).
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed.
job can be run or on which an endpoint can be deployed (default: None).
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
using `model_package_name` makes the Model Package un-versioned (default: None).
model_package_group_name (str): Model Package Group name, exclusive to
Expand Down Expand Up @@ -4377,10 +4377,9 @@ def get_create_model_package_request(
if domain is not None:
request_dict["Domain"] = domain
if containers is not None:
if not all([content_types, response_types, inference_instances, transform_instances]):
if not all([content_types, response_types]):
raise ValueError(
"content_types, response_types, inference_inferences and transform_instances "
"must be provided if containers is present."
"content_types and response_types " "must be provided if containers is present."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This thrown error was added by sreedes@. Did you check with her to see if she agree on this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed the same with Saumitra and got approval for this change

inference_specification = {
"Containers": containers,
Expand Down
10 changes: 5 additions & 5 deletions src/sagemaker/sklearn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def register(
self,
content_types,
response_types,
inference_instances,
transform_instances,
inference_instances=None,
transform_instances=None,
model_package_name=None,
model_package_group_name=None,
image_uri=None,
Expand All @@ -161,9 +161,9 @@ def register(
content_types (list): The supported MIME types for the input data.
response_types (list): The supported MIME types for the output data.
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time.
generate inferences in real-time (default: None).
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed.
job can be run or on which an endpoint can be deployed (default: None).
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
using `model_package_name` makes the Model Package un-versioned (default: None).
model_package_group_name (str): Model Package Group name, exclusive to
Expand All @@ -187,7 +187,7 @@ def register(
Returns:
A `sagemaker.model.ModelPackage` instance.
"""
instance_type = inference_instances[0]
instance_type = inference_instances[0] if inference_instances else None
self._init_sagemaker_session_if_does_not_exist(instance_type)

if image_uri:
Expand Down
10 changes: 5 additions & 5 deletions src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,8 @@ def register(
self,
content_types,
response_types,
inference_instances,
transform_instances,
inference_instances=None,
transform_instances=None,
model_package_name=None,
model_package_group_name=None,
image_uri=None,
Expand All @@ -213,9 +213,9 @@ def register(
content_types (list): The supported MIME types for the input data.
response_types (list): The supported MIME types for the output data.
inference_instances (list): A list of the instance types that are used to
generate inferences in real-time.
generate inferences in real-time (default: None).
transform_instances (list): A list of the instance types on which a transformation
job can be run or on which an endpoint can be deployed.
job can be run or on which an endpoint can be deployed (default: None).
model_package_name (str): Model Package name, exclusive to `model_package_group_name`,
using `model_package_name` makes the Model Package un-versioned (default: None).
model_package_group_name (str): Model Package Group name, exclusive to
Expand All @@ -239,7 +239,7 @@ def register(
Returns:
A `sagemaker.model.ModelPackage` instance.
"""
instance_type = inference_instances[0]
instance_type = inference_instances[0] if inference_instances else None
self._init_sagemaker_session_if_does_not_exist(instance_type)

if image_uri:
Expand Down
14 changes: 10 additions & 4 deletions src/sagemaker/workflow/step_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def __init__(
name: str,
content_types,
response_types,
inference_instances,
transform_instances,
inference_instances=None,
transform_instances=None,
estimator: EstimatorBase = None,
model_data=None,
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
Expand Down Expand Up @@ -220,9 +220,15 @@ def __init__(
kwargs.pop("output_kms_key", None)

if isinstance(model, PipelineModel):
self.container_def_list = model.pipeline_container_def(inference_instances[0])
self.container_def_list = model.pipeline_container_def(
inference_instances[0] if inference_instances else None
)
elif isinstance(model, Model):
self.container_def_list = [model.prepare_container_def(inference_instances[0])]
self.container_def_list = [
model.prepare_container_def(
inference_instances[0] if inference_instances else None
)
]

register_model_step = _RegisterModelStep(
name=name,
Expand Down
49 changes: 49 additions & 0 deletions tests/unit/sagemaker/workflow/test_pipeline_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,52 @@ def test_pipeline_session_context_for_model_step(pipeline_session_mock):
assert not register_step_args.create_model_request
assert register_step_args.create_model_package_request
assert len(register_step_args.need_runtime_repack) == 0


def test_pipeline_session_context_for_model_step_without_instance_types(
pipeline_session_mock,
):
model = Model(
name="MyModel",
image_uri="fakeimage",
model_data=ParameterString(name="ModelData", default_value="s3://my-bucket/file"),
sagemaker_session=pipeline_session_mock,
entry_point=f"{DATA_DIR}/dummy_script.py",
source_dir=f"{DATA_DIR}",
role=_ROLE,
)

register_step_args = model.register(
content_types=["text/csv"],
response_types=["text/csv"],
model_package_group_name="MyModelPackageGroup",
)

expected_output = {
"ModelPackageGroupName": "MyModelPackageGroup",
"InferenceSpecification": {
"Containers": [
{
"Image": "fakeimage",
"Environment": {
"SAGEMAKER_PROGRAM": "dummy_script.py",
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
"SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
"SAGEMAKER_REGION": "us-west-2",
},
"ModelDataUrl": ParameterString(
name="ModelData",
default_value="s3://my-bucket/file",
),
}
],
"SupportedContentTypes": ["text/csv"],
"SupportedResponseMIMETypes": ["text/csv"],
"SupportedRealtimeInferenceInstanceTypes": None,
"SupportedTransformInstanceTypes": None,
},
"CertifyForMarketplace": False,
"ModelApprovalStatus": "PendingManualApproval",
}

assert register_step_args.create_model_package_request == expected_output
Loading