|
19 | 19 |
|
20 | 20 | from sagemaker.estimator import EstimatorBase
|
21 | 21 | from sagemaker.model import Model
|
| 22 | +from sagemaker import PipelineModel |
22 | 23 | from sagemaker.predictor import Predictor
|
23 | 24 | from sagemaker.transformer import Transformer
|
24 | 25 | from sagemaker.workflow.entities import RequestType
|
@@ -68,7 +69,7 @@ def __init__(
|
68 | 69 | compile_model_family=None,
|
69 | 70 | description=None,
|
70 | 71 | tags=None,
|
71 |
| - pipeline_model=None, |
| 72 | + model=None, |
72 | 73 | **kwargs,
|
73 | 74 | ):
|
74 | 75 | """Construct steps `_RepackModelStep` and `_RegisterModelStep` based on the estimator.
|
@@ -100,8 +101,8 @@ def __init__(
|
100 | 101 | that tags will only be applied to newly created model package groups; if the
|
101 | 102 | name of an existing group is passed to "model_package_group_name",
|
102 | 103 | tags will not be applied.
|
103 |
| - pipeline_model (object): A PipelineModel object that comprises a list of models |
104 |
| - which gets executed as a serial inference pipeline. |
| 104 | + model (object or Model): A PipelineModel object that comprises a list of models |
| 105 | + which gets executed as a serial inference pipeline or a Model object. |
105 | 106 | **kwargs: additional arguments to `create_model`.
|
106 | 107 | """
|
107 | 108 | steps: List[Step] = []
|
@@ -134,33 +135,39 @@ def __init__(
|
134 | 135 | kwargs.pop("dependencies", None)
|
135 | 136 | kwargs.pop("output_kms_key", None)
|
136 | 137 |
|
137 |
| - if pipeline_model is not None: |
138 |
| - self.model_list = pipeline_model.models |
139 |
| - for model in pipeline_model.models: |
| 138 | + if model is not None: |
| 139 | + if isinstance(model, PipelineModel): |
| 140 | + self.model_list = model.models |
| 141 | + elif isinstance(model, Model): |
| 142 | + self.model_list = [model] |
| 143 | + |
| 144 | + for model_entity in self.model_list: |
140 | 145 | if estimator is not None:
|
141 | 146 | sagemaker_session = estimator.sagemaker_session
|
142 | 147 | role = estimator.role
|
143 | 148 | else:
|
144 |
| - sagemaker_session = pipeline_model.sagemaker_session or model.sagemaker_session |
145 |
| - role = pipeline_model.role |
146 |
| - if hasattr(model, "entry_point"): |
| 149 | + sagemaker_session = model_entity.sagemaker_session |
| 150 | + role = model_entity.role |
| 151 | + if hasattr(model_entity, "entry_point"): |
147 | 152 | repack_model = True
|
148 |
| - entry_point = model.entry_point |
149 |
| - source_dir = model.source_dir |
150 |
| - dependencies = model.dependencies |
151 |
| - name = model.name or model._framework_name |
| 153 | + entry_point = model_entity.entry_point |
| 154 | + source_dir = model_entity.source_dir |
| 155 | + dependencies = model_entity.dependencies |
| 156 | + name = model_entity.name or model_entity._framework_name |
152 | 157 | repack_model_step = _RepackModelStep(
|
153 | 158 | name=f"{name}RepackModel",
|
154 | 159 | depends_on=depends_on,
|
155 | 160 | sagemaker_session=sagemaker_session,
|
156 | 161 | role=role,
|
157 |
| - model_data=model.model_data, |
| 162 | + model_data=model_entity.model_data, |
158 | 163 | entry_point=entry_point,
|
159 | 164 | source_dir=source_dir,
|
160 | 165 | dependencies=dependencies,
|
161 | 166 | )
|
162 | 167 | steps.append(repack_model_step)
|
163 |
| - model.model_data = repack_model_step.properties.ModelArtifacts.S3ModelArtifacts |
| 168 | + model_entity.model_data = ( |
| 169 | + repack_model_step.properties.ModelArtifacts.S3ModelArtifacts |
| 170 | + ) |
164 | 171 |
|
165 | 172 | register_model_step = _RegisterModelStep(
|
166 | 173 | name=name,
|
|
0 commit comments