Skip to content

Commit ccf4c1e

Browse files
committed
fix: address requested changes
1 parent f5c0538 commit ccf4c1e

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

src/sagemaker/workflow/steps.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,21 +95,33 @@ def ref(self) -> Dict[str, str]:
9595

9696
@attr.s
9797
class CacheConfig:
98-
"""Step to cache pipeline workflow.
98+
"""Configure steps to enable cache in pipeline workflow.
99+
100+
If caching is enabled, the pipeline attempts to find a previous execution of a step.
101+
If a successful previous execution is found, the pipeline propagates the values
102+
from previous execution rather than recomputing the step.
103+
99104
100105
Attributes:
101-
enable_caching (bool): To enable step caching. Off by default.
106+
enable_caching (bool): To enable step caching. Defaults to `False`.
102107
expire_after (str): If step caching is enabled, a timeout also needs to defined.
103108
It defines how old a previous execution can be to be considered for reuse.
104-
Needs to be ISO 8601 duration string.
109+
Value should be an ISO 8601 duration string.
110+
If step caching is disabled, it defaults to an empty string.
105111
"""
106112

107113
enable_caching: bool = attr.ib(default=False)
108-
expire_after: str = attr.ib(factory=str)
114+
expire_after = attr.ib(default="")
115+
116+
@expire_after.validator
117+
def validate_expire_after(self, enable_caching, expire_after):
118+
"""Validates ISO 8601 duration string."""
119+
if enable_caching and expire_after == "":
120+
raise ValueError("expire_after must be an ISO 8601 duration string")
109121

110122
@property
111123
def config(self):
112-
"""Enables caching in pipeline steps."""
124+
"""Configures caching in pipeline steps."""
113125
return {"CacheConfig": {"Enabled": self.enable_caching, "ExpireAfter": self.expire_after}}
114126

115127

@@ -132,7 +144,7 @@ def __init__(
132144
name (str): The name of the training step.
133145
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
134146
inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`.
135-
cache_config (CacheConfig): A `sagemaker.steps.CacheConfig` instance to enable caching.
147+
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
136148
"""
137149
super(TrainingStep, self).__init__(name, StepTypeEnum.TRAINING)
138150
self.estimator = estimator
@@ -249,7 +261,7 @@ def __init__(
249261
name (str): The name of the transform step.
250262
transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
251263
inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
252-
cache_config (CacheConfig): An instance to enable caching.
264+
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
253265
"""
254266
super(TransformStep, self).__init__(name, StepTypeEnum.TRANSFORM)
255267
self.transformer = transformer
@@ -331,7 +343,7 @@ def __init__(
331343
script to run. Defaults to `None`.
332344
property_files (List[PropertyFile]): A list of property files that workflow looks
333345
for and resolves from the configured processing output list.
334-
cache_config (CacheConfig): An instance to enable caching.
346+
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
335347
"""
336348
super(ProcessingStep, self).__init__(name, StepTypeEnum.PROCESSING)
337349
self.processor = processor

tests/integ/test_workflow.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,12 @@
4444
ParameterInteger,
4545
ParameterString,
4646
)
47-
from sagemaker.workflow.steps import CreateModelStep, ProcessingStep, TrainingStep, CacheConfig
47+
from sagemaker.workflow.steps import (
48+
CreateModelStep,
49+
ProcessingStep,
50+
TrainingStep,
51+
CacheConfig,
52+
)
4853
from sagemaker.workflow.step_collections import RegisterModel
4954
from sagemaker.workflow.pipeline import Pipeline
5055
from tests.integ import DATA_DIR

0 commit comments

Comments
 (0)