Skip to content

Commit a8bf40a

Browse files
qidewenwhenDewen Qi
andauthored
change: Add more validations for pipeline step new interfaces (#3119)
Co-authored-by: Dewen Qi <[email protected]>
1 parent 621e70e commit a8bf40a

File tree

14 files changed

+893
-109
lines changed

14 files changed

+893
-109
lines changed

src/sagemaker/estimator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1340,6 +1340,11 @@ def register(
13401340
Returns:
13411341
str: A string of SageMaker Model Package ARN.
13421342
"""
1343+
if isinstance(self.sagemaker_session, PipelineSession):
1344+
raise TypeError(
1345+
"estimator.register does not support PipelineSession at this moment. "
1346+
"Please use model.register with PipelineSession if you're using the ModelStep."
1347+
)
13431348
default_name = name_from_base(self.base_job_name)
13441349
model_name = model_name or default_name
13451350
if compile_model_family is not None:

src/sagemaker/session.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ def submit(request):
591591
LOGGER.debug("train request: %s", json.dumps(request, indent=4))
592592
self.sagemaker_client.create_training_job(**request)
593593

594-
self._intercept_create_request(train_request, submit)
594+
self._intercept_create_request(train_request, submit, self.train.__name__)
595595

596596
def _get_train_request( # noqa: C901
597597
self,
@@ -922,7 +922,7 @@ def submit(request):
922922
LOGGER.debug("process request: %s", json.dumps(request, indent=4))
923923
self.sagemaker_client.create_processing_job(**request)
924924

925-
self._intercept_create_request(process_request, submit)
925+
self._intercept_create_request(process_request, submit, self.process.__name__)
926926

927927
def _get_process_request(
928928
self,
@@ -2099,7 +2099,7 @@ def submit(request):
20992099
LOGGER.debug("tune request: %s", json.dumps(request, indent=4))
21002100
self.sagemaker_client.create_hyper_parameter_tuning_job(**request)
21012101

2102-
self._intercept_create_request(tune_request, submit)
2102+
self._intercept_create_request(tune_request, submit, self.create_tuning_job.__name__)
21032103

21042104
def _get_tuning_request(
21052105
self,
@@ -2569,7 +2569,7 @@ def submit(request):
25692569
LOGGER.debug("Transform request: %s", json.dumps(request, indent=4))
25702570
self.sagemaker_client.create_transform_job(**request)
25712571

2572-
self._intercept_create_request(transform_request, submit)
2572+
self._intercept_create_request(transform_request, submit, self.transform.__name__)
25732573

25742574
def _create_model_request(
25752575
self,

src/sagemaker/workflow/model_step.py

Lines changed: 95 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
import logging
1717
from typing import Union, List, Dict, Optional
1818

19-
from sagemaker import Model, PipelineModel
19+
from sagemaker import Model, PipelineModel, Session
2020
from sagemaker.workflow._utils import _RegisterModelStep, _RepackModelStep
2121
from sagemaker.workflow.pipeline_context import PipelineSession, _ModelStepArguments
22-
from sagemaker.workflow.retry import RetryPolicy
22+
from sagemaker.workflow.retry import RetryPolicy, SageMakerJobStepRetryPolicy
2323
from sagemaker.workflow.step_collections import StepCollection
2424
from sagemaker.workflow.steps import Step, CreateModelStep
2525

@@ -57,17 +57,72 @@ def __init__(
5757
If a listed `Step` name does not exist, an error is returned (default: None).
5858
retry_policies (List[RetryPolicy] or Dict[str, List[RetryPolicy]]): The list of retry
5959
policies for the `ModelStep` (default: None).
60+
61+
If a list of retry policies is provided, it would be applied to all steps in the
62+
`ModelStep` collection. Note: in this case, `SageMakerJobStepRetryPolicy`
63+
is not allowed, since create/register model step does not support it.
64+
Please find the example below:
65+
66+
.. code:: python
67+
68+
ModelStep(
69+
...
70+
retry_policies=[
71+
StepRetryPolicy(...),
72+
],
73+
)
74+
75+
If a dict is provided, it can specify different retry policies for different
76+
types of steps in the `ModelStep` collection. Similarly,
77+
`SageMakerJobStepRetryPolicy` is not allowed for create/register model step.
78+
See examples below:
79+
80+
.. code:: python
81+
82+
ModelStep(
83+
...
84+
retry_policies=dict(
85+
register_model_retry_policies=[
86+
StepRetryPolicy(...),
87+
],
88+
repack_model_retry_policies=[
89+
SageMakerJobStepRetryPolicy(...),
90+
],
91+
)
92+
)
93+
94+
or
95+
96+
.. code:: python
97+
98+
ModelStep(
99+
...
100+
retry_policies=dict(
101+
create_model_retry_policies=[
102+
StepRetryPolicy(...),
103+
],
104+
repack_model_retry_policies=[
105+
SageMakerJobStepRetryPolicy(...),
106+
],
107+
)
108+
)
109+
60110
display_name (str): The display name of the `ModelStep`.
61111
The display name provides better UI readability. (default: None).
62112
description (str): The description of the `ModelStep` (default: None).
63113
"""
64114
# TODO: add a doc link in error message once ready
65-
if not isinstance(step_args, _ModelStepArguments):
66-
raise TypeError(
67-
"To correctly configure a ModelStep, "
68-
"the step_args must be a `_ModelStepArguments` object generated by "
69-
".create() or .register()."
70-
)
115+
from sagemaker.workflow.utilities import validate_step_args_input
116+
117+
validate_step_args_input(
118+
step_args=step_args,
119+
expected_caller={
120+
Session.create_model.__name__,
121+
Session.create_model_package_from_containers.__name__,
122+
},
123+
error_message="The step_args of ModelStep must be obtained from model.create() "
124+
"or model.register().",
125+
)
71126
if not (step_args.create_model_request is None) ^ (
72127
step_args.create_model_package_request is None
73128
):
@@ -93,7 +148,22 @@ def __init__(
93148
self._create_model_args = self.step_args.create_model_request
94149
self._register_model_args = self.step_args.create_model_package_request
95150
self._need_runtime_repack = self.step_args.need_runtime_repack
151+
self._assign_and_validate_retry_policies(retry_policies)
152+
153+
if self._need_runtime_repack:
154+
self._append_repack_model_step()
155+
if self._register_model_args:
156+
self._append_register_model_step()
157+
else:
158+
self._append_create_model_step()
96159

160+
def _assign_and_validate_retry_policies(self, retry_policies):
161+
"""Assign and validate retry policies according to each kind of sub steps
162+
163+
Args:
164+
retry_policies (List[RetryPolicy] or Dict[str, List[RetryPolicy]]): The list of retry
165+
policies for the `ModelStep`.
166+
"""
97167
if isinstance(retry_policies, dict):
98168
self._create_model_retry_policies = retry_policies.get(
99169
_CREATE_MODEL_RETRY_POLICIES, None
@@ -109,12 +179,23 @@ def __init__(
109179
self._register_model_retry_policies = retry_policies
110180
self._repack_model_retry_policies = retry_policies
111181

112-
if self._need_runtime_repack:
113-
self._append_repack_model_step()
114-
if self._register_model_args:
115-
self._append_register_model_step()
116-
else:
117-
self._append_create_model_step()
182+
self._validate_sagemaker_job_step_retry_policy()
183+
184+
def _validate_sagemaker_job_step_retry_policy(self):
185+
"""Validate SageMakerJobStepRetryPolicy
186+
187+
Validate that SageMakerJobStepRetryPolicy is not assigning to create/register model step.
188+
"""
189+
retry_policies = set(
190+
(self._create_model_retry_policies or []) + (self._register_model_retry_policies or [])
191+
)
192+
for policy in retry_policies:
193+
if not isinstance(policy, SageMakerJobStepRetryPolicy):
194+
continue
195+
raise ValueError(
196+
"SageMakerJobStepRetryPolicy is not allowed for a create/register"
197+
" model step. Please use StepRetryPolicy instead"
198+
)
118199

119200
def _append_register_model_step(self):
120201
"""Create and append a `_RegisterModelStep`"""

src/sagemaker/workflow/pipeline_context.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,45 @@
1515

1616
import warnings
1717
import inspect
18-
from typing import Dict
1918
from functools import wraps
19+
from typing import Dict, Optional
2020

2121
from sagemaker.session import Session, SessionSettings
2222

2323

24-
class _ModelStepArguments:
25-
"""Step arguments entity for ModelStep"""
24+
class _StepArguments:
25+
"""Step arguments entity for `Step`"""
26+
27+
def __init__(self, caller_name: str = None):
28+
"""Create a `_StepArguments`
29+
30+
Args:
31+
caller_name (str): The name of the caller function which is intercepted by the
32+
PipelineSession to get the step arguments.
33+
"""
34+
self.caller_name = caller_name
35+
36+
37+
class _JobStepArguments(_StepArguments):
38+
"""Step arguments entity for job step types
39+
40+
Job step types include: TrainingStep, ProcessingStep, TuningStep, TransformStep
41+
"""
42+
43+
def __init__(self, caller_name: str, args: dict):
44+
"""Create a `_JobStepArguments`
45+
46+
Args:
47+
caller_name (str): The name of the caller function which is intercepted by the
48+
PipelineSession to get the step arguments.
49+
args (dict): The arguments to be used for composing the SageMaker API request.
50+
"""
51+
super(_JobStepArguments, self).__init__(caller_name)
52+
self.args = args
53+
54+
55+
class _ModelStepArguments(_StepArguments):
56+
"""Step arguments entity for `ModelStep`"""
2657

2758
def __init__(self, model):
2859
"""Create a `_ModelStepArguments`
@@ -31,6 +62,7 @@ def __init__(self, model):
3162
model (Model or PipelineModel): A `sagemaker.model.Model`
3263
or `sagemaker.pipeline.PipelineModel` instance
3364
"""
65+
super(_ModelStepArguments, self).__init__()
3466
self.model = model
3567
self.create_model_package_request = None
3668
self.create_model_request = None
@@ -87,9 +119,8 @@ def context(self):
87119
return self._context
88120

89121
@context.setter
90-
def context(self, args: Dict):
91-
# TODO: we should use _StepArguments type to formalize non-ModelStep args
92-
self._context = args
122+
def context(self, value: Optional[_StepArguments] = None):
123+
self._context = value
93124

94125
def _intercept_create_request(self, request: Dict, create, func_name: str = None):
95126
"""This function intercepts the create job request
@@ -99,12 +130,14 @@ def _intercept_create_request(self, request: Dict, create, func_name: str = None
99130
create (functor): a functor calls the sagemaker client create method
100131
func_name (str): the name of the function needed intercepting
101132
"""
102-
if func_name == "create_model":
133+
if func_name == self.create_model.__name__:
103134
self.context.create_model_request = request
104-
elif func_name == "create_model_package_from_containers":
135+
self.context.caller_name = func_name
136+
elif func_name == self.create_model_package_from_containers.__name__:
105137
self.context.create_model_package_request = request
138+
self.context.caller_name = func_name
106139
else:
107-
self.context = request
140+
self.context = _JobStepArguments(func_name, request)
108141

109142
def init_step_arguments(self, model):
110143
"""Create a `_ModelStepArguments` (if not exist) as pipeline context

src/sagemaker/workflow/retry.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ def to_request(self) -> RequestType:
148148
request["ExceptionType"] = [e.value for e in self.exception_types]
149149
return request
150150

151+
def __hash__(self):
152+
"""Hash function for StepRetryPolicy types"""
153+
return hash(tuple(self.to_request()))
154+
151155

152156
class SageMakerJobStepRetryPolicy(RetryPolicy):
153157
"""RetryPolicy for exception thrown by SageMaker Job.
@@ -202,3 +206,7 @@ def to_request(self) -> RequestType:
202206
request = super().to_request()
203207
request["ExceptionType"] = [e.value for e in self.exception_type_list]
204208
return request
209+
210+
def __hash__(self):
211+
"""Hash function for SageMakerJobStepRetryPolicy types"""
212+
return hash(tuple(self.to_request()))

0 commit comments

Comments
 (0)