Skip to content

Commit 23201d9

Browse files
committed
fix RegisterModel step with model repacking request dict conversion error + add unite and integ tests for it
1 parent cc93061 commit 23201d9

File tree

3 files changed

+215
-0
lines changed

3 files changed

+215
-0
lines changed

src/sagemaker/workflow/step_collections.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,11 @@ def __init__(
109109
steps.append(repack_model_step)
110110
model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts
111111

112+
# remove kwargs consumed by model repacking step
113+
kwargs.pop("entry_point", None)
114+
kwargs.pop("source_dir", None)
115+
kwargs.pop("dependencies", None)
116+
112117
register_model_step = _RegisterModelStep(
113118
name=name,
114119
estimator=estimator,

tests/integ/test_workflow.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,106 @@ def test_conditional_pytorch_training_model_registration(
776776
pass
777777

778778

779+
def test_model_registration_with_model_repack(
780+
sagemaker_session,
781+
role,
782+
pipeline_name,
783+
region_name,
784+
):
785+
base_dir = os.path.join(DATA_DIR, "pytorch_mnist")
786+
entry_point = os.path.join(base_dir, "mnist.py")
787+
input_path = sagemaker_session.upload_data(
788+
path=os.path.join(base_dir, "training"),
789+
key_prefix="integ-test-data/pytorch_mnist/training",
790+
)
791+
inputs = TrainingInput(s3_data=input_path)
792+
793+
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
794+
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
795+
good_enough_input = ParameterInteger(name="GoodEnoughInput", default_value=1)
796+
797+
pytorch_estimator = PyTorch(
798+
entry_point=entry_point,
799+
role=role,
800+
framework_version="1.5.0",
801+
py_version="py3",
802+
instance_count=instance_count,
803+
instance_type=instance_type,
804+
sagemaker_session=sagemaker_session,
805+
)
806+
step_train = TrainingStep(
807+
name="pytorch-train",
808+
estimator=pytorch_estimator,
809+
inputs=inputs,
810+
)
811+
812+
step_register = RegisterModel(
813+
name="pytorch-register-model",
814+
estimator=pytorch_estimator,
815+
model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,
816+
content_types=["*"],
817+
response_types=["*"],
818+
inference_instances=["*"],
819+
transform_instances=["*"],
820+
description="test-description",
821+
entry_point=entry_point,
822+
)
823+
824+
model = Model(
825+
image_uri=pytorch_estimator.training_image_uri(),
826+
model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,
827+
sagemaker_session=sagemaker_session,
828+
role=role,
829+
)
830+
model_inputs = CreateModelInput(
831+
instance_type="ml.m5.large",
832+
accelerator_type="ml.eia1.medium",
833+
)
834+
step_model = CreateModelStep(
835+
name="pytorch-model",
836+
model=model,
837+
inputs=model_inputs,
838+
)
839+
840+
step_cond = ConditionStep(
841+
name="cond-good-enough",
842+
conditions=[ConditionGreaterThanOrEqualTo(left=good_enough_input, right=1)],
843+
if_steps=[step_train, step_register],
844+
else_steps=[step_model],
845+
)
846+
847+
pipeline = Pipeline(
848+
name=pipeline_name,
849+
parameters=[good_enough_input, instance_count, instance_type],
850+
steps=[step_cond],
851+
sagemaker_session=sagemaker_session,
852+
)
853+
854+
try:
855+
response = pipeline.create(role)
856+
create_arn = response["PipelineArn"]
857+
assert re.match(
858+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", create_arn
859+
)
860+
861+
execution = pipeline.start(parameters={})
862+
assert re.match(
863+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
864+
execution.arn,
865+
)
866+
867+
execution = pipeline.start(parameters={"GoodEnoughInput": 0})
868+
assert re.match(
869+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
870+
execution.arn,
871+
)
872+
finally:
873+
try:
874+
pipeline.delete()
875+
except Exception:
876+
pass
877+
878+
779879
def test_training_job_with_debugger_and_profiler(
780880
sagemaker_session,
781881
pipeline_name,

tests/unit/sagemaker/workflow/test_step_collections.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
import pytest
17+
from tests.unit import DATA_DIR
1718

1819
import sagemaker
1920

@@ -38,13 +39,17 @@
3839
StepCollection,
3940
RegisterModel,
4041
)
42+
from sagemaker.workflow.pipeline import Pipeline
4143
from tests.unit.sagemaker.workflow.helpers import ordered
4244

4345
REGION = "us-west-2"
4446
BUCKET = "my-bucket"
4547
IMAGE_URI = "fakeimage"
4648
ROLE = "DummyRole"
4749
MODEL_NAME = "gisele"
50+
MODEL_REPACKING_IMAGE_URI = (
51+
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:0.23-1-cpu-py3"
52+
)
4853

4954

5055
class CustomStep(Step):
@@ -177,6 +182,111 @@ def test_register_model(estimator, model_metrics):
177182
)
178183

179184

185+
def test_register_model_with_model_repack(estimator, model_metrics):
186+
model_data = f"s3://{BUCKET}/model.tar.gz"
187+
register_model = RegisterModel(
188+
name="RegisterModelStep",
189+
estimator=estimator,
190+
model_data=model_data,
191+
content_types=["content_type"],
192+
response_types=["response_type"],
193+
inference_instances=["inference_instance"],
194+
transform_instances=["transform_instance"],
195+
model_package_group_name="mpg",
196+
model_metrics=model_metrics,
197+
approval_status="Approved",
198+
description="description",
199+
entry_point=f"{DATA_DIR}/dummy_script.py",
200+
)
201+
202+
request_dicts = register_model.request_dicts()
203+
assert len(request_dicts) == 2
204+
print(request_dicts)
205+
for request_dict in request_dicts:
206+
if request_dict["Type"] == "Training":
207+
assert request_dict["Name"] == "RegisterModelStepRepackModel"
208+
arguments = request_dict["Arguments"]
209+
repacker_job_name = arguments["HyperParameters"]["sagemaker_job_name"]
210+
assert ordered(arguments) == ordered(
211+
{
212+
"AlgorithmSpecification": {
213+
"TrainingImage": MODEL_REPACKING_IMAGE_URI,
214+
"TrainingInputMode": "File",
215+
},
216+
"DebugHookConfig": {
217+
"CollectionConfigurations": [],
218+
"S3OutputPath": f"s3://{BUCKET}/",
219+
},
220+
"HyperParameters": {
221+
"inference_script": '"dummy_script.py"',
222+
"model_archive": '"model.tar.gz"',
223+
"sagemaker_submit_directory": '"s3://{}/{}/source/sourcedir.tar.gz"'.format(
224+
BUCKET, repacker_job_name.replace('"', "")
225+
),
226+
"sagemaker_program": '"_repack_model.py"',
227+
"sagemaker_container_log_level": "20",
228+
"sagemaker_job_name": repacker_job_name,
229+
"sagemaker_region": f'"{REGION}"',
230+
},
231+
"InputDataConfig": [
232+
{
233+
"ChannelName": "training",
234+
"DataSource": {
235+
"S3DataSource": {
236+
"S3DataDistributionType": "FullyReplicated",
237+
"S3DataType": "S3Prefix",
238+
"S3Uri": f"s3://{BUCKET}",
239+
}
240+
},
241+
}
242+
],
243+
"OutputDataConfig": {"S3OutputPath": f"s3://{BUCKET}/"},
244+
"ResourceConfig": {
245+
"InstanceCount": 1,
246+
"InstanceType": "ml.m5.large",
247+
"VolumeSizeInGB": 30,
248+
},
249+
"RoleArn": ROLE,
250+
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
251+
}
252+
)
253+
elif request_dict["Type"] == "RegisterModel":
254+
assert request_dict["Name"] == "RegisterModelStep"
255+
arguments = request_dict["Arguments"]
256+
assert len(arguments["InferenceSpecification"]["Containers"]) == 1
257+
assert (
258+
arguments["InferenceSpecification"]["Containers"][0]["Image"]
259+
== estimator.training_image_uri()
260+
)
261+
assert isinstance(
262+
arguments["InferenceSpecification"]["Containers"][0]["ModelDataUrl"], Properties
263+
)
264+
del arguments["InferenceSpecification"]["Containers"]
265+
assert ordered(arguments) == ordered(
266+
{
267+
"InferenceSpecification": {
268+
"SupportedContentTypes": ["content_type"],
269+
"SupportedRealtimeInferenceInstanceTypes": ["inference_instance"],
270+
"SupportedResponseMIMETypes": ["response_type"],
271+
"SupportedTransformInstanceTypes": ["transform_instance"],
272+
},
273+
"ModelApprovalStatus": "Approved",
274+
"ModelMetrics": {
275+
"ModelQuality": {
276+
"Statistics": {
277+
"ContentType": "text/csv",
278+
"S3Uri": f"s3://{BUCKET}/metrics.csv",
279+
},
280+
},
281+
},
282+
"ModelPackageDescription": "description",
283+
"ModelPackageGroupName": "mpg",
284+
}
285+
)
286+
else:
287+
raise Exception("A step exists in the collection of an invalid type.")
288+
289+
180290
def test_estimator_transformer(estimator):
181291
model_data = f"s3://{BUCKET}/model.tar.gz"
182292
model_inputs = CreateModelInput(

0 commit comments

Comments
 (0)