17
17
from collections import deque
18
18
import time
19
19
import threading
20
- from typing import Callable , Dict , List , Optional , Tuple , Any
20
+ from typing import Callable , Dict , List , Optional , Tuple , Any , Union
21
21
import functools
22
22
import itertools
23
23
import inspect
39
39
from sagemaker .remote_function import logging_config
40
40
from sagemaker .utils import name_from_base , base_from_name
41
41
from sagemaker .remote_function .spark_config import SparkConfig
42
- from sagemaker .remote_function .workdir_config import WorkdirConfig
42
+ from sagemaker .remote_function .custom_file_filter import CustomFileFilter
43
43
44
44
_API_CALL_LIMIT = {
45
45
"SubmittingIntervalInSecs" : 1 ,
@@ -66,7 +66,7 @@ def remote(
66
66
environment_variables : Dict [str , str ] = None ,
67
67
image_uri : str = None ,
68
68
include_local_workdir : bool = False ,
69
- workdir_config : WorkdirConfig = None ,
69
+ custom_file_filter : Optional [ Union [ Callable [[ str , List ], List ], CustomFileFilter ]] = None ,
70
70
instance_count : int = 1 ,
71
71
instance_type : str = None ,
72
72
job_conda_env : str = None ,
@@ -87,7 +87,6 @@ def remote(
87
87
spark_config : SparkConfig = None ,
88
88
use_spot_instances = False ,
89
89
max_wait_time_in_seconds = None ,
90
- custom_file_filter : Optional [Callable [[str , List ], List ]] = None ,
91
90
):
92
91
"""Decorator for running the annotated function as a SageMaker training job.
93
92
@@ -195,10 +194,12 @@ def remote(
195
194
methods that are not available via PyPI or conda. Only python files are included.
196
195
Default value is ``False``.
197
196
198
- workdir_config (WorkdirConfig): A ``WorkdirConfig`` object that specifies the
199
- local directories and files to be included in the remote function.
200
- workdir_config takes precedence over include_local_workdir.
201
- Default value is ``None``.
197
+ custom_file_filter (Callable[[str, List], List], CustomFileFilter): Either a function
198
+ that filters job dependencies to be uploaded to S3 or a ``CustomFileFilter`` object
199
+ that specifies the local directories and files to be included in the remote function.
200
+ If a callable is passed in, that function is passed to the ``ignore`` argument of
201
+ ``shutil.copytree``. Defaults to ``None``, which means only python
202
+ files are accepted and uploaded to S3.
202
203
203
204
instance_count (int): The number of instances to use. Defaults to 1.
204
205
NOTE: Remote function does not support instance_count > 1 for non Spark jobs.
@@ -274,11 +275,6 @@ def remote(
274
275
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
275
276
After this amount of time Amazon SageMaker will stop waiting for managed spot training
276
277
job to complete. Defaults to ``None``.
277
-
278
- custom_file_filter (Callable[[str, List], List]): A function that filters job
279
- dependencies to be uploaded to S3. This function is passed to the ``ignore``
280
- argument of ``shutil.copytree``. Defaults to ``None``, which means only python
281
- files are accepted.
282
278
"""
283
279
284
280
def _remote (func ):
@@ -290,7 +286,7 @@ def _remote(func):
290
286
environment_variables = environment_variables ,
291
287
image_uri = image_uri ,
292
288
include_local_workdir = include_local_workdir ,
293
- workdir_config = workdir_config ,
289
+ custom_file_filter = custom_file_filter ,
294
290
instance_count = instance_count ,
295
291
instance_type = instance_type ,
296
292
job_conda_env = job_conda_env ,
@@ -311,7 +307,6 @@ def _remote(func):
311
307
spark_config = spark_config ,
312
308
use_spot_instances = use_spot_instances ,
313
309
max_wait_time_in_seconds = max_wait_time_in_seconds ,
314
- custom_file_filter = custom_file_filter ,
315
310
)
316
311
317
312
@functools .wraps (func )
@@ -501,7 +496,7 @@ def __init__(
501
496
environment_variables : Dict [str , str ] = None ,
502
497
image_uri : str = None ,
503
498
include_local_workdir : bool = False ,
504
- workdir_config : WorkdirConfig = None ,
499
+ custom_file_filter : Optional [ Union [ Callable [[ str , List ], List ], CustomFileFilter ]] = None ,
505
500
instance_count : int = 1 ,
506
501
instance_type : str = None ,
507
502
job_conda_env : str = None ,
@@ -523,7 +518,6 @@ def __init__(
523
518
spark_config : SparkConfig = None ,
524
519
use_spot_instances = False ,
525
520
max_wait_time_in_seconds = None ,
526
- custom_file_filter : Optional [Callable [[str , List ], List ]] = None ,
527
521
):
528
522
"""Constructor for RemoteExecutor
529
523
@@ -628,10 +622,12 @@ def __init__(
628
622
local directories. Set to ``True`` if the remote function code imports local modules
629
623
and methods that are not available via PyPI or conda. Default value is ``False``.
630
624
631
- workdir_config (WorkdirConfig): A ``WorkdirConfig`` object that specifies the
632
- local directories and files to be included in the remote function.
633
- workdir_config takes precedence over include_local_workdir.
634
- Default value is ``None``.
625
+ custom_file_filter (Callable[[str, List], List], CustomFileFilter): Either a function
626
+ that filters job dependencies to be uploaded to S3 or a ``CustomFileFilter`` object
627
+ that specifies the local directories and files to be included in the remote function.
628
+ If a callable is passed in, that function is passed to the ``ignore`` argument of
629
+ ``shutil.copytree``. Defaults to ``None``, which means only python
630
+ files are accepted and uploaded to S3.
635
631
636
632
instance_count (int): The number of instances to use. Defaults to 1.
637
633
NOTE: Remote function does not support instance_count > 1 for non Spark jobs.
@@ -715,11 +711,6 @@ def __init__(
715
711
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
716
712
After this amount of time Amazon SageMaker will stop waiting for managed spot training
717
713
job to complete. Defaults to ``None``.
718
-
719
- custom_file_filter (Callable[[str, List], List]): A function that filters job
720
- dependencies to be uploaded to S3. This function is passed to the ``ignore``
721
- argument of ``shutil.copytree``. Defaults to ``None``, which means only python
722
- files are accepted.
723
714
"""
724
715
self .max_parallel_jobs = max_parallel_jobs
725
716
@@ -739,7 +730,7 @@ def __init__(
739
730
environment_variables = environment_variables ,
740
731
image_uri = image_uri ,
741
732
include_local_workdir = include_local_workdir ,
742
- workdir_config = workdir_config ,
733
+ custom_file_filter = custom_file_filter ,
743
734
instance_count = instance_count ,
744
735
instance_type = instance_type ,
745
736
job_conda_env = job_conda_env ,
@@ -760,7 +751,6 @@ def __init__(
760
751
spark_config = spark_config ,
761
752
use_spot_instances = use_spot_instances ,
762
753
max_wait_time_in_seconds = max_wait_time_in_seconds ,
763
- custom_file_filter = custom_file_filter ,
764
754
)
765
755
766
756
self ._state_condition = threading .Condition ()
0 commit comments