15
15
16
16
from concurrent .futures import ThreadPoolExecutor
17
17
from dataclasses import dataclass , field
18
- from typing import Any , Union , Dict , List , Tuple
18
+ from typing import Any , Dict , List , Tuple
19
19
20
- from sagemaker .s3 import s3_path_join
21
20
from sagemaker .remote_function .core .serialization import deserialize_obj_from_s3
21
+ from sagemaker .workflow .step_outputs import get_step
22
22
23
23
24
24
@dataclass
@@ -77,22 +77,11 @@ class _ExecutionVariable:
77
77
name : str
78
78
79
79
80
- @dataclass
81
- class _S3BaseUriIdentifier :
82
- """Identifies that the class refers to function step s3 base uri.
83
-
84
- The s3_base_uri = s3_root_uri + pipeline_name.
85
- This identifier is resolved in function step runtime by SDK.
86
- """
87
-
88
- NAME = "S3_BASE_URI"
89
-
90
-
91
80
@dataclass
92
81
class _DelayedReturn :
93
82
"""Delayed return from a function."""
94
83
95
- uri : List [ Union [ str , _Parameter , _ExecutionVariable ]]
84
+ uri : _Properties
96
85
reference_path : Tuple = field (default_factory = tuple )
97
86
98
87
@@ -164,26 +153,18 @@ def __init__(
164
153
self ,
165
154
delayed_returns : List [_DelayedReturn ],
166
155
hmac_key : str ,
167
- parameter_resolver : _ParameterResolver ,
168
- execution_variable_resolver : _ExecutionVariableResolver ,
169
- s3_base_uri : str ,
156
+ properties_resolver : _PropertiesResolver ,
170
157
** settings ,
171
158
):
172
159
"""Resolve delayed return.
173
160
174
161
Args:
175
162
delayed_returns: list of delayed returns to resolve.
176
163
hmac_key: key used to encrypt serialized and deserialized function and arguments.
177
- parameter_resolver: resolver used to pipeline parameters.
178
- execution_variable_resolver: resolver used to resolve execution variables.
179
- s3_base_uri (str): the s3 base uri of the function step that
180
- the serialized artifacts will be uploaded to.
181
- The s3_base_uri = s3_root_uri + pipeline_name.
164
+ properties_resolver: resolver used to resolve step properties.
182
165
**settings: settings to pass to the deserialization function.
183
166
"""
184
- self ._s3_base_uri = s3_base_uri
185
- self ._parameter_resolver = parameter_resolver
186
- self ._execution_variable_resolver = execution_variable_resolver
167
+ self ._properties_resolver = properties_resolver
187
168
# different delayed returns can have the same uri, so we need to dedupe
188
169
uris = {
189
170
self ._resolve_delayed_return_uri (delayed_return ) for delayed_return in delayed_returns
@@ -214,18 +195,7 @@ def resolve(self, delayed_return: _DelayedReturn) -> Any:
214
195
215
196
def _resolve_delayed_return_uri (self , delayed_return : _DelayedReturn ):
216
197
"""Resolve the s3 uri of the delayed return."""
217
-
218
- uri = []
219
- for component in delayed_return .uri :
220
- if isinstance (component , _Parameter ):
221
- uri .append (self ._parameter_resolver .resolve (component ))
222
- elif isinstance (component , _ExecutionVariable ):
223
- uri .append (self ._execution_variable_resolver .resolve (component ))
224
- elif isinstance (component , _S3BaseUriIdentifier ):
225
- uri .append (self ._s3_base_uri )
226
- else :
227
- uri .append (component )
228
- return s3_path_join (* uri )
198
+ return self ._properties_resolver .resolve (delayed_return .uri )
229
199
230
200
231
201
def _retrieve_child_item (delayed_return : _DelayedReturn , deserialized_obj : Any ):
@@ -241,7 +211,6 @@ def resolve_pipeline_variables(
241
211
func_args : Tuple ,
242
212
func_kwargs : Dict ,
243
213
hmac_key : str ,
244
- s3_base_uri : str ,
245
214
** settings ,
246
215
):
247
216
"""Resolve pipeline variables.
@@ -251,8 +220,6 @@ def resolve_pipeline_variables(
251
220
func_args: function args.
252
221
func_kwargs: function kwargs.
253
222
hmac_key: key used to encrypt serialized and deserialized function and arguments.
254
- s3_base_uri: the s3 base uri of the function step that the serialized artifacts
255
- will be uploaded to. The s3_base_uri = s3_root_uri + pipeline_name.
256
223
**settings: settings to pass to the deserialization function.
257
224
"""
258
225
@@ -274,9 +241,7 @@ def resolve_pipeline_variables(
274
241
delayed_return_resolver = _DelayedReturnResolver (
275
242
delayed_returns = delayed_returns ,
276
243
hmac_key = hmac_key ,
277
- parameter_resolver = parameter_resolver ,
278
- execution_variable_resolver = execution_variable_resolver ,
279
- s3_base_uri = s3_base_uri ,
244
+ properties_resolver = properties_resolver ,
280
245
** settings ,
281
246
)
282
247
@@ -322,39 +287,27 @@ def convert_pipeline_variables_to_pickleable(func_args: Tuple, func_kwargs: Dict
322
287
func_args: function args.
323
288
func_kwargs: function kwargs.
324
289
"""
290
+ converted_func_args = tuple (_convert_pipeline_variable_to_pickleable (arg ) for arg in func_args )
291
+ converted_func_kwargs = {
292
+ key : _convert_pipeline_variable_to_pickleable (arg ) for key , arg in func_kwargs .items ()
293
+ }
325
294
326
- from sagemaker .workflow .entities import PipelineVariable
327
-
328
- from sagemaker .workflow .execution_variables import ExecutionVariables
295
+ return converted_func_args , converted_func_kwargs
329
296
330
- from sagemaker .workflow .function_step import DelayedReturn
331
297
332
- # Notes:
333
- # 1. The s3_base_uri = s3_root_uri + pipeline_name, but the two may be unknown
334
- # when defining function steps. After step-level arg serialization,
335
- # it's hard to update the s3_base_uri in pipeline compile time.
336
- # Thus set a placeholder: _S3BaseUriIdentifier, and let the runtime job to resolve it.
337
- # 2. For saying s3_root_uri is unknown, it's because when defining function steps,
338
- # the pipeline's sagemaker_session is not passed in, but the default s3_root_uri
339
- # should be retrieved from the pipeline's sagemaker_session.
340
- def convert (arg ):
341
- if isinstance (arg , DelayedReturn ):
342
- return _DelayedReturn (
343
- uri = [
344
- _S3BaseUriIdentifier (),
345
- ExecutionVariables .PIPELINE_EXECUTION_ID ._pickleable ,
346
- arg ._step .name ,
347
- "results" ,
348
- ],
349
- reference_path = arg ._reference_path ,
350
- )
298
+ def _convert_pipeline_variable_to_pickleable (arg ):
299
+ """Convert a pipeline variable to pickleable."""
300
+ from sagemaker .workflow .entities import PipelineVariable
351
301
352
- if isinstance (arg , PipelineVariable ):
353
- return arg ._pickleable
302
+ from sagemaker .workflow .function_step import DelayedReturn
354
303
355
- return arg
304
+ if isinstance (arg , DelayedReturn ):
305
+ return _DelayedReturn (
306
+ uri = get_step (arg )._properties .OutputDataConfig .S3OutputPath ._pickleable ,
307
+ reference_path = arg ._reference_path ,
308
+ )
356
309
357
- converted_func_args = tuple ( convert ( arg ) for arg in func_args )
358
- converted_func_kwargs = { key : convert ( arg ) for key , arg in func_kwargs . items ()}
310
+ if isinstance ( arg , PipelineVariable ):
311
+ return arg . _pickleable
359
312
360
- return converted_func_args , converted_func_kwargs
313
+ return arg
0 commit comments