Skip to content

Commit a9ac311

Browse files
authored
feature: allow non-python files in job dependencies (#4138)
* feature: add custom_file_filter argument to job setting * change: Add more tests
1 parent 412e8ba commit a9ac311

File tree

7 files changed

+147
-12
lines changed

7 files changed

+147
-12
lines changed

src/sagemaker/feature_store/feature_processor/_config_uploader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""Contains classes for preparing and uploading configs for a scheduled feature processor."""
1414
from __future__ import absolute_import
15-
from typing import Callable, Dict, Tuple, List
15+
from typing import Callable, Dict, Optional, Tuple, List
1616

1717
import attr
1818

@@ -70,6 +70,7 @@ def prepare_step_input_channel_for_spark_mode(
7070
s3_base_uri,
7171
self.remote_decorator_config.s3_kms_key,
7272
sagemaker_session,
73+
self.remote_decorator_config.custom_file_filter,
7374
)
7475

7576
(
@@ -134,6 +135,7 @@ def _prepare_and_upload_dependencies(
134135
s3_base_uri: str,
135136
s3_kms_key: str,
136137
sagemaker_session: Session,
138+
custom_file_filter: Optional[Callable[[str, List], List]] = None,
137139
) -> str:
138140
"""Upload the training step dependencies to S3 if present"""
139141
return _prepare_and_upload_dependencies(
@@ -144,6 +146,7 @@ def _prepare_and_upload_dependencies(
144146
s3_base_uri=s3_base_uri,
145147
s3_kms_key=s3_kms_key,
146148
sagemaker_session=sagemaker_session,
149+
custom_file_filter=custom_file_filter,
147150
)
148151

149152
def _prepare_and_upload_runtime_scripts(

src/sagemaker/remote_function/client.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from collections import deque
1818
import time
1919
import threading
20-
from typing import Dict, List, Tuple, Any
20+
from typing import Callable, Dict, List, Optional, Tuple, Any
2121
import functools
2222
import itertools
2323
import inspect
@@ -85,6 +85,7 @@ def remote(
8585
spark_config: SparkConfig = None,
8686
use_spot_instances=False,
8787
max_wait_time_in_seconds=None,
88+
custom_file_filter: Optional[Callable[[str, List], List]] = None,
8889
):
8990
"""Decorator for running the annotated function as a SageMaker training job.
9091
@@ -265,6 +266,11 @@ def remote(
265266
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
266267
After this amount of time Amazon SageMaker will stop waiting for managed spot training
267268
job to complete. Defaults to ``None``.
269+
270+
custom_file_filter (Callable[[str, List], List]): A function that filters job
271+
dependencies to be uploaded to S3. This function is passed to the ``ignore``
272+
argument of ``shutil.copytree``. Defaults to ``None``, which means only python
273+
files are accepted.
268274
"""
269275

270276
def _remote(func):
@@ -296,6 +302,7 @@ def _remote(func):
296302
spark_config=spark_config,
297303
use_spot_instances=use_spot_instances,
298304
max_wait_time_in_seconds=max_wait_time_in_seconds,
305+
custom_file_filter=custom_file_filter,
299306
)
300307

301308
@functools.wraps(func)
@@ -506,6 +513,7 @@ def __init__(
506513
spark_config: SparkConfig = None,
507514
use_spot_instances=False,
508515
max_wait_time_in_seconds=None,
516+
custom_file_filter: Optional[Callable[[str, List], List]] = None,
509517
):
510518
"""Constructor for RemoteExecutor
511519
@@ -692,6 +700,11 @@ def __init__(
692700
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
693701
After this amount of time Amazon SageMaker will stop waiting for managed spot training
694702
job to complete. Defaults to ``None``.
703+
704+
custom_file_filter (Callable[[str, List], List]): A function that filters job
705+
dependencies to be uploaded to S3. This function is passed to the ``ignore``
706+
argument of ``shutil.copytree``. Defaults to ``None``, which means only python
707+
files are accepted.
695708
"""
696709
self.max_parallel_jobs = max_parallel_jobs
697710

@@ -731,6 +744,7 @@ def __init__(
731744
spark_config=spark_config,
732745
use_spot_instances=use_spot_instances,
733746
max_wait_time_in_seconds=max_wait_time_in_seconds,
747+
custom_file_filter=custom_file_filter,
734748
)
735749

736750
self._state_condition = threading.Condition()

src/sagemaker/remote_function/job.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import sys
2121
import json
2222
import secrets
23-
from typing import Dict, List, Tuple
23+
from typing import Callable, Dict, List, Optional, Tuple
2424
from urllib.parse import urlparse
2525
from io import BytesIO
2626

@@ -193,6 +193,7 @@ def __init__(
193193
spark_config: SparkConfig = None,
194194
use_spot_instances=False,
195195
max_wait_time_in_seconds=None,
196+
custom_file_filter: Optional[Callable[[str, List], List]] = None,
196197
):
197198
"""Initialize a _JobSettings instance which configures the remote job.
198199
@@ -363,6 +364,11 @@ def __init__(
363364
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
364365
After this amount of time Amazon SageMaker will stop waiting for managed spot
365366
training job to complete. Defaults to ``None``.
367+
368+
custom_file_filter (Callable[[str, List], List]): A function that filters job
369+
dependencies to be uploaded to S3. This function is passed to the ``ignore``
370+
argument of ``shutil.copytree``. Defaults to ``None``, which means only python
371+
files are accepted.
366372
"""
367373
self.sagemaker_session = sagemaker_session or Session()
368374
self.environment_variables = resolve_value_from_config(
@@ -450,6 +456,7 @@ def __init__(
450456
self.keep_alive_period_in_seconds = keep_alive_period_in_seconds
451457
self.spark_config = spark_config
452458
self.use_spot_instances = use_spot_instances
459+
self.custom_file_filter = custom_file_filter
453460
self.max_wait_time_in_seconds = max_wait_time_in_seconds
454461
self.job_conda_env = resolve_value_from_config(
455462
direct_input=job_conda_env,
@@ -649,6 +656,7 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
649656
s3_base_uri=s3_base_uri,
650657
s3_kms_key=job_settings.s3_kms_key,
651658
sagemaker_session=job_settings.sagemaker_session,
659+
custom_file_filter=job_settings.custom_file_filter,
652660
)
653661

654662
stored_function = StoredFunction(
@@ -890,6 +898,7 @@ def _prepare_and_upload_dependencies(
890898
s3_base_uri: str,
891899
s3_kms_key: str,
892900
sagemaker_session: Session,
901+
custom_file_filter: Optional[Callable[[str, List], List]] = None,
893902
) -> str:
894903
"""Upload the job dependencies to S3 if present"""
895904

@@ -906,12 +915,12 @@ def _prepare_and_upload_dependencies(
906915
os.mkdir(tmp_workspace_dir)
907916
# TODO Remove the following hack to avoid dir_exists error in the copy_tree call below.
908917
tmp_workspace = os.path.join(tmp_workspace_dir, JOB_REMOTE_FUNCTION_WORKSPACE)
909-
918+
ignore = custom_file_filter if custom_file_filter is not None else _filter_non_python_files
910919
if include_local_workdir:
911920
shutil.copytree(
912921
os.getcwd(),
913922
tmp_workspace,
914-
ignore=_filter_non_python_files,
923+
ignore=ignore,
915924
)
916925
logger.info("Copied user workspace python scripts to '%s'", tmp_workspace)
917926

tests/unit/sagemaker/feature_store/feature_processor/test_config_uploader.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def runtime_env_manager():
5050
return mocked_runtime_env_manager
5151

5252

53+
def custom_file_filter():
54+
pass
55+
56+
5357
@pytest.fixture
5458
def remote_decorator_config(sagemaker_session):
5559
return Mock(
@@ -64,6 +68,7 @@ def remote_decorator_config(sagemaker_session):
6468
pre_execution_script="some_path",
6569
python_sdk_whl_s3_uri=SAGEMAKER_WHL_FILE_S3_PATH,
6670
environment_variables={"REMOTE_FUNCTION_SECRET_KEY": "some_secret_key"},
71+
custom_file_filter=None,
6772
)
6873

6974

@@ -72,6 +77,24 @@ def config_uploader(remote_decorator_config, runtime_env_manager):
7277
return ConfigUploader(remote_decorator_config, runtime_env_manager)
7378

7479

80+
@pytest.fixture
81+
def remote_decorator_config_with_filter(sagemaker_session):
82+
return Mock(
83+
_JobSettings,
84+
sagemaker_session=sagemaker_session,
85+
s3_root_uri="some_s3_uri",
86+
s3_kms_key="some_kms",
87+
spark_config=SparkConfig(),
88+
dependencies=None,
89+
include_local_workdir=True,
90+
pre_execution_commands="some_commands",
91+
pre_execution_script="some_path",
92+
python_sdk_whl_s3_uri=SAGEMAKER_WHL_FILE_S3_PATH,
93+
environment_variables={"REMOTE_FUNCTION_SECRET_KEY": "some_secret_key"},
94+
custom_file_filter=custom_file_filter,
95+
)
96+
97+
7598
@patch("sagemaker.feature_store.feature_processor._config_uploader.StoredFunction")
7699
def test_prepare_and_upload_callable(mock_stored_function, config_uploader, wrapped_func):
77100
mock_stored_function.save(wrapped_func).return_value = None
@@ -108,6 +131,42 @@ def test_prepare_and_upload_dependencies(mock_upload, config_uploader):
108131
s3_base_uri=remote_decorator_config.s3_root_uri,
109132
s3_kms_key=remote_decorator_config.s3_kms_key,
110133
sagemaker_session=sagemaker_session,
134+
custom_file_filter=None,
135+
)
136+
137+
138+
@patch(
139+
"sagemaker.feature_store.feature_processor._config_uploader._prepare_and_upload_dependencies",
140+
return_value="some_s3_uri",
141+
)
142+
def test_prepare_and_upload_dependencies_with_filter(
143+
mock_job_upload, remote_decorator_config_with_filter, runtime_env_manager
144+
):
145+
config_uploader_with_filter = ConfigUploader(
146+
remote_decorator_config=remote_decorator_config_with_filter,
147+
runtime_env_manager=runtime_env_manager,
148+
)
149+
remote_decorator_config = config_uploader_with_filter.remote_decorator_config
150+
config_uploader_with_filter._prepare_and_upload_dependencies(
151+
local_dependencies_path="some/path/to/dependency",
152+
include_local_workdir=True,
153+
pre_execution_commands=remote_decorator_config.pre_execution_commands,
154+
pre_execution_script_local_path=remote_decorator_config.pre_execution_script,
155+
s3_base_uri=remote_decorator_config.s3_root_uri,
156+
s3_kms_key=remote_decorator_config.s3_kms_key,
157+
sagemaker_session=sagemaker_session,
158+
custom_file_filter=remote_decorator_config_with_filter.custom_file_filter,
159+
)
160+
161+
mock_job_upload.assert_called_once_with(
162+
local_dependencies_path="some/path/to/dependency",
163+
include_local_workdir=True,
164+
pre_execution_commands=remote_decorator_config.pre_execution_commands,
165+
pre_execution_script_local_path=remote_decorator_config.pre_execution_script,
166+
s3_base_uri=remote_decorator_config.s3_root_uri,
167+
s3_kms_key=remote_decorator_config.s3_kms_key,
168+
sagemaker_session=sagemaker_session,
169+
custom_file_filter=custom_file_filter,
111170
)
112171

113172

@@ -201,6 +260,7 @@ def test_prepare_step_input_channel(
201260
s3_base_uri=remote_decorator_config.s3_root_uri,
202261
s3_kms_key="some_kms",
203262
sagemaker_session=sagemaker_session,
263+
custom_file_filter=None,
204264
)
205265

206266
mock_spark_dependency_upload.assert_called_once_with(

tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -263,13 +263,7 @@ def test_to_pipeline(
263263
)
264264

265265
mock_dependency_upload.assert_called_once_with(
266-
local_dependencies_path,
267-
True,
268-
None,
269-
None,
270-
f"{S3_URI}/pipeline_name",
271-
None,
272-
session,
266+
local_dependencies_path, True, None, None, f"{S3_URI}/pipeline_name", None, session, None
273267
)
274268

275269
mock_spark_dependency_upload.assert_called_once_with(
@@ -875,6 +869,7 @@ def test_remote_decorator_fields_consistency(get_execution_role, session):
875869
"tags",
876870
"use_spot_instances",
877871
"max_wait_time_in_seconds",
872+
"custom_file_filter",
878873
}
879874

880875
job_settings = _JobSettings(

tests/unit/sagemaker/remote_function/test_client.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,32 @@ def square(x):
175175
assert mock_job_settings.call_args.kwargs["image_uri"] == IMAGE
176176

177177

178+
@patch(
179+
"sagemaker.remote_function.client.serialization.deserialize_obj_from_s3",
180+
return_value=EXPECTED_JOB_RESULT,
181+
)
182+
@patch("sagemaker.remote_function.client._JobSettings")
183+
@patch("sagemaker.remote_function.client._Job.start")
184+
def test_decorator_with_custom_file_filter(
185+
mock_start, mock_job_settings, mock_deserialize_obj_from_s3
186+
):
187+
mock_job = Mock(job_name=TRAINING_JOB_NAME)
188+
mock_job.describe.return_value = COMPLETED_TRAINING_JOB
189+
190+
mock_start.return_value = mock_job
191+
192+
def custom_file_filter():
193+
pass
194+
195+
@remote(image_uri=IMAGE, s3_root_uri=S3_URI, custom_file_filter=custom_file_filter)
196+
def square(x):
197+
return x * x
198+
199+
result = square(5)
200+
assert result == EXPECTED_JOB_RESULT
201+
assert mock_job_settings.call_args.kwargs["custom_file_filter"] == custom_file_filter
202+
203+
178204
@patch(
179205
"sagemaker.remote_function.client.serialization.deserialize_exception_from_s3",
180206
return_value=ZeroDivisionError("division by zero"),

tests/unit/sagemaker/remote_function/test_job.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ def test_start(
358358
s3_base_uri=f"{S3_URI}/{job.job_name}",
359359
s3_kms_key=None,
360360
sagemaker_session=session(),
361+
custom_file_filter=None,
361362
)
362363

363364
session().sagemaker_client.create_training_job.assert_called_once_with(
@@ -480,6 +481,7 @@ def test_start_with_complete_job_settings(
480481
s3_base_uri=f"{S3_URI}/{job.job_name}",
481482
s3_kms_key=job_settings.s3_kms_key,
482483
sagemaker_session=session(),
484+
custom_file_filter=None,
483485
)
484486

485487
session().sagemaker_client.create_training_job.assert_called_once_with(
@@ -778,6 +780,32 @@ def test_prepare_and_upload_dependencies(session, mock_copytree, mock_copy, mock
778780
)
779781

780782

783+
@patch("sagemaker.s3.S3Uploader.upload", return_value="some_uri")
784+
@patch("shutil.copy2")
785+
@patch("shutil.copytree")
786+
@patch("sagemaker.remote_function.job.Session", return_value=mock_session())
787+
def test_prepare_and_upload_dependencies_with_custom_filter(
788+
session, mock_copytree, mock_copy, mock_s3_upload
789+
):
790+
def custom_file_filter():
791+
pass
792+
793+
s3_path = _prepare_and_upload_dependencies(
794+
local_dependencies_path="some/path/to/dependency",
795+
include_local_workdir=True,
796+
pre_execution_commands=["cmd_1", "cmd_2"],
797+
pre_execution_script_local_path=None,
798+
s3_base_uri=S3_URI,
799+
s3_kms_key=KMS_KEY_ARN,
800+
sagemaker_session=session,
801+
custom_file_filter=custom_file_filter,
802+
)
803+
804+
assert s3_path == mock_s3_upload.return_value
805+
806+
mock_copytree.assert_called_with(os.getcwd(), ANY, ignore=custom_file_filter)
807+
808+
781809
@patch("sagemaker.remote_function.job.Session", return_value=mock_session())
782810
def test_prepare_and_upload_spark_dependent_file_without_spark_config(session):
783811
assert _prepare_and_upload_spark_dependent_files(

0 commit comments

Comments
 (0)