|
14 | 14 | from __future__ import absolute_import
|
15 | 15 |
|
16 | 16 | import pytest
|
| 17 | +from tests.unit import DATA_DIR |
17 | 18 |
|
18 | 19 | import sagemaker
|
19 | 20 |
|
|
38 | 39 | StepCollection,
|
39 | 40 | RegisterModel,
|
40 | 41 | )
|
| 42 | +from sagemaker.workflow.pipeline import Pipeline |
41 | 43 | from tests.unit.sagemaker.workflow.helpers import ordered
|
42 | 44 |
|
43 | 45 | REGION = "us-west-2"
|
44 | 46 | BUCKET = "my-bucket"
|
45 | 47 | IMAGE_URI = "fakeimage"
|
46 | 48 | ROLE = "DummyRole"
|
47 | 49 | MODEL_NAME = "gisele"
|
| 50 | +MODEL_REPACKING_IMAGE_URI = ( |
| 51 | + "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:0.23-1-cpu-py3" |
| 52 | +) |
48 | 53 |
|
49 | 54 |
|
50 | 55 | class CustomStep(Step):
|
@@ -177,6 +182,111 @@ def test_register_model(estimator, model_metrics):
|
177 | 182 | )
|
178 | 183 |
|
179 | 184 |
|
| 185 | +def test_register_model_with_model_repack(estimator, model_metrics): |
| 186 | + model_data = f"s3://{BUCKET}/model.tar.gz" |
| 187 | + register_model = RegisterModel( |
| 188 | + name="RegisterModelStep", |
| 189 | + estimator=estimator, |
| 190 | + model_data=model_data, |
| 191 | + content_types=["content_type"], |
| 192 | + response_types=["response_type"], |
| 193 | + inference_instances=["inference_instance"], |
| 194 | + transform_instances=["transform_instance"], |
| 195 | + model_package_group_name="mpg", |
| 196 | + model_metrics=model_metrics, |
| 197 | + approval_status="Approved", |
| 198 | + description="description", |
| 199 | + entry_point=f"{DATA_DIR}/dummy_script.py", |
| 200 | + ) |
| 201 | + |
| 202 | + request_dicts = register_model.request_dicts() |
| 203 | + assert len(request_dicts) == 2 |
| 204 | + print(request_dicts) |
| 205 | + for request_dict in request_dicts: |
| 206 | + if request_dict["Type"] == "Training": |
| 207 | + assert request_dict["Name"] == "RegisterModelStepRepackModel" |
| 208 | + arguments = request_dict["Arguments"] |
| 209 | + repacker_job_name = arguments["HyperParameters"]["sagemaker_job_name"] |
| 210 | + assert ordered(arguments) == ordered( |
| 211 | + { |
| 212 | + "AlgorithmSpecification": { |
| 213 | + "TrainingImage": MODEL_REPACKING_IMAGE_URI, |
| 214 | + "TrainingInputMode": "File", |
| 215 | + }, |
| 216 | + "DebugHookConfig": { |
| 217 | + "CollectionConfigurations": [], |
| 218 | + "S3OutputPath": f"s3://{BUCKET}/", |
| 219 | + }, |
| 220 | + "HyperParameters": { |
| 221 | + "inference_script": '"dummy_script.py"', |
| 222 | + "model_archive": '"model.tar.gz"', |
| 223 | + "sagemaker_submit_directory": '"s3://{}/{}/source/sourcedir.tar.gz"'.format( |
| 224 | + BUCKET, repacker_job_name.replace('"', "") |
| 225 | + ), |
| 226 | + "sagemaker_program": '"_repack_model.py"', |
| 227 | + "sagemaker_container_log_level": "20", |
| 228 | + "sagemaker_job_name": repacker_job_name, |
| 229 | + "sagemaker_region": f'"{REGION}"', |
| 230 | + }, |
| 231 | + "InputDataConfig": [ |
| 232 | + { |
| 233 | + "ChannelName": "training", |
| 234 | + "DataSource": { |
| 235 | + "S3DataSource": { |
| 236 | + "S3DataDistributionType": "FullyReplicated", |
| 237 | + "S3DataType": "S3Prefix", |
| 238 | + "S3Uri": f"s3://{BUCKET}", |
| 239 | + } |
| 240 | + }, |
| 241 | + } |
| 242 | + ], |
| 243 | + "OutputDataConfig": {"S3OutputPath": f"s3://{BUCKET}/"}, |
| 244 | + "ResourceConfig": { |
| 245 | + "InstanceCount": 1, |
| 246 | + "InstanceType": "ml.m5.large", |
| 247 | + "VolumeSizeInGB": 30, |
| 248 | + }, |
| 249 | + "RoleArn": ROLE, |
| 250 | + "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, |
| 251 | + } |
| 252 | + ) |
| 253 | + elif request_dict["Type"] == "RegisterModel": |
| 254 | + assert request_dict["Name"] == "RegisterModelStep" |
| 255 | + arguments = request_dict["Arguments"] |
| 256 | + assert len(arguments["InferenceSpecification"]["Containers"]) == 1 |
| 257 | + assert ( |
| 258 | + arguments["InferenceSpecification"]["Containers"][0]["Image"] |
| 259 | + == estimator.training_image_uri() |
| 260 | + ) |
| 261 | + assert isinstance( |
| 262 | + arguments["InferenceSpecification"]["Containers"][0]["ModelDataUrl"], Properties |
| 263 | + ) |
| 264 | + del arguments["InferenceSpecification"]["Containers"] |
| 265 | + assert ordered(arguments) == ordered( |
| 266 | + { |
| 267 | + "InferenceSpecification": { |
| 268 | + "SupportedContentTypes": ["content_type"], |
| 269 | + "SupportedRealtimeInferenceInstanceTypes": ["inference_instance"], |
| 270 | + "SupportedResponseMIMETypes": ["response_type"], |
| 271 | + "SupportedTransformInstanceTypes": ["transform_instance"], |
| 272 | + }, |
| 273 | + "ModelApprovalStatus": "Approved", |
| 274 | + "ModelMetrics": { |
| 275 | + "ModelQuality": { |
| 276 | + "Statistics": { |
| 277 | + "ContentType": "text/csv", |
| 278 | + "S3Uri": f"s3://{BUCKET}/metrics.csv", |
| 279 | + }, |
| 280 | + }, |
| 281 | + }, |
| 282 | + "ModelPackageDescription": "description", |
| 283 | + "ModelPackageGroupName": "mpg", |
| 284 | + } |
| 285 | + ) |
| 286 | + else: |
| 287 | + raise Exception("A step exists in the collection of an invalid type.") |
| 288 | + |
| 289 | + |
180 | 290 | def test_estimator_transformer(estimator):
|
181 | 291 | model_data = f"s3://{BUCKET}/model.tar.gz"
|
182 | 292 | model_inputs = CreateModelInput(
|
|
0 commit comments