@@ -64,6 +64,7 @@ class Step(Entity):
64
64
Attributes:
65
65
name (str): The name of the step.
66
66
step_type (StepTypeEnum): The type of the step.
67
+
67
68
"""
68
69
69
70
name : str = attr .ib (factory = str )
@@ -93,6 +94,26 @@ def ref(self) -> Dict[str, str]:
93
94
return {"Name" : self .name }
94
95
95
96
97
+ @attr .s
98
+ class CacheConfig :
99
+ """Step to cache pipeline workflow.
100
+
101
+ Attributes:
102
+ enable_caching (bool): To enable step caching. Off by default.
103
+ expire_after (str): If step caching is enabled, a timeout also needs to defined.
104
+ It defines how old a previous execution can be to be considered for reuse.
105
+ Needs to be ISO 8601 duration string.
106
+ """
107
+
108
+ enable_caching : bool = attr .ib (default = False )
109
+ expire_after : str = attr .ib (factory = str )
110
+
111
+ @property
112
+ def config (self ):
113
+ """Enables caching in pipeline steps."""
114
+ return {"CacheConfig" : {"Enabled" : self .enable_caching , "ExpireAfter" : self .expire_after }}
115
+
116
+
96
117
class TrainingStep (Step ):
97
118
"""Training step for workflow."""
98
119
@@ -101,6 +122,7 @@ def __init__(
101
122
name : str ,
102
123
estimator : EstimatorBase ,
103
124
inputs : TrainingInput = None ,
125
+ cache_config : CacheConfig = None ,
104
126
):
105
127
"""Construct a TrainingStep, given an `EstimatorBase` instance.
106
128
@@ -111,14 +133,15 @@ def __init__(
111
133
name (str): The name of the training step.
112
134
estimator (EstimatorBase): A `sagemaker.estimator.EstimatorBase` instance.
113
135
inputs (TrainingInput): A `sagemaker.inputs.TrainingInput` instance. Defaults to `None`.
136
+ cache_config (CacheConfig): An instance to enable caching.
114
137
"""
115
138
super (TrainingStep , self ).__init__ (name , StepTypeEnum .TRAINING )
116
139
self .estimator = estimator
117
140
self .inputs = inputs
118
-
119
141
self ._properties = Properties (
120
142
path = f"Steps.{ name } " , shape_name = "DescribeTrainingJobResponse"
121
143
)
144
+ self .cache_config = cache_config
122
145
123
146
@property
124
147
def arguments (self ) -> RequestType :
@@ -145,6 +168,13 @@ def properties(self):
145
168
"""A Properties object representing the DescribeTrainingJobResponse data model."""
146
169
return self ._properties
147
170
171
+ def to_request (self ) -> RequestType :
172
+ """Updates the dictionary with cache configuration."""
173
+ request_dict = super ().to_request ()
174
+ request_dict .update (self .cache_config .config )
175
+
176
+ return request_dict
177
+
148
178
149
179
class CreateModelStep (Step ):
150
180
"""CreateModel step for workflow."""
@@ -208,6 +238,7 @@ def __init__(
208
238
name : str ,
209
239
transformer : Transformer ,
210
240
inputs : TransformInput ,
241
+ cache_config : CacheConfig = None ,
211
242
):
212
243
"""Constructs a TransformStep, given an `Transformer` instance.
213
244
@@ -218,11 +249,12 @@ def __init__(
218
249
name (str): The name of the transform step.
219
250
transformer (Transformer): A `sagemaker.transformer.Transformer` instance.
220
251
inputs (TransformInput): A `sagemaker.inputs.TransformInput` instance.
252
+ cache_config (CacheConfig): An instance to enable caching.
221
253
"""
222
254
super (TransformStep , self ).__init__ (name , StepTypeEnum .TRANSFORM )
223
255
self .transformer = transformer
224
256
self .inputs = inputs
225
-
257
+ self . cache_config = cache_config
226
258
self ._properties = Properties (
227
259
path = f"Steps.{ name } " , shape_name = "DescribeTransformJobResponse"
228
260
)
@@ -258,6 +290,13 @@ def properties(self):
258
290
"""A Properties object representing the DescribeTransformJobResponse data model."""
259
291
return self ._properties
260
292
293
+ def to_request (self ) -> RequestType :
294
+ """Updates the dictionary with cache configuration."""
295
+ request_dict = super ().to_request ()
296
+ request_dict .update (self .cache_config .config )
297
+
298
+ return request_dict
299
+
261
300
262
301
class ProcessingStep (Step ):
263
302
"""Processing step for workflow."""
@@ -271,6 +310,7 @@ def __init__(
271
310
job_arguments : List [str ] = None ,
272
311
code : str = None ,
273
312
property_files : List [PropertyFile ] = None ,
313
+ cache_config : CacheConfig = None ,
274
314
):
275
315
"""Construct a ProcessingStep, given a `Processor` instance.
276
316
@@ -290,6 +330,7 @@ def __init__(
290
330
script to run. Defaults to `None`.
291
331
property_files (List[PropertyFile]): A list of property files that workflow looks
292
332
for and resolves from the configured processing output list.
333
+ cache_config (CacheConfig): An instance to enable caching.
293
334
"""
294
335
super (ProcessingStep , self ).__init__ (name , StepTypeEnum .PROCESSING )
295
336
self .processor = processor
@@ -306,6 +347,7 @@ def __init__(
306
347
self ._properties = Properties (
307
348
path = f"Steps.{ name } " , shape_name = "DescribeProcessingJobResponse"
308
349
)
350
+ self .cache_config = cache_config
309
351
310
352
@property
311
353
def arguments (self ) -> RequestType :
@@ -336,6 +378,7 @@ def properties(self):
336
378
def to_request (self ) -> RequestType :
337
379
"""Get the request structure for workflow service calls."""
338
380
request_dict = super (ProcessingStep , self ).to_request ()
381
+ request_dict .update (self .cache_config .config )
339
382
if self .property_files :
340
383
request_dict ["PropertyFiles" ] = [
341
384
property_file .expr for property_file in self .property_files
0 commit comments