|
18 | 18 | import tarfile
|
19 | 19 | import tempfile
|
20 | 20 |
|
21 |
| -from typing import List |
| 21 | +from typing import List, Union |
22 | 22 |
|
23 | 23 | from sagemaker import image_uris
|
24 | 24 | from sagemaker.inputs import TrainingInput
|
@@ -59,7 +59,7 @@ def __init__(
|
59 | 59 | entry_point: str,
|
60 | 60 | source_dir: str = None,
|
61 | 61 | dependencies: List = None,
|
62 |
| - depends_on: List[str] = None, |
| 62 | + depends_on: Union[List[str], List[Step]] = None, |
63 | 63 | **kwargs,
|
64 | 64 | ):
|
65 | 65 | """Constructs a TrainingStep, given an `EstimatorBase` instance.
|
@@ -226,7 +226,7 @@ def __init__(
|
226 | 226 | image_uri=None,
|
227 | 227 | compile_model_family=None,
|
228 | 228 | description=None,
|
229 |
| - depends_on: List[str] = None, |
| 229 | + depends_on: Union[List[str], List[Step]] = None, |
230 | 230 | tags=None,
|
231 | 231 | **kwargs,
|
232 | 232 | ):
|
@@ -255,8 +255,8 @@ def __init__(
|
255 | 255 | compile_model_family (str): Instance family for compiled model, if specified, a compiled
|
256 | 256 | model will be used (default: None).
|
257 | 257 | description (str): Model Package description (default: None).
|
258 |
| - depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep` |
259 |
| - depends on |
| 258 | + depends_on (List[str] or List[Step]): A list of step names or instances |
| 259 | + this step depends on |
260 | 260 | **kwargs: additional arguments to `create_model`.
|
261 | 261 | """
|
262 | 262 | super(_RegisterModelStep, self).__init__(name, StepTypeEnum.REGISTER_MODEL, depends_on)
|
|
0 commit comments