Skip to content

Commit 1f95b82

Browse files
committed
add helper function to get tuning step top performing model s3 uri
1 parent 8cf18b8 commit 1f95b82

File tree

3 files changed

+77
-3
lines changed

3 files changed

+77
-3
lines changed

src/sagemaker/workflow/steps.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
PropertyFile,
4545
Properties,
4646
)
47+
from sagemaker.workflow.functions import Join
4748

4849

4950
class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
@@ -525,3 +526,27 @@ def to_request(self) -> RequestType:
525526
request_dict.update(self.cache_config.config)
526527

527528
return request_dict
529+
530+
def get_top_model_s3_uri(self, top_k: int, s3_bucket: str, prefix: str = ""):
531+
"""Get the model artifact s3 uri from the top performing training jobs.
532+
533+
Args:
534+
top_k (int): the index of the top performing training job
535+
tuning step stores up to 50 top performing training jobs, hence
536+
a valid top_k value is from 0 to 49. The best training job
537+
model is at index 0
538+
s3_bucket (str): the s3 bucket to store the training job output artifact
539+
prefix (str): the s3 key prefix to store the training job output artifact
540+
"""
541+
values = ["s3:/", s3_bucket]
542+
if prefix != "" and prefix is not None:
543+
values.append(prefix)
544+
545+
return Join(
546+
on="/",
547+
values=values
548+
+ [
549+
self.properties.TrainingJobSummaries[top_k].TrainingJobName,
550+
"output/model.tar.gz",
551+
],
552+
)

tests/integ/test_workflow.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,8 @@ def test_tuning(
884884
objective_type="Maximize",
885885
hyperparameter_ranges=hyperparameter_ranges,
886886
metric_definitions=[{"Name": "test:acc", "Regex": "Overall test accuracy: (.*?);"}],
887+
max_jobs=2,
888+
max_parallel_jobs=2,
887889
)
888890

889891
step_tune = TuningStep(
@@ -892,13 +894,48 @@ def test_tuning(
892894
inputs=inputs,
893895
)
894896

897+
best_model = Model(
898+
image_uri=pytorch_estimator.training_image_uri(),
899+
model_data=step_tune.get_top_model_s3_uri(
900+
top_k=0,
901+
s3_bucket=sagemaker_session.default_bucket(),
902+
),
903+
sagemaker_session=sagemaker_session,
904+
role=role,
905+
)
906+
model_inputs = CreateModelInput(
907+
instance_type="ml.m5.large",
908+
accelerator_type="ml.eia1.medium",
909+
)
910+
step_best_model = CreateModelStep(
911+
name="1st-model",
912+
model=best_model,
913+
inputs=model_inputs,
914+
)
915+
916+
second_best_model = Model(
917+
image_uri=pytorch_estimator.training_image_uri(),
918+
model_data=step_tune.get_top_model_s3_uri(
919+
top_k=1,
920+
s3_bucket=sagemaker_session.default_bucket(),
921+
),
922+
sagemaker_session=sagemaker_session,
923+
role=role,
924+
)
925+
926+
step_second_best_model = CreateModelStep(
927+
name="2nd-best-model",
928+
model=second_best_model,
929+
inputs=model_inputs,
930+
)
931+
895932
pipeline = Pipeline(
896933
name=pipeline_name,
897934
parameters=[instance_count, instance_type],
898-
steps=[step_tune],
935+
steps=[step_tune, step_best_model, step_second_best_model],
899936
sagemaker_session=sagemaker_session,
900937
)
901-
print(pipeline.definition())
938+
902939
try:
903940
response = pipeline.create(role)
904941
create_arn = response["PipelineArn"]

tests/unit/sagemaker/workflow/test_steps.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ def test_single_algo_tuning_step(sagemaker_session):
551551
tuner=tuner,
552552
inputs=inputs,
553553
)
554-
print(tuning_step.to_request())
554+
555555
assert tuning_step.to_request() == {
556556
"Name": "MyTuningStep",
557557
"Type": "Tuning",
@@ -644,6 +644,18 @@ def test_single_algo_tuning_step(sagemaker_session):
644644
assert tuning_step.properties.TrainingJobSummaries[0].TrainingJobName.expr == {
645645
"Get": "Steps.MyTuningStep.TrainingJobSummaries[0].TrainingJobName"
646646
}
647+
assert tuning_step.get_top_model_s3_uri(0, "my-bucket", "my-prefix").expr == {
648+
"Std:Join": {
649+
"On": "/",
650+
"Values": [
651+
"s3:/",
652+
"my-bucket",
653+
"my-prefix",
654+
{"Get": "Steps.MyTuningStep.TrainingJobSummaries[0].TrainingJobName"},
655+
"output/model.tar.gz",
656+
],
657+
}
658+
}
647659

648660

649661
def test_multi_algo_tuning_step(sagemaker_session):

0 commit comments

Comments
 (0)