Skip to content

Commit f5ecbde

Browse files
author
Payton Staub
committed
Check py_version existence in RegisterModel step
1 parent 6d39762 commit f5ecbde

File tree

2 files changed

+62
-1
lines changed

2 files changed

+62
-1
lines changed

src/sagemaker/workflow/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def arguments(self) -> RequestType:
307307
model._framework_name,
308308
region_name,
309309
version=model.framework_version,
310-
py_version=model.py_version,
310+
py_version=model.py_version if hasattr(model, "py_version") else None,
311311
instance_type=self.kwargs.get("instance_type", self.estimator.instance_type),
312312
accelerator_type=self.kwargs.get("accelerator_type"),
313313
image_scope="inference",

tests/unit/sagemaker/workflow/test_step_collections.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929

3030
from sagemaker.estimator import Estimator
31+
from sagemaker.tensorflow import TensorFlow
3132
from sagemaker.inputs import CreateModelInput, TransformInput
3233
from sagemaker.model_metrics import (
3334
MetricsSource,
@@ -119,6 +120,17 @@ def estimator(sagemaker_session):
119120
sagemaker_session=sagemaker_session,
120121
)
121122

123+
@pytest.fixture
124+
def estimator_tf(sagemaker_session):
125+
return TensorFlow(
126+
entry_point="/some/script.py",
127+
framework_version="1.15.2",
128+
py_version="py3",
129+
role=ROLE,
130+
instance_type="ml.c4.2xlarge",
131+
instance_count=1,
132+
sagemaker_session=sagemaker_session,
133+
)
122134

123135
@pytest.fixture
124136
def model_metrics():
@@ -201,6 +213,55 @@ def test_register_model(estimator, model_metrics):
201213
]
202214
)
203215

216+
def test_register_model_tf(estimator_tf, model_metrics):
217+
model_data = f"s3://{BUCKET}/model.tar.gz"
218+
register_model = RegisterModel(
219+
name="RegisterModelStep",
220+
estimator=estimator_tf,
221+
model_data=model_data,
222+
content_types=["content_type"],
223+
response_types=["response_type"],
224+
inference_instances=["inference_instance"],
225+
transform_instances=["transform_instance"],
226+
model_package_group_name="mpg",
227+
model_metrics=model_metrics,
228+
approval_status="Approved",
229+
description="description",
230+
)
231+
assert ordered(register_model.request_dicts()) == ordered(
232+
[
233+
{
234+
"Name": "RegisterModelStep",
235+
"Type": "RegisterModel",
236+
"Arguments": {
237+
"InferenceSpecification": {
238+
"Containers": [
239+
{
240+
"Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference:1.15.2-cpu",
241+
"ModelDataUrl": f"s3://{BUCKET}/model.tar.gz",
242+
}
243+
],
244+
"SupportedContentTypes": ["content_type"],
245+
"SupportedRealtimeInferenceInstanceTypes": ["inference_instance"],
246+
"SupportedResponseMIMETypes": ["response_type"],
247+
"SupportedTransformInstanceTypes": ["transform_instance"],
248+
},
249+
"ModelApprovalStatus": "Approved",
250+
"ModelMetrics": {
251+
"ModelQuality": {
252+
"Statistics": {
253+
"ContentType": "text/csv",
254+
"S3Uri": f"s3://{BUCKET}/metrics.csv",
255+
},
256+
},
257+
},
258+
"ModelPackageDescription": "description",
259+
"ModelPackageGroupName": "mpg",
260+
},
261+
},
262+
]
263+
)
264+
204265

205266
def test_register_model_with_model_repack(estimator, model_metrics):
206267
model_data = f"s3://{BUCKET}/model.tar.gz"

0 commit comments

Comments
 (0)