Skip to content

Commit b69818b

Browse files
authored
Merge branch 'master' into xgboost-1.2-2
2 parents 60d5ed0 + 529752c commit b69818b

File tree

4 files changed

+79
-2
lines changed

4 files changed

+79
-2
lines changed

CHANGELOG.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
11
# Changelog
22

3+
## v2.39.1 (2021-05-05)
4+
5+
### Bug Fixes and Other Changes
6+
7+
* RegisterModel step and custom dependency support
8+
9+
### Documentation Changes
10+
11+
* reverting SageMaker distributed data parallel EFA doc updates
12+
* adding new version, SM dist. data parallel 1.2.0.
13+
* add current Hugging Face supported versions
14+
* SMDDP 1.2.0 release notes
15+
316
## v2.39.0.post0 (2021-05-04)
417

518
### Testing and Release Infrastructure

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.39.1.dev0
1+
2.39.2.dev0

src/sagemaker/workflow/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def arguments(self) -> RequestType:
307307
model._framework_name,
308308
region_name,
309309
version=model.framework_version,
310-
py_version=model.py_version,
310+
py_version=model.py_version if hasattr(model, "py_version") else None,
311311
instance_type=self.kwargs.get("instance_type", self.estimator.instance_type),
312312
accelerator_type=self.kwargs.get("accelerator_type"),
313313
image_scope="inference",

tests/unit/sagemaker/workflow/test_step_collections.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929

3030
from sagemaker.estimator import Estimator
31+
from sagemaker.tensorflow import TensorFlow
3132
from sagemaker.inputs import CreateModelInput, TransformInput
3233
from sagemaker.model_metrics import (
3334
MetricsSource,
@@ -120,6 +121,19 @@ def estimator(sagemaker_session):
120121
)
121122

122123

124+
@pytest.fixture
125+
def estimator_tf(sagemaker_session):
126+
return TensorFlow(
127+
entry_point="/some/script.py",
128+
framework_version="1.15.2",
129+
py_version="py3",
130+
role=ROLE,
131+
instance_type="ml.c4.2xlarge",
132+
instance_count=1,
133+
sagemaker_session=sagemaker_session,
134+
)
135+
136+
123137
@pytest.fixture
124138
def model_metrics():
125139
return ModelMetrics(
@@ -202,6 +216,56 @@ def test_register_model(estimator, model_metrics):
202216
)
203217

204218

219+
def test_register_model_tf(estimator_tf, model_metrics):
220+
model_data = f"s3://{BUCKET}/model.tar.gz"
221+
register_model = RegisterModel(
222+
name="RegisterModelStep",
223+
estimator=estimator_tf,
224+
model_data=model_data,
225+
content_types=["content_type"],
226+
response_types=["response_type"],
227+
inference_instances=["inference_instance"],
228+
transform_instances=["transform_instance"],
229+
model_package_group_name="mpg",
230+
model_metrics=model_metrics,
231+
approval_status="Approved",
232+
description="description",
233+
)
234+
assert ordered(register_model.request_dicts()) == ordered(
235+
[
236+
{
237+
"Name": "RegisterModelStep",
238+
"Type": "RegisterModel",
239+
"Arguments": {
240+
"InferenceSpecification": {
241+
"Containers": [
242+
{
243+
"Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference:1.15.2-cpu",
244+
"ModelDataUrl": f"s3://{BUCKET}/model.tar.gz",
245+
}
246+
],
247+
"SupportedContentTypes": ["content_type"],
248+
"SupportedRealtimeInferenceInstanceTypes": ["inference_instance"],
249+
"SupportedResponseMIMETypes": ["response_type"],
250+
"SupportedTransformInstanceTypes": ["transform_instance"],
251+
},
252+
"ModelApprovalStatus": "Approved",
253+
"ModelMetrics": {
254+
"ModelQuality": {
255+
"Statistics": {
256+
"ContentType": "text/csv",
257+
"S3Uri": f"s3://{BUCKET}/metrics.csv",
258+
},
259+
},
260+
},
261+
"ModelPackageDescription": "description",
262+
"ModelPackageGroupName": "mpg",
263+
},
264+
},
265+
]
266+
)
267+
268+
205269
def test_register_model_with_model_repack(estimator, model_metrics):
206270
model_data = f"s3://{BUCKET}/model.tar.gz"
207271
register_model = RegisterModel(

0 commit comments

Comments
 (0)