Skip to content

Commit 447e531

Browse files
Namrata Madanbenieric
authored andcommitted
fix: merge WorkdirConfig and custom_filter_filter parameters
1 parent 3ee0b13 commit 447e531

File tree

14 files changed

+126
-145
lines changed

14 files changed

+126
-145
lines changed

src/sagemaker/config/config_schema.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
ENVIRONMENT_VARIABLES = "EnvironmentVariables"
5353
IMAGE_URI = "ImageUri"
5454
INCLUDE_LOCAL_WORKDIR = "IncludeLocalWorkDir"
55+
CUSTOM_FILE_FILTER = "CustomFileFilter"
5556
INSTANCE_TYPE = "InstanceType"
5657
S3_KMS_KEY_ID = "S3KmsKeyId"
5758
S3_ROOT_URI = "S3RootUri"
@@ -733,7 +734,7 @@ def _simple_path(*args: str):
733734
},
734735
IMAGE_URI: {TYPE: "string"},
735736
INCLUDE_LOCAL_WORKDIR: {TYPE: "boolean"},
736-
"WorkdirConfig": {
737+
CUSTOM_FILE_FILTER: {
737738
TYPE: OBJECT,
738739
ADDITIONAL_PROPERTIES: False,
739740
PROPERTIES: {

src/sagemaker/feature_store/feature_processor/_config_uploader.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""Contains classes for preparing and uploading configs for a scheduled feature processor."""
1414
from __future__ import absolute_import
15-
from typing import Callable, Dict, Optional, Tuple, List
15+
from typing import Callable, Dict, Optional, Tuple, List, Union
1616

1717
import attr
1818

@@ -38,7 +38,7 @@
3838
RuntimeEnvironmentManager,
3939
)
4040
from sagemaker.remote_function.spark_config import SparkConfig
41-
from sagemaker.remote_function.workdir_config import WorkdirConfig
41+
from sagemaker.remote_function.custom_file_filter import CustomFileFilter
4242
from sagemaker.s3 import s3_path_join
4343

4444

@@ -66,7 +66,6 @@ def prepare_step_input_channel_for_spark_mode(
6666
user_workspace_s3uri = self._prepare_and_upload_workspace(
6767
dependencies_list_path,
6868
self.remote_decorator_config.include_local_workdir,
69-
self.remote_decorator_config.workdir_config,
7069
self.remote_decorator_config.pre_execution_commands,
7170
self.remote_decorator_config.pre_execution_script,
7271
s3_base_uri,
@@ -132,19 +131,17 @@ def _prepare_and_upload_workspace(
132131
self,
133132
local_dependencies_path: str,
134133
include_local_workdir: bool,
135-
workdir_config: WorkdirConfig,
136134
pre_execution_commands: List[str],
137135
pre_execution_script_local_path: str,
138136
s3_base_uri: str,
139137
s3_kms_key: str,
140138
sagemaker_session: Session,
141-
custom_file_filter: Optional[Callable[[str, List], List]] = None,
139+
custom_file_filter: Optional[Union[Callable[[str, List], List], CustomFileFilter]] = None,
142140
) -> str:
143141
"""Upload the training step dependencies to S3 if present"""
144142
return _prepare_and_upload_workspace(
145143
local_dependencies_path=local_dependencies_path,
146144
include_local_workdir=include_local_workdir,
147-
workdir_config=workdir_config,
148145
pre_execution_commands=pre_execution_commands,
149146
pre_execution_script_local_path=pre_execution_script_local_path,
150147
s3_base_uri=s3_base_uri,

src/sagemaker/remote_function/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,5 @@
1515

1616
from sagemaker.remote_function.client import remote, RemoteExecutor # noqa: F401
1717
from sagemaker.remote_function.checkpoint_location import CheckpointLocation # noqa: F401
18-
from sagemaker.remote_function.workdir_config import WorkdirConfig # noqa: F401
18+
from sagemaker.remote_function.custom_file_filter import CustomFileFilter # noqa: F401
1919
from sagemaker.remote_function.spark_config import SparkConfig # noqa: F401

src/sagemaker/remote_function/client.py

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from collections import deque
1818
import time
1919
import threading
20-
from typing import Callable, Dict, List, Optional, Tuple, Any
20+
from typing import Callable, Dict, List, Optional, Tuple, Any, Union
2121
import functools
2222
import itertools
2323
import inspect
@@ -39,7 +39,7 @@
3939
from sagemaker.remote_function import logging_config
4040
from sagemaker.utils import name_from_base, base_from_name
4141
from sagemaker.remote_function.spark_config import SparkConfig
42-
from sagemaker.remote_function.workdir_config import WorkdirConfig
42+
from sagemaker.remote_function.custom_file_filter import CustomFileFilter
4343

4444
_API_CALL_LIMIT = {
4545
"SubmittingIntervalInSecs": 1,
@@ -66,7 +66,7 @@ def remote(
6666
environment_variables: Dict[str, str] = None,
6767
image_uri: str = None,
6868
include_local_workdir: bool = False,
69-
workdir_config: WorkdirConfig = None,
69+
custom_file_filter: Optional[Union[Callable[[str, List], List], CustomFileFilter]] = None,
7070
instance_count: int = 1,
7171
instance_type: str = None,
7272
job_conda_env: str = None,
@@ -87,7 +87,6 @@ def remote(
8787
spark_config: SparkConfig = None,
8888
use_spot_instances=False,
8989
max_wait_time_in_seconds=None,
90-
custom_file_filter: Optional[Callable[[str, List], List]] = None,
9190
):
9291
"""Decorator for running the annotated function as a SageMaker training job.
9392
@@ -195,10 +194,12 @@ def remote(
195194
methods that are not available via PyPI or conda. Only python files are included.
196195
Default value is ``False``.
197196
198-
workdir_config (WorkdirConfig): A ``WorkdirConfig`` object that specifies the
199-
local directories and files to be included in the remote function.
200-
workdir_config takes precedence over include_local_workdir.
201-
Default value is ``None``.
197+
custom_file_filter (Callable[[str, List], List], CustomFileFilter): Either a function
198+
that filters job dependencies to be uploaded to S3 or a ``CustomFileFilter`` object
199+
that specifies the local directories and files to be included in the remote function.
200+
If a callable is passed in, that function is passed to the ``ignore`` argument of
201+
``shutil.copytree``. Defaults to ``None``, which means only python
202+
files are accepted and uploaded to S3.
202203
203204
instance_count (int): The number of instances to use. Defaults to 1.
204205
NOTE: Remote function does not support instance_count > 1 for non Spark jobs.
@@ -274,11 +275,6 @@ def remote(
274275
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
275276
After this amount of time Amazon SageMaker will stop waiting for managed spot training
276277
job to complete. Defaults to ``None``.
277-
278-
custom_file_filter (Callable[[str, List], List]): A function that filters job
279-
dependencies to be uploaded to S3. This function is passed to the ``ignore``
280-
argument of ``shutil.copytree``. Defaults to ``None``, which means only python
281-
files are accepted.
282278
"""
283279

284280
def _remote(func):
@@ -290,7 +286,7 @@ def _remote(func):
290286
environment_variables=environment_variables,
291287
image_uri=image_uri,
292288
include_local_workdir=include_local_workdir,
293-
workdir_config=workdir_config,
289+
custom_file_filter=custom_file_filter,
294290
instance_count=instance_count,
295291
instance_type=instance_type,
296292
job_conda_env=job_conda_env,
@@ -311,7 +307,6 @@ def _remote(func):
311307
spark_config=spark_config,
312308
use_spot_instances=use_spot_instances,
313309
max_wait_time_in_seconds=max_wait_time_in_seconds,
314-
custom_file_filter=custom_file_filter,
315310
)
316311

317312
@functools.wraps(func)
@@ -501,7 +496,7 @@ def __init__(
501496
environment_variables: Dict[str, str] = None,
502497
image_uri: str = None,
503498
include_local_workdir: bool = False,
504-
workdir_config: WorkdirConfig = None,
499+
custom_file_filter: Optional[Union[Callable[[str, List], List], CustomFileFilter]] = None,
505500
instance_count: int = 1,
506501
instance_type: str = None,
507502
job_conda_env: str = None,
@@ -523,7 +518,6 @@ def __init__(
523518
spark_config: SparkConfig = None,
524519
use_spot_instances=False,
525520
max_wait_time_in_seconds=None,
526-
custom_file_filter: Optional[Callable[[str, List], List]] = None,
527521
):
528522
"""Constructor for RemoteExecutor
529523
@@ -628,10 +622,12 @@ def __init__(
628622
local directories. Set to ``True`` if the remote function code imports local modules
629623
and methods that are not available via PyPI or conda. Default value is ``False``.
630624
631-
workdir_config (WorkdirConfig): A ``WorkdirConfig`` object that specifies the
632-
local directories and files to be included in the remote function.
633-
workdir_config takes precedence over include_local_workdir.
634-
Default value is ``None``.
625+
custom_file_filter (Callable[[str, List], List], CustomFileFilter): Either a function
626+
that filters job dependencies to be uploaded to S3 or a ``CustomFileFilter`` object
627+
that specifies the local directories and files to be included in the remote function.
628+
If a callable is passed in, that function is passed to the ``ignore`` argument of
629+
``shutil.copytree``. Defaults to ``None``, which means only python
630+
files are accepted and uploaded to S3.
635631
636632
instance_count (int): The number of instances to use. Defaults to 1.
637633
NOTE: Remote function does not support instance_count > 1 for non Spark jobs.
@@ -715,11 +711,6 @@ def __init__(
715711
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
716712
After this amount of time Amazon SageMaker will stop waiting for managed spot training
717713
job to complete. Defaults to ``None``.
718-
719-
custom_file_filter (Callable[[str, List], List]): A function that filters job
720-
dependencies to be uploaded to S3. This function is passed to the ``ignore``
721-
argument of ``shutil.copytree``. Defaults to ``None``, which means only python
722-
files are accepted.
723714
"""
724715
self.max_parallel_jobs = max_parallel_jobs
725716

@@ -739,7 +730,7 @@ def __init__(
739730
environment_variables=environment_variables,
740731
image_uri=image_uri,
741732
include_local_workdir=include_local_workdir,
742-
workdir_config=workdir_config,
733+
custom_file_filter=custom_file_filter,
743734
instance_count=instance_count,
744735
instance_type=instance_type,
745736
job_conda_env=job_conda_env,
@@ -760,7 +751,6 @@ def __init__(
760751
spark_config=spark_config,
761752
use_spot_instances=use_spot_instances,
762753
max_wait_time_in_seconds=max_wait_time_in_seconds,
763-
custom_file_filter=custom_file_filter,
764754
)
765755

766756
self._state_condition = threading.Condition()

src/sagemaker/remote_function/workdir_config.py renamed to src/sagemaker/remote_function/custom_file_filter.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@
1616
import fnmatch
1717
import os
1818
import shutil
19-
from typing import List, Optional, Callable
19+
from typing import List, Optional, Callable, Union
2020

2121
from sagemaker.utils import resolve_value_from_config
22-
from sagemaker.config.config_schema import REMOTE_FUNCTION_PATH
22+
from sagemaker.config.config_schema import REMOTE_FUNCTION_PATH, CUSTOM_FILE_FILTER
2323

2424

25-
class WorkdirConfig:
25+
class CustomFileFilter:
2626
"""Configuration that specifies how the local working directory should be packaged."""
2727

2828
def __init__(self, *, ignore_name_patterns: List[str] = None):
29-
"""Initialize a WorkdirConfig.
29+
"""Initialize a CustomFileFilter.
3030
3131
Args:
3232
ignore_name_patterns (List[str]): ignore files or directories with names
@@ -50,59 +50,51 @@ def workdir(self):
5050
return self._workdir
5151

5252

53-
def resolve_workdir_config_from_config_file(
54-
direct_input: WorkdirConfig = None, sagemaker_session=None
55-
) -> WorkdirConfig:
56-
"""Resolve the workdir configuration from the config file.
53+
def resolve_custom_file_filter_from_config_file(
54+
direct_input: Union[Callable[[str, List], List], CustomFileFilter] = None,
55+
sagemaker_session=None,
56+
) -> Union[Callable[[str, List], List], CustomFileFilter, None]:
57+
"""Resolve the CustomFileFilter configuration from the config file.
5758
5859
Args:
59-
direct_input (WorkdirConfig): direct input from the user.
60+
direct_input (Callable[[str, List], List], CustomFileFilter): direct input from the user.
6061
sagemaker_session (sagemaker.session.Session): sagemaker session.
6162
Returns:
62-
WorkdirConfig: configuration that specifies how the local
63+
CustomFileFilter: configuration that specifies how the local
6364
working directory should be packaged.
6465
"""
6566
if direct_input is not None:
6667
return direct_input
6768
ignore_name_patterns = resolve_value_from_config(
6869
direct_input=None,
69-
config_path=".".join([REMOTE_FUNCTION_PATH, "WorkdirConfig", "IgnoreNamePatterns"]),
70+
config_path=".".join([REMOTE_FUNCTION_PATH, CUSTOM_FILE_FILTER, "IgnoreNamePatterns"]),
7071
default_value=None,
7172
sagemaker_session=sagemaker_session,
7273
)
7374
if ignore_name_patterns is not None:
74-
return WorkdirConfig(ignore_name_patterns=ignore_name_patterns)
75+
return CustomFileFilter(ignore_name_patterns=ignore_name_patterns)
7576
return None
7677

7778

78-
def copy_workdir(workdir_config: WorkdirConfig, dst: str):
79+
def copy_workdir(
80+
dst: str,
81+
custom_file_filter: Optional[Union[Callable[[str, List], List], CustomFileFilter]] = None,
82+
):
7983
"""Copy the local working directory to the destination.
8084
8185
Args:
82-
workdir_config (WorkdirConfig): configuration that specifies how the local
83-
working directory should be packaged.
8486
dst (str): destination path.
87+
custom_file_filter (Union[Callable[[str, List], List], CustomFileFilter): configuration that
88+
specifies how the local working directory should be packaged.
8589
"""
8690

8791
def _ignore_patterns(path: str, names: List): # pylint: disable=unused-argument
8892
ignored_names = set()
89-
if workdir_config.ignore_name_patterns is not None:
90-
for pattern in workdir_config.ignore_name_patterns:
93+
if custom_file_filter.ignore_name_patterns is not None:
94+
for pattern in custom_file_filter.ignore_name_patterns:
9195
ignored_names.update(fnmatch.filter(names, pattern))
9296
return ignored_names
9397

94-
shutil.copytree(
95-
workdir_config.workdir,
96-
dst,
97-
ignore=_ignore_patterns,
98-
)
99-
100-
101-
def copy_local_files(
102-
custom_file_filter: Optional[Callable[[str, List], List]], workdir: str, dst: str
103-
):
104-
"""Copy files from the local working directory to the destination."""
105-
10698
def _filter_non_python_files(path: str, names: List) -> List:
10799
"""Ignore function for filtering out non python files."""
108100
to_ignore = []
@@ -119,9 +111,18 @@ def _filter_non_python_files(path: str, names: List) -> List:
119111

120112
return to_ignore
121113

122-
ignore = custom_file_filter if custom_file_filter is not None else _filter_non_python_files
114+
_ignore = None
115+
_src = os.getcwd()
116+
if not custom_file_filter:
117+
_ignore = _filter_non_python_files
118+
elif callable(custom_file_filter):
119+
_ignore = custom_file_filter
120+
elif isinstance(custom_file_filter, CustomFileFilter):
121+
_ignore = _ignore_patterns
122+
_src = custom_file_filter.workdir
123+
123124
shutil.copytree(
124-
workdir,
125+
_src,
125126
dst,
126-
ignore=ignore,
127+
ignore=_ignore,
127128
)

0 commit comments

Comments
 (0)