Skip to content

Commit f1d9e49

Browse files
author
Payton Staub
committed
Propagate tags and VPC configs to repack model steps
1 parent 0fb0822 commit f1d9e49

File tree

3 files changed

+351
-16
lines changed

3 files changed

+351
-16
lines changed

src/sagemaker/workflow/_utils.py

Lines changed: 70 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,10 @@
4242

4343

4444
class _RepackModelStep(TrainingStep):
45-
"""Repacks model artifacts with inference entry point.
45+
"""Repacks model artifacts with custom inference entry points.
4646
47-
Attributes:
48-
name (str): The name of the training step.
49-
step_type (StepTypeEnum): The type of the step with value `StepTypeEnum.Training`.
50-
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
51-
inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`.
47+
The SDK automatically adds this step to pipelines that have RegisterModelSteps with models
48+
that have a custom entry point.
5249
"""
5350

5451
def __init__(
@@ -61,19 +58,77 @@ def __init__(
6158
source_dir: str = None,
6259
dependencies: List = None,
6360
depends_on: Union[List[str], List[Step]] = None,
61+
subnets=None,
62+
security_group_ids=None,
6463
**kwargs,
6564
):
66-
"""Constructs a TrainingStep, given an `EstimatorBase` instance.
67-
68-
In addition to the estimator instance, the other arguments are those that are supplied to
69-
the `fit` method of the `sagemaker.estimator.Estimator`.
65+
"""Base class initializer.
7066
7167
Args:
7268
name (str): The name of the training step.
73-
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
74-
inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`.
69+
sagemaker_session (sagemaker.session.Session): Session object which manages
70+
interactions with Amazon SageMaker APIs and any other AWS services needed. If
71+
not specified, the estimator creates one using the default
72+
AWS configuration chain.
73+
role (str): An AWS IAM role (either name or full ARN). The Amazon
74+
SageMaker training jobs and APIs that create Amazon SageMaker
75+
endpoints use this role to access training data and model
76+
artifacts. After the endpoint is created, the inference code
77+
might use the IAM role, if it needs to access an AWS resource.
78+
model_data (str): The S3 location of a SageMaker model data
79+
``.tar.gz`` file (default: None).
80+
entry_point (str): Path (absolute or relative) to the local Python
81+
source file which should be executed as the entry point to
82+
inference. If ``source_dir`` is specified, then ``entry_point``
83+
must point to a file located at the root of ``source_dir``.
84+
If 'git_config' is provided, 'entry_point' should be
85+
a relative location to the Python source file in the Git repo.
86+
87+
Example:
88+
With the following GitHub repo directory structure:
89+
90+
>>> |----- README.md
91+
>>> |----- src
92+
>>> |----- train.py
93+
>>> |----- test.py
94+
95+
You can assign entry_point='src/train.py'.
96+
source_dir (str): A relative location to a directory with other training
97+
or model hosting source code dependencies aside from the entry point
98+
file in the Git repo (default: None). Structure within this
99+
directory are preserved when training on Amazon SageMaker.
100+
dependencies (list[str]): A list of paths to directories (absolute
101+
or relative) with any additional libraries that will be exported
102+
to the container (default: []). The library folders will be
103+
copied to SageMaker in the same folder where the entrypoint is
104+
copied. If 'git_config' is provided, 'dependencies' should be a
105+
list of relative locations to directories with any additional
106+
libraries needed in the Git repo.
107+
108+
.. admonition:: Example
109+
110+
The following call
111+
112+
>>> Estimator(entry_point='train.py',
113+
... dependencies=['my/libs/common', 'virtual-env'])
114+
115+
results in the following inside the container:
116+
117+
>>> $ ls
118+
119+
>>> opt/ml/code
120+
>>> |------ train.py
121+
>>> |------ common
122+
>>> |------ virtual-env
123+
124+
This is not supported with "local code" in Local Mode.
125+
depends_on (List[str] or List[Step]): A list of step names or instances
126+
this step depends on
127+
subnets (list[str]): List of subnet ids. If not specified, the re-packing
128+
job will be created without VPC config.
129+
security_group_ids (list[str]): List of security group ids. If not
130+
specified, the re-packing job will be created without VPC config.
75131
"""
76-
# yeah, go ahead and save the originals for now
77132
self._model_data = model_data
78133
self.sagemaker_session = sagemaker_session
79134
self.role = role
@@ -101,6 +156,8 @@ def __init__(
101156
"inference_script": self._entry_point_basename,
102157
"model_archive": self._model_archive,
103158
},
159+
subnets=subnets,
160+
security_group_ids=security_group_ids,
104161
**kwargs,
105162
)
106163
repacker.disable_profiler = True

src/sagemaker/workflow/step_collections.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __init__(
6969
compile_model_family=None,
7070
description=None,
7171
tags=None,
72-
model=None,
72+
model: Union[Model, PipelineModel] = None,
7373
**kwargs,
7474
):
7575
"""Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator.
@@ -109,6 +109,16 @@ def __init__(
109109
repack_model = False
110110
self.model_list = None
111111
self.container_def_list = None
112+
subnets = None
113+
security_group_ids = None
114+
115+
if estimator is not None:
116+
subnets = estimator.subnets
117+
security_group_ids = estimator.security_group_ids
118+
elif model is not None and model.vpc_config is not None:
119+
subnets = model.vpc_config["Subnets"]
120+
security_group_ids = model.vpc_config["SecurityGroupIds"]
121+
112122
if "entry_point" in kwargs:
113123
repack_model = True
114124
entry_point = kwargs.pop("entry_point", None)
@@ -125,6 +135,9 @@ def __init__(
125135
entry_point=entry_point,
126136
source_dir=source_dir,
127137
dependencies=dependencies,
138+
tags=tags,
139+
subnets=subnets,
140+
security_group_ids=security_group_ids,
128141
**kwargs,
129142
)
130143
steps.append(repack_model_step)
@@ -163,6 +176,9 @@ def __init__(
163176
entry_point=entry_point,
164177
source_dir=source_dir,
165178
dependencies=dependencies,
179+
tags=tags,
180+
subnets=subnets,
181+
security_group_ids=security_group_ids,
166182
**kwargs,
167183
)
168184
steps.append(repack_model_step)
@@ -283,6 +299,9 @@ def __init__(
283299
entry_point=entry_point,
284300
source_dir=source_dir,
285301
dependencies=dependencies,
302+
tags=tags,
303+
subnets=estimator.subnets,
304+
security_group_ids=estimator.security_group_ids,
286305
)
287306
steps.append(repack_model_step)
288307
model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts

0 commit comments

Comments
 (0)