16
16
import logging
17
17
from typing import Union , List , Dict , Optional
18
18
19
- from sagemaker import Model , PipelineModel
19
+ from sagemaker import Model , PipelineModel , Session
20
20
from sagemaker .workflow ._utils import _RegisterModelStep , _RepackModelStep
21
21
from sagemaker .workflow .pipeline_context import PipelineSession , _ModelStepArguments
22
- from sagemaker .workflow .retry import RetryPolicy
22
+ from sagemaker .workflow .retry import RetryPolicy , SageMakerJobStepRetryPolicy
23
23
from sagemaker .workflow .step_collections import StepCollection
24
24
from sagemaker .workflow .steps import Step , CreateModelStep
25
25
@@ -57,17 +57,72 @@ def __init__(
57
57
If a listed `Step` name does not exist, an error is returned (default: None).
58
58
retry_policies (List[RetryPolicy] or Dict[str, List[RetryPolicy]]): The list of retry
59
59
policies for the `ModelStep` (default: None).
60
+
61
+ If a list of retry policies is provided, it would be applied to all steps in the
62
+ `ModelStep` collection. Note: in this case, `SageMakerJobStepRetryPolicy`
63
+ is not allowed, since create/register model step does not support it.
64
+ Please find the example below:
65
+
66
+ .. code:: python
67
+
68
+ ModelStep(
69
+ ...
70
+ retry_policies=[
71
+ StepRetryPolicy(...),
72
+ ],
73
+ )
74
+
75
+ If a dict is provided, it can specify different retry policies for different
76
+ types of steps in the `ModelStep` collection. Similarly,
77
+ `SageMakerJobStepRetryPolicy` is not allowed for create/register model step.
78
+ See examples below:
79
+
80
+ .. code:: python
81
+
82
+ ModelStep(
83
+ ...
84
+ retry_policies=dict(
85
+ register_model_retry_policies=[
86
+ StepRetryPolicy(...),
87
+ ],
88
+ repack_model_retry_policies=[
89
+ SageMakerJobStepRetryPolicy(...),
90
+ ],
91
+ )
92
+ )
93
+
94
+ or
95
+
96
+ .. code:: python
97
+
98
+ ModelStep(
99
+ ...
100
+ retry_policies=dict(
101
+ create_model_retry_policies=[
102
+ StepRetryPolicy(...),
103
+ ],
104
+ repack_model_retry_policies=[
105
+ SageMakerJobStepRetryPolicy(...),
106
+ ],
107
+ )
108
+ )
109
+
60
110
display_name (str): The display name of the `ModelStep`.
61
111
The display name provides better UI readability. (default: None).
62
112
description (str): The description of the `ModelStep` (default: None).
63
113
"""
64
114
# TODO: add a doc link in error message once ready
65
- if not isinstance (step_args , _ModelStepArguments ):
66
- raise TypeError (
67
- "To correctly configure a ModelStep, "
68
- "the step_args must be a `_ModelStepArguments` object generated by "
69
- ".create() or .register()."
70
- )
115
+ from sagemaker .workflow .utilities import validate_step_args_input
116
+
117
+ validate_step_args_input (
118
+ step_args = step_args ,
119
+ expected_caller = {
120
+ Session .create_model .__name__ ,
121
+ Session .create_model_package_from_containers .__name__ ,
122
+ },
123
+ error_message = "The step_args of ModelStep must be obtained from model.create() "
124
+ "or model.register()." ,
125
+ )
71
126
if not (step_args .create_model_request is None ) ^ (
72
127
step_args .create_model_package_request is None
73
128
):
@@ -93,7 +148,22 @@ def __init__(
93
148
self ._create_model_args = self .step_args .create_model_request
94
149
self ._register_model_args = self .step_args .create_model_package_request
95
150
self ._need_runtime_repack = self .step_args .need_runtime_repack
151
+ self ._assign_and_validate_retry_policies (retry_policies )
152
+
153
+ if self ._need_runtime_repack :
154
+ self ._append_repack_model_step ()
155
+ if self ._register_model_args :
156
+ self ._append_register_model_step ()
157
+ else :
158
+ self ._append_create_model_step ()
96
159
160
+ def _assign_and_validate_retry_policies (self , retry_policies ):
161
+ """Assign and validate retry policies according to each kind of sub steps
162
+
163
+ Args:
164
+ retry_policies (List[RetryPolicy] or Dict[str, List[RetryPolicy]]): The list of retry
165
+ policies for the `ModelStep`.
166
+ """
97
167
if isinstance (retry_policies , dict ):
98
168
self ._create_model_retry_policies = retry_policies .get (
99
169
_CREATE_MODEL_RETRY_POLICIES , None
@@ -109,12 +179,23 @@ def __init__(
109
179
self ._register_model_retry_policies = retry_policies
110
180
self ._repack_model_retry_policies = retry_policies
111
181
112
- if self ._need_runtime_repack :
113
- self ._append_repack_model_step ()
114
- if self ._register_model_args :
115
- self ._append_register_model_step ()
116
- else :
117
- self ._append_create_model_step ()
182
+ self ._validate_sagemaker_job_step_retry_policy ()
183
+
184
+ def _validate_sagemaker_job_step_retry_policy (self ):
185
+ """Validate SageMakerJobStepRetryPolicy
186
+
187
+ Validate that SageMakerJobStepRetryPolicy is not assigning to create/register model step.
188
+ """
189
+ retry_policies = set (
190
+ (self ._create_model_retry_policies or []) + (self ._register_model_retry_policies or [])
191
+ )
192
+ for policy in retry_policies :
193
+ if not isinstance (policy , SageMakerJobStepRetryPolicy ):
194
+ continue
195
+ raise ValueError (
196
+ "SageMakerJobStepRetryPolicy is not allowed for a create/register"
197
+ " model step. Please use StepRetryPolicy instead"
198
+ )
118
199
119
200
def _append_register_model_step (self ):
120
201
"""Create and append a `_RegisterModelStep`"""
0 commit comments