Skip to content

Commit e1d0d45

Browse files
committed
add cache config and unit test
1 parent b28866d commit e1d0d45

File tree

3 files changed

+133
-8
lines changed

3 files changed

+133
-8
lines changed

src/sagemaker/workflow/steps.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class Step(Entity):
6464
Attributes:
6565
name (str): The name of the step.
6666
step_type (StepTypeEnum): The type of the step.
67+
6768
"""
6869

6970
name: str = attr.ib(factory=str)
@@ -93,6 +94,26 @@ def ref(self) -> Dict[str, str]:
9394
return {"Name": self.name}
9495

9596

97+
@attr.s
98+
class CacheConfig:
99+
"""Step to cache pipeline workflow.
100+
101+
Attributes:
102+
enable_caching (bool): To enable step caching. Off by default.
103+
expire_after (str): If step caching is enabled, a timeout also needs to defined.
104+
It defines how old a previous execution can be to be considered for reuse.
105+
Needs to be ISO 8601 duration string.
106+
"""
107+
108+
enable_caching: bool = attr.ib(default=False)
109+
expire_after: str = attr.ib(factory=str)
110+
111+
@property
112+
def config(self):
113+
"""Enables caching in pipeline steps."""
114+
return {"CacheConfig": {"Enabled": self.enable_caching, "ExpireAfter": self.expire_after}}
115+
116+
96117
class TrainingStep(Step):
97118
"""Training step for workflow."""
98119

@@ -101,6 +122,7 @@ def __init__(
101122
name: str,
102123
estimator: EstimatorBase,
103124
inputs: TrainingInput = None,
125+
cache_config: CacheConfig = None,
104126
):
105127
"""Construct a TrainingStep, given an `EstimatorBase` instance.
106128
@@ -111,14 +133,15 @@ def __init__(
111133
name (str): The name of the training step.
112134
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
113135
inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`.
136+
cache_config (CacheConfig): An instance to enable caching.
114137
"""
115138
super(TrainingStep, self).__init__(name, StepTypeEnum.TRAINING)
116139
self.estimator = estimator
117140
self.inputs = inputs
118-
119141
self._properties = Properties(
120142
path=f"Steps.{name}", shape_name="DescribeTrainingJobResponse"
121143
)
144+
self.cache_config = cache_config
122145

123146
@property
124147
def arguments(self) -> RequestType:
@@ -145,6 +168,13 @@ def properties(self):
145168
"""A Properties object representing the DescribeTrainingJobResponse data model."""
146169
return self._properties
147170

171+
def to_request(self) -> RequestType:
172+
"""Updates the dictionary with cache configuration."""
173+
request_dict = super().to_request()
174+
request_dict.update(self.cache_config.config)
175+
176+
return request_dict
177+
148178

149179
class CreateModelStep(Step):
150180
"""CreateModel step for workflow."""
@@ -208,6 +238,7 @@ def __init__(
208238
name: str,
209239
transformer: Transformer,
210240
inputs: TransformInput,
241+
cache_config: CacheConfig = None,
211242
):
212243
"""Constructs a TransformStep, given an `Transformer` instance.
213244
@@ -218,11 +249,12 @@ def __init__(
218249
name (str): The name of the transform step.
219250
transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
220251
inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
252+
cache_config (CacheConfig): An instance to enable caching.
221253
"""
222254
super(TransformStep, self).__init__(name, StepTypeEnum.TRANSFORM)
223255
self.transformer = transformer
224256
self.inputs = inputs
225-
257+
self.cache_config = cache_config
226258
self._properties = Properties(
227259
path=f"Steps.{name}", shape_name="DescribeTransformJobResponse"
228260
)
@@ -258,6 +290,13 @@ def properties(self):
258290
"""A Properties object representing the DescribeTransformJobResponse data model."""
259291
return self._properties
260292

293+
def to_request(self) -> RequestType:
294+
"""Updates the dictionary with cache configuration."""
295+
request_dict = super().to_request()
296+
request_dict.update(self.cache_config.config)
297+
298+
return request_dict
299+
261300

262301
class ProcessingStep(Step):
263302
"""Processing step for workflow."""
@@ -271,6 +310,7 @@ def __init__(
271310
job_arguments: List[str] = None,
272311
code: str = None,
273312
property_files: List[PropertyFile] = None,
313+
cache_config: CacheConfig = None,
274314
):
275315
"""Construct a ProcessingStep, given a `Processor` instance.
276316
@@ -290,6 +330,7 @@ def __init__(
290330
script to run. Defaults to `None`.
291331
property_files (List[PropertyFile]): A list of property files that workflow looks
292332
for and resolves from the configured processing output list.
333+
cache_config (CacheConfig): An instance to enable caching.
293334
"""
294335
super(ProcessingStep, self).__init__(name, StepTypeEnum.PROCESSING)
295336
self.processor = processor
@@ -306,6 +347,7 @@ def __init__(
306347
self._properties = Properties(
307348
path=f"Steps.{name}", shape_name="DescribeProcessingJobResponse"
308349
)
350+
self.cache_config = cache_config
309351

310352
@property
311353
def arguments(self) -> RequestType:
@@ -336,6 +378,7 @@ def properties(self):
336378
def to_request(self) -> RequestType:
337379
"""Get the request structure for workflow service calls."""
338380
request_dict = super(ProcessingStep, self).to_request()
381+
request_dict.update(self.cache_config.config)
339382
if self.property_files:
340383
request_dict["PropertyFiles"] = [
341384
property_file.expr for property_file in self.property_files

tests/integ/test_workflow.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
CreateModelStep,
4949
ProcessingStep,
5050
TrainingStep,
51+
CacheConfig
5152
)
5253
from sagemaker.workflow.step_collections import RegisterModel
5354
from sagemaker.workflow.pipeline import Pipeline
@@ -551,3 +552,80 @@ def test_training_job_with_debugger(
551552
pipeline.delete()
552553
except Exception:
553554
pass
555+
556+
557+
def test_cache_hit_expired_entry(
558+
sagemaker_session,
559+
workflow_session,
560+
region_name,
561+
role,
562+
script_dir,
563+
pipeline_name,
564+
):
565+
566+
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
567+
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
568+
569+
570+
estimator =
571+
572+
step_train = TrainingStep(
573+
name="my-train",
574+
estimator=sklearn_train,
575+
inputs=TrainingInput(
576+
s3_data=step_process.properties.ProcessingOutputConfig.Outputs[
577+
"train_data"
578+
].S3Output.S3Uri
579+
),
580+
cache_config=
581+
)
582+
pipeline = Pipeline(
583+
name=pipeline_name,
584+
parameters=[instance_type, instance_count],
585+
steps=[step_train],
586+
sagemaker_session=workflow_session,
587+
)
588+
589+
try:
590+
# NOTE: We should exercise the case when role used in the pipeline execution is
591+
# different than that required of the steps in the pipeline itself. The role in
592+
# the pipeline definition needs to create training and processing jobs and other
593+
# sagemaker entities. However, the jobs created in the steps themselves execute
594+
# under a potentially different role, often requiring access to S3 and other
595+
# artifacts not required to during creation of the jobs in the pipeline steps.
596+
response = pipeline.create(role)
597+
create_arn = response["PipelineArn"]
598+
assert re.match(
599+
fr"arn:aws:sagemaker:{region}:\d{{12}}:pipeline/{pipeline_name}",
600+
create_arn,
601+
)
602+
603+
pipeline.parameters = [ParameterInteger(name="InstanceCount", default_value=1)]
604+
response = pipeline.update(role)
605+
update_arn = response["PipelineArn"]
606+
assert re.match(
607+
fr"arn:aws:sagemaker:{region}:\d{{12}}:pipeline/{pipeline_name}",
608+
update_arn,
609+
)
610+
611+
execution = pipeline.start(parameters={})
612+
assert re.match(
613+
fr"arn:aws:sagemaker:{region}:\d{{12}}:pipeline/{pipeline_name}/execution/",
614+
execution.arn,
615+
)
616+
617+
response = execution.describe()
618+
assert response["PipelineArn"] == create_arn
619+
620+
try:
621+
execution.wait(delay=30, max_attempts=3)
622+
except WaiterError:
623+
pass
624+
execution_steps = execution.list_steps()
625+
assert len(execution_steps) == 1
626+
assert execution_steps[0]["StepName"] == "sklearn-process"
627+
finally:
628+
try:
629+
pipeline.delete()
630+
except Exception:
631+
pass

tests/unit/sagemaker/workflow/test_steps.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
TrainingStep,
3838
TransformStep,
3939
CreateModelStep,
40+
CacheConfig,
4041
)
4142

4243
REGION = "us-west-2"
@@ -114,10 +115,9 @@ def test_training_step(sagemaker_session):
114115
sagemaker_session=sagemaker_session,
115116
)
116117
inputs = TrainingInput(f"s3://{BUCKET}/train_manifest")
118+
cache_config = CacheConfig(enable_caching=False, expire_after="PT1H")
117119
step = TrainingStep(
118-
name="MyTrainingStep",
119-
estimator=estimator,
120-
inputs=inputs,
120+
name="MyTrainingStep", estimator=estimator, inputs=inputs, cache_config=cache_config
121121
)
122122
assert step.to_request() == {
123123
"Name": "MyTrainingStep",
@@ -145,6 +145,7 @@ def test_training_step(sagemaker_session):
145145
"RoleArn": ROLE,
146146
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
147147
},
148+
"CacheConfig": {"Enabled": False, "ExpireAfter": "PT1H"},
148149
}
149150
assert step.properties.TrainingJobName.expr == {"Get": "Steps.MyTrainingStep.TrainingJobName"}
150151

@@ -163,11 +164,13 @@ def test_processing_step(sagemaker_session):
163164
destination="processing_manifest",
164165
)
165166
]
167+
cache_config = CacheConfig(enable_caching=False, expire_after="PT1H")
166168
step = ProcessingStep(
167169
name="MyProcessingStep",
168170
processor=processor,
169171
inputs=inputs,
170172
outputs=[],
173+
cache_config=cache_config,
171174
)
172175
assert step.to_request() == {
173176
"Name": "MyProcessingStep",
@@ -197,6 +200,7 @@ def test_processing_step(sagemaker_session):
197200
},
198201
"RoleArn": "DummyRole",
199202
},
203+
"CacheConfig": {"Enabled": False, "ExpireAfter": "PT1H"},
200204
}
201205
assert step.properties.ProcessingJobName.expr == {
202206
"Get": "Steps.MyProcessingStep.ProcessingJobName"
@@ -238,10 +242,9 @@ def test_transform_step(sagemaker_session):
238242
sagemaker_session=sagemaker_session,
239243
)
240244
inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest")
245+
cache_config = CacheConfig(enable_caching=False, expire_after="PT1H")
241246
step = TransformStep(
242-
name="MyTransformStep",
243-
transformer=transformer,
244-
inputs=inputs,
247+
name="MyTransformStep", transformer=transformer, inputs=inputs, cache_config=cache_config
245248
)
246249
assert step.to_request() == {
247250
"Name": "MyTransformStep",
@@ -262,6 +265,7 @@ def test_transform_step(sagemaker_session):
262265
"InstanceType": "c4.4xlarge",
263266
},
264267
},
268+
"CacheConfig": {"Enabled": False, "ExpireAfter": "PT1H"},
265269
}
266270
assert step.properties.TransformJobName.expr == {
267271
"Get": "Steps.MyTransformStep.TransformJobName"

0 commit comments

Comments
 (0)