|
28 | 28 | )
|
29 | 29 |
|
30 | 30 | from sagemaker.estimator import Estimator
|
| 31 | +from sagemaker.tensorflow import TensorFlow |
31 | 32 | from sagemaker.inputs import CreateModelInput, TransformInput
|
32 | 33 | from sagemaker.model_metrics import (
|
33 | 34 | MetricsSource,
|
@@ -120,6 +121,19 @@ def estimator(sagemaker_session):
|
120 | 121 | )
|
121 | 122 |
|
122 | 123 |
|
| 124 | +@pytest.fixture |
| 125 | +def estimator_tf(sagemaker_session): |
| 126 | + return TensorFlow( |
| 127 | + entry_point="/some/script.py", |
| 128 | + framework_version="1.15.2", |
| 129 | + py_version="py3", |
| 130 | + role=ROLE, |
| 131 | + instance_type="ml.c4.2xlarge", |
| 132 | + instance_count=1, |
| 133 | + sagemaker_session=sagemaker_session, |
| 134 | + ) |
| 135 | + |
| 136 | + |
123 | 137 | @pytest.fixture
|
124 | 138 | def model_metrics():
|
125 | 139 | return ModelMetrics(
|
@@ -202,6 +216,56 @@ def test_register_model(estimator, model_metrics):
|
202 | 216 | )
|
203 | 217 |
|
204 | 218 |
|
| 219 | +def test_register_model_tf(estimator_tf, model_metrics): |
| 220 | + model_data = f"s3://{BUCKET}/model.tar.gz" |
| 221 | + register_model = RegisterModel( |
| 222 | + name="RegisterModelStep", |
| 223 | + estimator=estimator_tf, |
| 224 | + model_data=model_data, |
| 225 | + content_types=["content_type"], |
| 226 | + response_types=["response_type"], |
| 227 | + inference_instances=["inference_instance"], |
| 228 | + transform_instances=["transform_instance"], |
| 229 | + model_package_group_name="mpg", |
| 230 | + model_metrics=model_metrics, |
| 231 | + approval_status="Approved", |
| 232 | + description="description", |
| 233 | + ) |
| 234 | + assert ordered(register_model.request_dicts()) == ordered( |
| 235 | + [ |
| 236 | + { |
| 237 | + "Name": "RegisterModelStep", |
| 238 | + "Type": "RegisterModel", |
| 239 | + "Arguments": { |
| 240 | + "InferenceSpecification": { |
| 241 | + "Containers": [ |
| 242 | + { |
| 243 | + "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference:1.15.2-cpu", |
| 244 | + "ModelDataUrl": f"s3://{BUCKET}/model.tar.gz", |
| 245 | + } |
| 246 | + ], |
| 247 | + "SupportedContentTypes": ["content_type"], |
| 248 | + "SupportedRealtimeInferenceInstanceTypes": ["inference_instance"], |
| 249 | + "SupportedResponseMIMETypes": ["response_type"], |
| 250 | + "SupportedTransformInstanceTypes": ["transform_instance"], |
| 251 | + }, |
| 252 | + "ModelApprovalStatus": "Approved", |
| 253 | + "ModelMetrics": { |
| 254 | + "ModelQuality": { |
| 255 | + "Statistics": { |
| 256 | + "ContentType": "text/csv", |
| 257 | + "S3Uri": f"s3://{BUCKET}/metrics.csv", |
| 258 | + }, |
| 259 | + }, |
| 260 | + }, |
| 261 | + "ModelPackageDescription": "description", |
| 262 | + "ModelPackageGroupName": "mpg", |
| 263 | + }, |
| 264 | + }, |
| 265 | + ] |
| 266 | + ) |
| 267 | + |
| 268 | + |
205 | 269 | def test_register_model_with_model_repack(estimator, model_metrics):
|
206 | 270 | model_data = f"s3://{BUCKET}/model.tar.gz"
|
207 | 271 | register_model = RegisterModel(
|
|
0 commit comments