Skip to content

fix: Propagate tags and VPC configs to repack model steps #2594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 70 additions & 13 deletions src/sagemaker/workflow/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,10 @@


class _RepackModelStep(TrainingStep):
"""Repacks model artifacts with inference entry point.
"""Repacks model artifacts with custom inference entry points.

Attributes:
name (str): The name of the training step.
step_type (StepTypeEnum): The type of the step with value `StepTypeEnum.Training`.
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`.
The SDK automatically adds this step to pipelines that have RegisterModelSteps with models
that have a custom entry point.
"""

def __init__(
Expand All @@ -61,19 +58,77 @@ def __init__(
source_dir: str = None,
dependencies: List = None,
depends_on: Union[List[str], List[Step]] = None,
subnets=None,
security_group_ids=None,
**kwargs,
):
"""Constructs a TrainingStep, given an `EstimatorBase` instance.

In addition to the estimator instance, the other arguments are those that are supplied to
the `fit` method of the `sagemaker.estimator.Estimator`.
"""Base class initializer.

Args:
name (str): The name of the training step.
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`.
sagemaker_session (sagemaker.session.Session): Session object which manages
interactions with Amazon SageMaker APIs and any other AWS services needed. If
not specified, the estimator creates one using the default
AWS configuration chain.
role (str): An AWS IAM role (either name or full ARN). The Amazon
SageMaker training jobs and APIs that create Amazon SageMaker
endpoints use this role to access training data and model
artifacts. After the endpoint is created, the inference code
might use the IAM role, if it needs to access an AWS resource.
model_data (str): The S3 location of a SageMaker model data
``.tar.gz`` file (default: None).
entry_point (str): Path (absolute or relative) to the local Python
source file which should be executed as the entry point to
inference. If ``source_dir`` is specified, then ``entry_point``
must point to a file located at the root of ``source_dir``.
If 'git_config' is provided, 'entry_point' should be
a relative location to the Python source file in the Git repo.

Example:
With the following GitHub repo directory structure:

>>> |----- README.md
>>> |----- src
>>> |----- train.py
>>> |----- test.py

You can assign entry_point='src/train.py'.
source_dir (str): A relative location to a directory with other training
or model hosting source code dependencies aside from the entry point
file in the Git repo (default: None). Structure within this
directory are preserved when training on Amazon SageMaker.
dependencies (list[str]): A list of paths to directories (absolute
or relative) with any additional libraries that will be exported
to the container (default: []). The library folders will be
copied to SageMaker in the same folder where the entrypoint is
copied. If 'git_config' is provided, 'dependencies' should be a
list of relative locations to directories with any additional
libraries needed in the Git repo.

.. admonition:: Example

The following call

>>> Estimator(entry_point='train.py',
... dependencies=['my/libs/common', 'virtual-env'])

results in the following inside the container:

>>> $ ls

>>> opt/ml/code
>>> |------ train.py
>>> |------ common
>>> |------ virtual-env

This is not supported with "local code" in Local Mode.
depends_on (List[str] or List[Step]): A list of step names or instances
this step depends on
subnets (list[str]): List of subnet ids. If not specified, the re-packing
job will be created without VPC config.
security_group_ids (list[str]): List of security group ids. If not
specified, the re-packing job will be created without VPC config.
"""
# yeah, go ahead and save the originals for now
self._model_data = model_data
self.sagemaker_session = sagemaker_session
self.role = role
Expand Down Expand Up @@ -101,6 +156,8 @@ def __init__(
"inference_script": self._entry_point_basename,
"model_archive": self._model_archive,
},
subnets=subnets,
security_group_ids=security_group_ids,
**kwargs,
)
repacker.disable_profiler = True
Expand Down
21 changes: 20 additions & 1 deletion src/sagemaker/workflow/step_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(
compile_model_family=None,
description=None,
tags=None,
model=None,
model: Union[Model, PipelineModel] = None,
**kwargs,
):
"""Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator.
Expand Down Expand Up @@ -109,6 +109,16 @@ def __init__(
repack_model = False
self.model_list = None
self.container_def_list = None
subnets = None
security_group_ids = None

if estimator is not None:
subnets = estimator.subnets
security_group_ids = estimator.security_group_ids
elif model is not None and model.vpc_config is not None:
subnets = model.vpc_config["Subnets"]
security_group_ids = model.vpc_config["SecurityGroupIds"]

if "entry_point" in kwargs:
repack_model = True
entry_point = kwargs.pop("entry_point", None)
Expand All @@ -125,6 +135,9 @@ def __init__(
entry_point=entry_point,
source_dir=source_dir,
dependencies=dependencies,
tags=tags,
subnets=subnets,
security_group_ids=security_group_ids,
**kwargs,
)
steps.append(repack_model_step)
Expand Down Expand Up @@ -163,6 +176,9 @@ def __init__(
entry_point=entry_point,
source_dir=source_dir,
dependencies=dependencies,
tags=tags,
subnets=subnets,
security_group_ids=security_group_ids,
**kwargs,
)
steps.append(repack_model_step)
Expand Down Expand Up @@ -283,6 +299,9 @@ def __init__(
entry_point=entry_point,
source_dir=source_dir,
dependencies=dependencies,
tags=tags,
subnets=estimator.subnets,
security_group_ids=estimator.security_group_ids,
)
steps.append(repack_model_step)
model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts
Expand Down
Loading