Skip to content

Commit 742abd3

Browse files
author
Dewen Qi
committed
change: Clean code and add retry to integ tests
1 parent ca998c4 commit 742abd3

File tree

6 files changed

+197
-104
lines changed

6 files changed

+197
-104
lines changed

src/sagemaker/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ def create(
405405
https://boto3.amazonaws.com/v1/documentation
406406
/api/latest/reference/services/sagemaker.html#SageMaker.Client.add_tags
407407
"""
408+
# TODO: we should replace _create_sagemaker_model() with create()
408409
self._create_sagemaker_model(
409410
instance_type=instance_type,
410411
accelerator_type=accelerator_type,

src/sagemaker/workflow/_utils.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,7 @@ def __init__(
7676
endpoints use this role to access training data and model
7777
artifacts. After the endpoint is created, the inference code
7878
might use the IAM role, if it needs to access an AWS resource.
79-
model_data (str): The S3 location of a SageMaker model data
80-
``.tar.gz`` file.
79+
model_data (str): The S3 location of a SageMaker model data `.tar.gz` file.
8180
entry_point (str): Path (absolute or relative) to the local Python
8281
source file which should be executed as the entry point to
8382
inference. If ``source_dir`` is specified, then ``entry_point``
@@ -94,6 +93,8 @@ def __init__(
9493
>>> |----- test.py
9594
9695
You can assign entry_point='src/train.py'.
96+
display_name (str): The display name of this `_RepackModelStep` step (default: None).
97+
description (str): The description of this `_RepackModelStep` (default: None).
9798
source_dir (str): A relative location to a directory with other training
9899
or model hosting source code dependencies aside from the entry point
99100
file in the Git repo (default: None). Structure within this
@@ -124,12 +125,14 @@ def __init__(
124125
125126
This is not supported with "local code" in Local Mode.
126127
depends_on (List[str] or List[Step]): A list of step names or instances
127-
this step depends on
128+
this step depends on (default: None).
128129
retry_policies (List[RetryPolicy]): The list of retry policies for the current step
130+
(default: None).
129131
subnets (list[str]): List of subnet ids. If not specified, the re-packing
130-
job will be created without VPC config.
132+
job will be created without VPC config (default: None).
131133
security_group_ids (list[str]): List of security group ids. If not
132-
specified, the re-packing job will be created without VPC config.
134+
specified, the re-packing job will be created without VPC config (default: None).
135+
**kwargs: additional arguments for the repacking job.
133136
"""
134137
self._model_data = model_data
135138
self.sagemaker_session = sagemaker_session
@@ -283,36 +286,38 @@ def __init__(
283286
284287
Args:
285288
name (str): The name of the training step.
286-
step_args (dict): The arguments for the `_RegisterModelStep` definition (default: None).
287-
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance
289+
step_args (dict): The arguments for this `_RegisterModelStep` definition
288290
(default: None).
289-
model_data: the S3 URI to the model data from training (default: None).
290-
content_types (list): The supported MIME types for the
291-
input data (default: None).
292-
response_types (list): The supported MIME types for
293-
the output data (default: None).
291+
content_types (list): The supported MIME types for the input data (default: None).
292+
response_types (list): The supported MIME types for the output data (default: None).
294293
inference_instances (list): A list of the instance types that are used to
295294
generate inferences in real-time (default: None).
296295
transform_instances (list): A list of the instance types on which a
297296
transformation job can be run or on which an endpoint
298297
can be deployed (default: None).
298+
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance
299+
(default: None).
300+
model_data: the S3 URI to the model data from training (default: None).
299301
model_package_group_name (str): Model Package Group name, exclusive to
300302
`model_package_name`, using `model_package_group_name`
301303
makes the Model Package versioned (default: None).
302304
model_metrics (ModelMetrics): ModelMetrics object (default: None).
303-
metadata_properties (MetadataProperties): MetadataProperties object
304-
(default: None).
305+
metadata_properties (MetadataProperties): MetadataProperties object (default: None).
305306
approval_status (str): Model Approval Status, values can be "Approved",
306-
"Rejected", or "PendingManualApproval"
307-
(default: "PendingManualApproval").
307+
"Rejected", or "PendingManualApproval" (default: "PendingManualApproval").
308308
image_uri (str): The container image uri for Model Package, if not specified,
309309
Estimator's training container image will be used (default: None).
310310
compile_model_family (str): Instance family for compiled model,
311311
if specified, a compiled model will be used (default: None).
312+
display_name (str): The display name of this `_RegisterModelStep` step (default: None).
312313
description (str): Model Package description (default: None).
313314
depends_on (List[str] or List[Step]): A list of step names or instances
314315
this step depends on (default: None).
315316
retry_policies (List[RetryPolicy]): The list of retry policies for the current step
317+
(default: None).
318+
tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs used to
319+
configure the create model package request (default: None).
320+
container_def_list (list): A list of container definitions (default: None).
316321
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
317322
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
318323
metadata properties (default: None).

src/sagemaker/workflow/model_step.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
that this `ModelStep` depends on.
5656
If a listed `Step` name does not exist, an error is returned (default: None).
5757
retry_policies (List[RetryPolicy] or Dict[str, List[RetryPolicy]]): The list of retry
58-
policies for the `ModelStep`.
58+
policies for the `ModelStep` (default: None).
5959
display_name (str): The display name of the `ModelStep`.
6060
The display name provides better UI readability. (default: None).
6161
description (str): The description of the `ModelStep` (default: None).
@@ -107,8 +107,9 @@ def __init__(
107107
self._create_model_retry_policies = retry_policies
108108
self._register_model_retry_policies = retry_policies
109109
self._repack_model_retry_policies = retry_policies
110-
self._repack()
111110

111+
if self._need_runtime_repack:
112+
self._append_repack_model_step()
112113
if self._register_model_args:
113114
self._append_register_model_step()
114115
else:
@@ -140,12 +141,8 @@ def _append_create_model_step(self):
140141
create_model_step.add_depends_on(self.depends_on)
141142
self.steps.append(create_model_step)
142143

143-
def _repack(self):
144-
"""Add `_RepackModelStep` for the runtime repack"""
145-
146-
if not self._need_runtime_repack:
147-
return
148-
144+
def _append_repack_model_step(self):
145+
"""Create and append a `_RepackModelStep` for the runtime repack"""
149146
if isinstance(self._model, PipelineModel):
150147
model_list = self._model.models
151148
elif isinstance(self._model, Model):

src/sagemaker/workflow/pipeline_context.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,8 @@ def runnable_by_pipeline(run_func):
133133
"""
134134

135135
def wrapper(*args, **kwargs):
136-
if isinstance(args[0].sagemaker_session, PipelineSession):
136+
self_instance = args[0]
137+
if isinstance(self_instance.sagemaker_session, PipelineSession):
137138
run_func_params = inspect.signature(run_func).parameters
138139
arg_list = list(args)
139140

@@ -161,14 +162,14 @@ def wrapper(*args, **kwargs):
161162
UserWarning,
162163
)
163164
if run_func.__name__ in ["register", "create"]:
164-
args[0].sagemaker_session.init_step_arguments(args[0])
165+
self_instance.sagemaker_session.init_step_arguments(self_instance)
165166
run_func(*args, **kwargs)
166-
context = args[0].sagemaker_session.context
167-
args[0].sagemaker_session.context = None
167+
context = self_instance.sagemaker_session.context
168+
self_instance.sagemaker_session.context = None
168169
return context
169170

170171
run_func(*args, **kwargs)
171-
return args[0].sagemaker_session.context
172+
return self_instance.sagemaker_session.context
172173

173174
return run_func(*args, **kwargs)
174175

src/sagemaker/workflow/step_collections.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""The step definitions for workflow."""
1414
from __future__ import absolute_import
1515

16+
import warnings
1617
from typing import List, Union
1718

1819
import attr
@@ -243,6 +244,16 @@ def __init__(
243244
steps.append(register_model_step)
244245
self.steps = steps
245246

247+
# TODO: add public document link here once ready
248+
warnings.warn(
249+
(
250+
"We are deprecating the use of RegisterModel. "
251+
"Instead, please use the ModelStep, which simply takes in the step arguments "
252+
"generated by model.register()."
253+
),
254+
DeprecationWarning,
255+
)
256+
246257

247258
RegisterModel = deprecated_class(RegisterModel, "RegisterModel")
248259

0 commit comments

Comments
 (0)