|
28 | 28 | )
|
29 | 29 |
|
30 | 30 | from sagemaker.estimator import Estimator
|
| 31 | +from sagemaker.tensorflow import TensorFlow |
31 | 32 | from sagemaker.inputs import CreateModelInput, TransformInput
|
32 | 33 | from sagemaker.model_metrics import (
|
33 | 34 | MetricsSource,
|
@@ -119,6 +120,17 @@ def estimator(sagemaker_session):
|
119 | 120 | sagemaker_session=sagemaker_session,
|
120 | 121 | )
|
121 | 122 |
|
| 123 | +@pytest.fixture |
| 124 | +def estimator_tf(sagemaker_session): |
| 125 | + return TensorFlow( |
| 126 | + entry_point="/some/script.py", |
| 127 | + framework_version="1.15.2", |
| 128 | + py_version="py3", |
| 129 | + role=ROLE, |
| 130 | + instance_type="ml.c4.2xlarge", |
| 131 | + instance_count=1, |
| 132 | + sagemaker_session=sagemaker_session, |
| 133 | + ) |
122 | 134 |
|
123 | 135 | @pytest.fixture
|
124 | 136 | def model_metrics():
|
@@ -201,6 +213,55 @@ def test_register_model(estimator, model_metrics):
|
201 | 213 | ]
|
202 | 214 | )
|
203 | 215 |
|
| 216 | +def test_register_model_tf(estimator_tf, model_metrics): |
| 217 | + model_data = f"s3://{BUCKET}/model.tar.gz" |
| 218 | + register_model = RegisterModel( |
| 219 | + name="RegisterModelStep", |
| 220 | + estimator=estimator_tf, |
| 221 | + model_data=model_data, |
| 222 | + content_types=["content_type"], |
| 223 | + response_types=["response_type"], |
| 224 | + inference_instances=["inference_instance"], |
| 225 | + transform_instances=["transform_instance"], |
| 226 | + model_package_group_name="mpg", |
| 227 | + model_metrics=model_metrics, |
| 228 | + approval_status="Approved", |
| 229 | + description="description", |
| 230 | + ) |
| 231 | + assert ordered(register_model.request_dicts()) == ordered( |
| 232 | + [ |
| 233 | + { |
| 234 | + "Name": "RegisterModelStep", |
| 235 | + "Type": "RegisterModel", |
| 236 | + "Arguments": { |
| 237 | + "InferenceSpecification": { |
| 238 | + "Containers": [ |
| 239 | + { |
| 240 | + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference:1.15.2-cpu", |
| 241 | + "ModelDataUrl": f"s3://{BUCKET}/model.tar.gz", |
| 242 | + } |
| 243 | + ], |
| 244 | + "SupportedContentTypes": ["content_type"], |
| 245 | + "SupportedRealtimeInferenceInstanceTypes": ["inference_instance"], |
| 246 | + "SupportedResponseMIMETypes": ["response_type"], |
| 247 | + "SupportedTransformInstanceTypes": ["transform_instance"], |
| 248 | + }, |
| 249 | + "ModelApprovalStatus": "Approved", |
| 250 | + "ModelMetrics": { |
| 251 | + "ModelQuality": { |
| 252 | + "Statistics": { |
| 253 | + "ContentType": "text/csv", |
| 254 | + "S3Uri": f"s3://{BUCKET}/metrics.csv", |
| 255 | + }, |
| 256 | + }, |
| 257 | + }, |
| 258 | + "ModelPackageDescription": "description", |
| 259 | + "ModelPackageGroupName": "mpg", |
| 260 | + }, |
| 261 | + }, |
| 262 | + ] |
| 263 | + ) |
| 264 | + |
204 | 265 |
|
205 | 266 | def test_register_model_with_model_repack(estimator, model_metrics):
|
206 | 267 | model_data = f"s3://{BUCKET}/model.tar.gz"
|
|
0 commit comments