Skip to content

Commit 97b4d7e

Browse files
committed
support step DependsOn
1 parent 23201d9 commit 97b4d7e

File tree

9 files changed

+214
-17
lines changed

9 files changed

+214
-17
lines changed

src/sagemaker/workflow/_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(
5959
entry_point: str,
6060
source_dir: str = None,
6161
dependencies: List = None,
62+
depends_on: List[str] = None,
6263
):
6364
"""Constructs a TrainingStep, given an `EstimatorBase` instance.
6465
@@ -102,7 +103,9 @@ def __init__(
102103
inputs = TrainingInput(self._model_prefix)
103104

104105
# super!
105-
super(_RepackModelStep, self).__init__(name=name, estimator=repacker, inputs=inputs)
106+
super(_RepackModelStep, self).__init__(
107+
name=name, depends_on=depends_on, estimator=repacker, inputs=inputs
108+
)
106109

107110
def _prepare_for_repacking(self):
108111
"""Prepares the source for the estimator."""
@@ -221,6 +224,7 @@ def __init__(
221224
image_uri=None,
222225
compile_model_family=None,
223226
description=None,
227+
depends_on: List[str] = None,
224228
**kwargs,
225229
):
226230
"""Constructor of a register model step.
@@ -248,9 +252,10 @@ def __init__(
248252
compile_model_family (str): Instance family for compiled model, if specified, a compiled
249253
model will be used (default: None).
250254
description (str): Model Package description (default: None).
255+
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep` depends on
251256
**kwargs: additional arguments to `create_model`.
252257
"""
253-
super(_RegisterModelStep, self).__init__(name, StepTypeEnum.REGISTER_MODEL)
258+
super(_RegisterModelStep, self).__init__(name, StepTypeEnum.REGISTER_MODEL, depends_on)
254259
self.estimator = estimator
255260
self.model_data = model_data
256261
self.content_types = content_types

src/sagemaker/workflow/condition_step.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class ConditionStep(Step):
4040
def __init__(
4141
self,
4242
name: str,
43+
depends_on: List[str] = None,
4344
conditions: List[Condition] = None,
4445
if_steps: List[Union[Step, StepCollection]] = None,
4546
else_steps: List[Union[Step, StepCollection]] = None,
@@ -60,7 +61,7 @@ def __init__(
6061
and `sagemaker.workflow.step_collections.StepCollection` instances that are
6162
marked as ready for execution if the list of conditions evaluates to False.
6263
"""
63-
super(ConditionStep, self).__init__(name, StepTypeEnum.CONDITION)
64+
super(ConditionStep, self).__init__(name, StepTypeEnum.CONDITION, depends_on)
6465
self.conditions = conditions or []
6566
self.if_steps = if_steps or []
6667
self.else_steps = else_steps or []

src/sagemaker/workflow/step_collections.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(
6060
response_types,
6161
inference_instances,
6262
transform_instances,
63+
depends_on: List[str] = None,
6364
model_package_group_name=None,
6465
model_metrics=None,
6566
approval_status=None,
@@ -80,6 +81,7 @@ def __init__(
8081
generate inferences in real-time (default: None).
8182
transform_instances (list): A list of the instance types on which a transformation
8283
job can be run or on which an endpoint can be deployed (default: None).
84+
depends_on (List[str]): The list of step names the first step in the collection depends on
8385
model_package_group_name (str): The Model Package Group name, exclusive to
8486
`model_package_name`, using `model_package_group_name` makes the Model Package
8587
versioned (default: None).
@@ -94,12 +96,15 @@ def __init__(
9496
**kwargs: additional arguments to `create_model`.
9597
"""
9698
steps: List[Step] = []
99+
repack_model = False
97100
if "entry_point" in kwargs:
101+
repack_model = True
98102
entry_point = kwargs["entry_point"]
99103
source_dir = kwargs.get("source_dir")
100104
dependencies = kwargs.get("dependencies")
101105
repack_model_step = _RepackModelStep(
102106
name=f"{name}RepackModel",
107+
depends_on=depends_on,
103108
estimator=estimator,
104109
model_data=model_data,
105110
entry_point=entry_point,
@@ -130,6 +135,9 @@ def __init__(
130135
description=description,
131136
**kwargs,
132137
)
138+
if not repack_model:
139+
register_model_step.add_depends_on(depends_on)
140+
133141
steps.append(register_model_step)
134142
self.steps = steps
135143

@@ -160,6 +168,7 @@ def __init__(
160168
max_payload=None,
161169
tags=None,
162170
volume_kms_key=None,
171+
depends_on: List[str] = None,
163172
**kwargs,
164173
):
165174
"""Construct steps required for a Transformer step collection:
@@ -196,6 +205,7 @@ def __init__(
196205
it will be the format of the batch transform output.
197206
env (dict): The Environment variables to be set for use during the
198207
transform job (default: None).
208+
depends_on (List[str]): The list of step names the first step in the collection depends on
199209
"""
200210
steps = []
201211
if "entry_point" in kwargs:
@@ -204,6 +214,7 @@ def __init__(
204214
dependencies = kwargs.get("dependencies")
205215
repack_model_step = _RepackModelStep(
206216
name=f"{name}RepackModel",
217+
depends_on=depends_on,
207218
estimator=estimator,
208219
model_data=model_data,
209220
entry_point=entry_point,
@@ -232,6 +243,9 @@ def predict_wrapper(endpoint, session):
232243
model=model,
233244
inputs=model_inputs,
234245
)
246+
if "entry_point" not in kwargs and depends_on:
247+
# if the CreateModelStep is the first step in the collection
248+
model_step.add_depends_on(depends_on)
235249
steps.append(model_step)
236250

237251
transformer = Transformer(

src/sagemaker/workflow/steps.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,12 @@ class Step(Entity):
6363
Attributes:
6464
name (str): The name of the step.
6565
step_type (StepTypeEnum): The type of the step.
66+
depends_on (List[str]): The list of step names the current step depends on
6667
"""
6768

6869
name: str = attr.ib(factory=str)
6970
step_type: StepTypeEnum = attr.ib(factory=StepTypeEnum.factory)
71+
depends_on: List[str] = attr.ib(default=None)
7072

7173
@property
7274
@abc.abstractmethod
@@ -80,11 +82,21 @@ def properties(self):
8082

8183
def to_request(self) -> RequestType:
8284
"""Gets the request structure for workflow service calls."""
83-
return {
85+
request_dict = {
8486
"Name": self.name,
8587
"Type": self.step_type.value,
8688
"Arguments": self.arguments,
8789
}
90+
if self.depends_on:
91+
request_dict["DependsOn"] = self.depends_on
92+
return request_dict
93+
94+
def add_depends_on(self, step_names: List[str]):
95+
if not step_names:
96+
return
97+
if not self.depends_on:
98+
self.depends_on = []
99+
self.depends_on.extend(step_names)
88100

89101
@property
90102
def ref(self) -> Dict[str, str]:
@@ -133,6 +145,7 @@ def __init__(
133145
estimator: EstimatorBase,
134146
inputs: TrainingInput = None,
135147
cache_config: CacheConfig = None,
148+
depends_on: List[str] = None,
136149
):
137150
"""Construct a TrainingStep, given an `EstimatorBase` instance.
138151
@@ -144,8 +157,9 @@ def __init__(
144157
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
145158
inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`.
146159
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
160+
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TrainingStep` depends on
147161
"""
148-
super(TrainingStep, self).__init__(name, StepTypeEnum.TRAINING)
162+
super(TrainingStep, self).__init__(name, StepTypeEnum.TRAINING, depends_on)
149163
self.estimator = estimator
150164
self.inputs = inputs
151165
self._properties = Properties(
@@ -188,10 +202,7 @@ class CreateModelStep(Step):
188202
"""CreateModel step for workflow."""
189203

190204
def __init__(
191-
self,
192-
name: str,
193-
model: Model,
194-
inputs: CreateModelInput,
205+
self, name: str, model: Model, inputs: CreateModelInput, depends_on: List[str] = None
195206
):
196207
"""Construct a CreateModelStep, given an `sagemaker.model.Model` instance.
197208
@@ -203,8 +214,9 @@ def __init__(
203214
model (Model): A `sagemaker.model.Model` instance.
204215
inputs (CreateModelInput): A `sagemaker.inputs.CreateModelInput` instance.
205216
Defaults to `None`.
217+
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.CreateModelStep` depends on
206218
"""
207-
super(CreateModelStep, self).__init__(name, StepTypeEnum.CREATE_MODEL)
219+
super(CreateModelStep, self).__init__(name, StepTypeEnum.CREATE_MODEL, depends_on)
208220
self.model = model
209221
self.inputs = inputs or CreateModelInput()
210222

@@ -247,6 +259,7 @@ def __init__(
247259
transformer: Transformer,
248260
inputs: TransformInput,
249261
cache_config: CacheConfig = None,
262+
depends_on: List[str] = None,
250263
):
251264
"""Constructs a TransformStep, given an `Transformer` instance.
252265
@@ -258,8 +271,9 @@ def __init__(
258271
transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
259272
inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
260273
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
274+
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.TransformStep` depends on
261275
"""
262-
super(TransformStep, self).__init__(name, StepTypeEnum.TRANSFORM)
276+
super(TransformStep, self).__init__(name, StepTypeEnum.TRANSFORM, depends_on)
263277
self.transformer = transformer
264278
self.inputs = inputs
265279
self.cache_config = cache_config
@@ -320,6 +334,7 @@ def __init__(
320334
code: str = None,
321335
property_files: List[PropertyFile] = None,
322336
cache_config: CacheConfig = None,
337+
depends_on: List[str] = None,
323338
):
324339
"""Construct a ProcessingStep, given a `Processor` instance.
325340
@@ -340,8 +355,9 @@ def __init__(
340355
property_files (List[PropertyFile]): A list of property files that workflow looks
341356
for and resolves from the configured processing output list.
342357
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
358+
depends_on (List[str]): A list of step names this `sagemaker.workflow.steps.ProcessingStep` depends on
343359
"""
344-
super(ProcessingStep, self).__init__(name, StepTypeEnum.PROCESSING)
360+
super(ProcessingStep, self).__init__(name, StepTypeEnum.PROCESSING, depends_on)
345361
self.processor = processor
346362
self.inputs = inputs
347363
self.outputs = outputs

tests/integ/test_workflow.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
rule_configs,
2929
)
3030
from datetime import datetime
31-
from sagemaker.inputs import CreateModelInput, TrainingInput
31+
from sagemaker.inputs import CreateModelInput, TrainingInput, TransformInput
3232
from sagemaker.model import Model
3333
from sagemaker.processing import ProcessingInput, ProcessingOutput
3434
from sagemaker.pytorch.estimator import PyTorch
@@ -969,3 +969,110 @@ def test_training_job_with_debugger_and_profiler(
969969
pipeline.delete()
970970
except Exception:
971971
pass
972+
973+
974+
def test_two_processing_job_depends_on(
975+
sagemaker_session,
976+
role,
977+
pipeline_name,
978+
region_name,
979+
cpu_instance_type,
980+
):
981+
instance_count = ParameterInteger(name="InstanceCount", default_value=2)
982+
script_path = os.path.join(DATA_DIR, "dummy_script.py")
983+
984+
pyspark_processor = PySparkProcessor(
985+
base_job_name="sm-spark",
986+
framework_version="2.4",
987+
role=role,
988+
instance_count=instance_count,
989+
instance_type=cpu_instance_type,
990+
max_runtime_in_seconds=1200,
991+
sagemaker_session=sagemaker_session,
992+
)
993+
994+
spark_run_args = pyspark_processor.get_run_args(
995+
submit_app=script_path,
996+
arguments=[
997+
"--s3_input_bucket",
998+
sagemaker_session.default_bucket(),
999+
"--s3_input_key_prefix",
1000+
"spark-input",
1001+
"--s3_output_bucket",
1002+
sagemaker_session.default_bucket(),
1003+
"--s3_output_key_prefix",
1004+
"spark-output",
1005+
],
1006+
)
1007+
1008+
step_pyspark_1 = ProcessingStep(
1009+
name="pyspark-process-1",
1010+
processor=pyspark_processor,
1011+
inputs=spark_run_args.inputs,
1012+
outputs=spark_run_args.outputs,
1013+
job_arguments=spark_run_args.arguments,
1014+
code=spark_run_args.code,
1015+
)
1016+
1017+
step_pyspark_2 = ProcessingStep(
1018+
name="pyspark-process-2",
1019+
depends_on=[step_pyspark_1.name],
1020+
processor=pyspark_processor,
1021+
inputs=spark_run_args.inputs,
1022+
outputs=spark_run_args.outputs,
1023+
job_arguments=spark_run_args.arguments,
1024+
code=spark_run_args.code,
1025+
)
1026+
1027+
pipeline = Pipeline(
1028+
name=pipeline_name,
1029+
parameters=[instance_count],
1030+
steps=[step_pyspark_1, step_pyspark_2],
1031+
sagemaker_session=sagemaker_session,
1032+
)
1033+
1034+
try:
1035+
response = pipeline.create(role)
1036+
create_arn = response["PipelineArn"]
1037+
assert re.match(
1038+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
1039+
create_arn,
1040+
)
1041+
1042+
pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
1043+
response = pipeline.update(role)
1044+
update_arn = response["PipelineArn"]
1045+
assert re.match(
1046+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
1047+
update_arn,
1048+
)
1049+
1050+
execution = pipeline.start(parameters={})
1051+
assert re.match(
1052+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
1053+
execution.arn,
1054+
)
1055+
1056+
response = execution.describe()
1057+
assert response["PipelineArn"] == create_arn
1058+
1059+
try:
1060+
execution.wait(delay=60)
1061+
except WaiterError:
1062+
pass
1063+
1064+
execution_steps = execution.list_steps()
1065+
assert len(execution_steps) == 2
1066+
time_stamp = {}
1067+
for execution_step in execution_steps:
1068+
name = execution_step["StepName"]
1069+
if name == "pyspark-process-1":
1070+
time_stamp[name] = execution_step["EndTime"]
1071+
else:
1072+
time_stamp[name] = execution_step["StartTime"]
1073+
assert time_stamp["pyspark-process-1"] < time_stamp["pyspark-process-2"]
1074+
finally:
1075+
try:
1076+
pipeline.delete()
1077+
except Exception:
1078+
pass

tests/unit/sagemaker/workflow/test_condition_step.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,17 @@ def test_condition_step():
4343
step1 = CustomStep("MyStep1")
4444
step2 = CustomStep("MyStep2")
4545
cond_step = ConditionStep(
46-
name="MyConditionStep", conditions=[cond], if_steps=[step1], else_steps=[step2]
46+
name="MyConditionStep",
47+
depends_on=["TestStep"],
48+
conditions=[cond],
49+
if_steps=[step1],
50+
else_steps=[step2],
4751
)
52+
cond_step.add_depends_on(["SecondTestStep"])
4853
assert cond_step.to_request() == {
4954
"Name": "MyConditionStep",
5055
"Type": "Condition",
56+
"DependsOn": ["TestStep", "SecondTestStep"],
5157
"Arguments": {
5258
"Conditions": [
5359
{

0 commit comments

Comments
 (0)