Skip to content

Commit b32fe65

Browse files
committed
feature: add custom_file_filter argument to job setting
1 parent 41feb4c commit b32fe65

File tree

6 files changed

+64
-12
lines changed

6 files changed

+64
-12
lines changed

src/sagemaker/feature_store/feature_processor/_config_uploader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""Contains classes for preparing and uploading configs for a scheduled feature processor."""
1414
from __future__ import absolute_import
15-
from typing import Callable, Dict, Tuple, List
15+
from typing import Callable, Dict, Optional, Tuple, List
1616

1717
import attr
1818

@@ -70,6 +70,7 @@ def prepare_step_input_channel_for_spark_mode(
7070
s3_base_uri,
7171
self.remote_decorator_config.s3_kms_key,
7272
sagemaker_session,
73+
self.remote_decorator_config.custom_file_filter,
7374
)
7475

7576
(
@@ -134,6 +135,7 @@ def _prepare_and_upload_dependencies(
134135
s3_base_uri: str,
135136
s3_kms_key: str,
136137
sagemaker_session: Session,
138+
custom_file_filter: Optional[Callable[[str, List], List]] = None,
137139
) -> str:
138140
"""Upload the training step dependencies to S3 if present"""
139141
return _prepare_and_upload_dependencies(
@@ -144,6 +146,7 @@ def _prepare_and_upload_dependencies(
144146
s3_base_uri=s3_base_uri,
145147
s3_kms_key=s3_kms_key,
146148
sagemaker_session=sagemaker_session,
149+
custom_file_filter=custom_file_filter,
147150
)
148151

149152
def _prepare_and_upload_runtime_scripts(

src/sagemaker/remote_function/client.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from collections import deque
1818
import time
1919
import threading
20-
from typing import Dict, List, Tuple, Any
20+
from typing import Callable, Dict, List, Optional, Tuple, Any
2121
import functools
2222
import itertools
2323
import inspect
@@ -85,6 +85,7 @@ def remote(
8585
spark_config: SparkConfig = None,
8686
use_spot_instances=False,
8787
max_wait_time_in_seconds=None,
88+
custom_file_filter: Optional[Callable[[str, List], List]] = None,
8889
):
8990
"""Decorator for running the annotated function as a SageMaker training job.
9091
@@ -265,6 +266,11 @@ def remote(
265266
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
266267
After this amount of time Amazon SageMaker will stop waiting for managed spot training
267268
job to complete. Defaults to ``None``.
269+
270+
custom_file_filter (Callable[[str, List], List]): A function that filters job
271+
dependencies to be uploaded to S3. This function is passed to the ``ignore``
272+
argument of ``shutil.copytree``. Defaults to ``None``, which means only python
273+
files are accepted.
268274
"""
269275

270276
def _remote(func):
@@ -296,6 +302,7 @@ def _remote(func):
296302
spark_config=spark_config,
297303
use_spot_instances=use_spot_instances,
298304
max_wait_time_in_seconds=max_wait_time_in_seconds,
305+
custom_file_filter=custom_file_filter,
299306
)
300307

301308
@functools.wraps(func)
@@ -506,6 +513,7 @@ def __init__(
506513
spark_config: SparkConfig = None,
507514
use_spot_instances=False,
508515
max_wait_time_in_seconds=None,
516+
custom_file_filter: Optional[Callable[[str, List], List]] = None,
509517
):
510518
"""Constructor for RemoteExecutor
511519
@@ -692,6 +700,11 @@ def __init__(
692700
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
693701
After this amount of time Amazon SageMaker will stop waiting for managed spot training
694702
job to complete. Defaults to ``None``.
703+
704+
custom_file_filter (Callable[[str, List], List]): A function that filters job
705+
dependencies to be uploaded to S3. This function is passed to the ``ignore``
706+
argument of ``shutil.copytree``. Defaults to ``None``, which means only python
707+
files are accepted.
695708
"""
696709
self.max_parallel_jobs = max_parallel_jobs
697710

@@ -731,6 +744,7 @@ def __init__(
731744
spark_config=spark_config,
732745
use_spot_instances=use_spot_instances,
733746
max_wait_time_in_seconds=max_wait_time_in_seconds,
747+
custom_file_filter=custom_file_filter,
734748
)
735749

736750
self._state_condition = threading.Condition()

src/sagemaker/remote_function/job.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import sys
2121
import json
2222
import secrets
23-
from typing import Dict, List, Tuple
23+
from typing import Callable, Dict, List, Optional, Tuple
2424
from urllib.parse import urlparse
2525
from io import BytesIO
2626

@@ -193,6 +193,7 @@ def __init__(
193193
spark_config: SparkConfig = None,
194194
use_spot_instances=False,
195195
max_wait_time_in_seconds=None,
196+
custom_file_filter: Optional[Callable[[str, List], List]] = None,
196197
):
197198
"""Initialize a _JobSettings instance which configures the remote job.
198199
@@ -363,6 +364,11 @@ def __init__(
363364
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
364365
After this amount of time Amazon SageMaker will stop waiting for managed spot
365366
training job to complete. Defaults to ``None``.
367+
368+
custom_file_filter (Callable[[str, List], List]): A function that filters job
369+
dependencies to be uploaded to S3. This function is passed to the ``ignore``
370+
argument of ``shutil.copytree``. Defaults to ``None``, which means only python
371+
files are accepted.
366372
"""
367373
self.sagemaker_session = sagemaker_session or Session()
368374
self.environment_variables = resolve_value_from_config(
@@ -450,6 +456,7 @@ def __init__(
450456
self.keep_alive_period_in_seconds = keep_alive_period_in_seconds
451457
self.spark_config = spark_config
452458
self.use_spot_instances = use_spot_instances
459+
self.custom_file_filter = custom_file_filter
453460
self.max_wait_time_in_seconds = max_wait_time_in_seconds
454461
self.job_conda_env = resolve_value_from_config(
455462
direct_input=job_conda_env,
@@ -649,6 +656,7 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
649656
s3_base_uri=s3_base_uri,
650657
s3_kms_key=job_settings.s3_kms_key,
651658
sagemaker_session=job_settings.sagemaker_session,
659+
custom_file_filter=job_settings.custom_file_filter,
652660
)
653661

654662
stored_function = StoredFunction(
@@ -890,6 +898,7 @@ def _prepare_and_upload_dependencies(
890898
s3_base_uri: str,
891899
s3_kms_key: str,
892900
sagemaker_session: Session,
901+
custom_file_filter: Optional[Callable[[str, List], List]] = None,
893902
) -> str:
894903
"""Upload the job dependencies to S3 if present"""
895904

@@ -906,12 +915,12 @@ def _prepare_and_upload_dependencies(
906915
os.mkdir(tmp_workspace_dir)
907916
# TODO Remove the following hack to avoid dir_exists error in the copy_tree call below.
908917
tmp_workspace = os.path.join(tmp_workspace_dir, JOB_REMOTE_FUNCTION_WORKSPACE)
909-
918+
ignore = custom_file_filter if custom_file_filter is not None else _filter_non_python_files
910919
if include_local_workdir:
911920
shutil.copytree(
912921
os.getcwd(),
913922
tmp_workspace,
914-
ignore=_filter_non_python_files,
923+
ignore=ignore,
915924
)
916925
logger.info("Copied user workspace python scripts to '%s'", tmp_workspace)
917926

tests/unit/sagemaker/feature_store/feature_processor/test_config_uploader.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def remote_decorator_config(sagemaker_session):
6464
pre_execution_script="some_path",
6565
python_sdk_whl_s3_uri=SAGEMAKER_WHL_FILE_S3_PATH,
6666
environment_variables={"REMOTE_FUNCTION_SECRET_KEY": "some_secret_key"},
67+
custom_file_filter=None,
6768
)
6869

6970

@@ -108,6 +109,7 @@ def test_prepare_and_upload_dependencies(mock_upload, config_uploader):
108109
s3_base_uri=remote_decorator_config.s3_root_uri,
109110
s3_kms_key=remote_decorator_config.s3_kms_key,
110111
sagemaker_session=sagemaker_session,
112+
custom_file_filter=None,
111113
)
112114

113115

@@ -201,6 +203,7 @@ def test_prepare_step_input_channel(
201203
s3_base_uri=remote_decorator_config.s3_root_uri,
202204
s3_kms_key="some_kms",
203205
sagemaker_session=sagemaker_session,
206+
custom_file_filter=None,
204207
)
205208

206209
mock_spark_dependency_upload.assert_called_once_with(

tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -263,13 +263,7 @@ def test_to_pipeline(
263263
)
264264

265265
mock_dependency_upload.assert_called_once_with(
266-
local_dependencies_path,
267-
True,
268-
None,
269-
None,
270-
f"{S3_URI}/pipeline_name",
271-
None,
272-
session,
266+
local_dependencies_path, True, None, None, f"{S3_URI}/pipeline_name", None, session, None
273267
)
274268

275269
mock_spark_dependency_upload.assert_called_once_with(
@@ -875,6 +869,7 @@ def test_remote_decorator_fields_consistency(get_execution_role, session):
875869
"tags",
876870
"use_spot_instances",
877871
"max_wait_time_in_seconds",
872+
"custom_file_filter",
878873
}
879874

880875
job_settings = _JobSettings(

tests/unit/sagemaker/remote_function/test_job.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ def test_start(
358358
s3_base_uri=f"{S3_URI}/{job.job_name}",
359359
s3_kms_key=None,
360360
sagemaker_session=session(),
361+
custom_file_filter=None,
361362
)
362363

363364
session().sagemaker_client.create_training_job.assert_called_once_with(
@@ -480,6 +481,7 @@ def test_start_with_complete_job_settings(
480481
s3_base_uri=f"{S3_URI}/{job.job_name}",
481482
s3_kms_key=job_settings.s3_kms_key,
482483
sagemaker_session=session(),
484+
custom_file_filter=None,
483485
)
484486

485487
session().sagemaker_client.create_training_job.assert_called_once_with(
@@ -778,6 +780,32 @@ def test_prepare_and_upload_dependencies(session, mock_copytree, mock_copy, mock
778780
)
779781

780782

783+
@patch("sagemaker.s3.S3Uploader.upload", return_value="some_uri")
784+
@patch("shutil.copy2")
785+
@patch("shutil.copytree")
786+
@patch("sagemaker.remote_function.job.Session", return_value=mock_session())
787+
def test_prepare_and_upload_dependencies_with_custom_filter(
788+
session, mock_copytree, mock_copy, mock_s3_upload
789+
):
790+
def custom_file_filter():
791+
pass
792+
793+
s3_path = _prepare_and_upload_dependencies(
794+
local_dependencies_path="some/path/to/dependency",
795+
include_local_workdir=True,
796+
pre_execution_commands=["cmd_1", "cmd_2"],
797+
pre_execution_script_local_path=None,
798+
s3_base_uri=S3_URI,
799+
s3_kms_key=KMS_KEY_ARN,
800+
sagemaker_session=session,
801+
custom_file_filter=custom_file_filter,
802+
)
803+
804+
assert s3_path == mock_s3_upload.return_value
805+
806+
mock_copytree.assert_called_with(os.getcwd(), ANY, ignore=custom_file_filter)
807+
808+
781809
@patch("sagemaker.remote_function.job.Session", return_value=mock_session())
782810
def test_prepare_and_upload_spark_dependent_file_without_spark_config(session):
783811
assert _prepare_and_upload_spark_dependent_files(

0 commit comments

Comments
 (0)