42
42
43
43
44
44
class _RepackModelStep (TrainingStep ):
45
- """Repacks model artifacts with inference entry point .
45
+ """Repacks model artifacts with custom inference entry points .
46
46
47
- Attributes:
48
- name (str): The name of the training step.
49
- step_type (StepTypeEnum): The type of the step with value `StepTypeEnum.Training`.
50
- estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
51
- inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`.
47
+ The SDK automatically adds this step to pipelines that have RegisterModelSteps with models
48
+ that have a custom entry point.
52
49
"""
53
50
54
51
def __init__ (
@@ -61,19 +58,77 @@ def __init__(
61
58
source_dir : str = None ,
62
59
dependencies : List = None ,
63
60
depends_on : Union [List [str ], List [Step ]] = None ,
61
+ subnets = None ,
62
+ security_group_ids = None ,
64
63
** kwargs ,
65
64
):
66
- """Constructs a TrainingStep, given an `EstimatorBase` instance.
67
-
68
- In addition to the estimator instance, the other arguments are those that are supplied to
69
- the `fit` method of the `sagemaker.estimator.Estimator`.
65
+ """Base class initializer.
70
66
71
67
Args:
72
68
name (str): The name of the training step.
73
- estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
74
- inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`.
69
+ sagemaker_session (sagemaker.session.Session): Session object which manages
70
+ interactions with Amazon SageMaker APIs and any other AWS services needed. If
71
+ not specified, the estimator creates one using the default
72
+ AWS configuration chain.
73
+ role (str): An AWS IAM role (either name or full ARN). The Amazon
74
+ SageMaker training jobs and APIs that create Amazon SageMaker
75
+ endpoints use this role to access training data and model
76
+ artifacts. After the endpoint is created, the inference code
77
+ might use the IAM role, if it needs to access an AWS resource.
78
+ model_data (str): The S3 location of a SageMaker model data
79
+ ``.tar.gz`` file (default: None).
80
+ entry_point (str): Path (absolute or relative) to the local Python
81
+ source file which should be executed as the entry point to
82
+ inference. If ``source_dir`` is specified, then ``entry_point``
83
+ must point to a file located at the root of ``source_dir``.
84
+ If 'git_config' is provided, 'entry_point' should be
85
+ a relative location to the Python source file in the Git repo.
86
+
87
+ Example:
88
+ With the following GitHub repo directory structure:
89
+
90
+ >>> |----- README.md
91
+ >>> |----- src
92
+ >>> |----- train.py
93
+ >>> |----- test.py
94
+
95
+ You can assign entry_point='src/train.py'.
96
+ source_dir (str): A relative location to a directory with other training
97
+ or model hosting source code dependencies aside from the entry point
98
+ file in the Git repo (default: None). Structure within this
99
+ directory are preserved when training on Amazon SageMaker.
100
+ dependencies (list[str]): A list of paths to directories (absolute
101
+ or relative) with any additional libraries that will be exported
102
+ to the container (default: []). The library folders will be
103
+ copied to SageMaker in the same folder where the entrypoint is
104
+ copied. If 'git_config' is provided, 'dependencies' should be a
105
+ list of relative locations to directories with any additional
106
+ libraries needed in the Git repo.
107
+
108
+ .. admonition:: Example
109
+
110
+ The following call
111
+
112
+ >>> Estimator(entry_point='train.py',
113
+ ... dependencies=['my/libs/common', 'virtual-env'])
114
+
115
+ results in the following inside the container:
116
+
117
+ >>> $ ls
118
+
119
+ >>> opt/ml/code
120
+ >>> |------ train.py
121
+ >>> |------ common
122
+ >>> |------ virtual-env
123
+
124
+ This is not supported with "local code" in Local Mode.
125
+ depends_on (List[str] or List[Step]): A list of step names or instances
126
+ this step depends on
127
+ subnets (list[str]): List of subnet ids. If not specified, the re-packing
128
+ job will be created without VPC config.
129
+ security_group_ids (list[str]): List of security group ids. If not
130
+ specified, the re-packing job will be created without VPC config.
75
131
"""
76
- # yeah, go ahead and save the originals for now
77
132
self ._model_data = model_data
78
133
self .sagemaker_session = sagemaker_session
79
134
self .role = role
@@ -101,6 +156,8 @@ def __init__(
101
156
"inference_script" : self ._entry_point_basename ,
102
157
"model_archive" : self ._model_archive ,
103
158
},
159
+ subnets = subnets ,
160
+ security_group_ids = security_group_ids ,
104
161
** kwargs ,
105
162
)
106
163
repacker .disable_profiler = True
0 commit comments