Skip to content

Commit b26397e

Browse files
committed
Add python_files_only to job setting
1 parent e100e0a commit b26397e

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

src/sagemaker/remote_function/client.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def remote(
8585
spark_config: SparkConfig = None,
8686
use_spot_instances=False,
8787
max_wait_time_in_seconds=None,
88+
python_files_only=True,
8889
):
8990
"""Decorator for running the annotated function as a SageMaker training job.
9091
@@ -265,6 +266,8 @@ def remote(
265266
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
266267
After this amount of time Amazon SageMaker will stop waiting for managed spot training
267268
job to complete. Defaults to ``None``.
269+
270+
python_files_only (bool): Specifies whether non-python files are allowed in job dependencies. Defaults to ``True``.
268271
"""
269272

270273
def _remote(func):
@@ -296,6 +299,7 @@ def _remote(func):
296299
spark_config=spark_config,
297300
use_spot_instances=use_spot_instances,
298301
max_wait_time_in_seconds=max_wait_time_in_seconds,
302+
python_files_only=python_files_only,
299303
)
300304

301305
@functools.wraps(func)
@@ -506,6 +510,7 @@ def __init__(
506510
spark_config: SparkConfig = None,
507511
use_spot_instances=False,
508512
max_wait_time_in_seconds=None,
513+
python_files_only=True,
509514
):
510515
"""Constructor for RemoteExecutor
511516
@@ -692,6 +697,8 @@ def __init__(
692697
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
693698
After this amount of time Amazon SageMaker will stop waiting for managed spot training
694699
job to complete. Defaults to ``None``.
700+
701+
python_files_only (bool): Specifies whether non-python files are allowed in job dependencies. Defaults to ``True``.
695702
"""
696703
self.max_parallel_jobs = max_parallel_jobs
697704

@@ -731,6 +738,7 @@ def __init__(
731738
spark_config=spark_config,
732739
use_spot_instances=use_spot_instances,
733740
max_wait_time_in_seconds=max_wait_time_in_seconds,
741+
python_files_only=python_files_only,
734742
)
735743

736744
self._state_condition = threading.Condition()

src/sagemaker/remote_function/job.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def __init__(
193193
spark_config: SparkConfig = None,
194194
use_spot_instances=False,
195195
max_wait_time_in_seconds=None,
196+
python_files_only: bool = True,
196197
):
197198
"""Initialize a _JobSettings instance which configures the remote job.
198199
@@ -363,6 +364,9 @@ def __init__(
363364
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
364365
After this amount of time Amazon SageMaker will stop waiting for managed spot
365366
training job to complete. Defaults to ``None``.
367+
368+
python_files_only (bool): Specifies whether non-python files are allowed in job
369+
dependencies. Defaults to ``True``.
366370
"""
367371
self.sagemaker_session = sagemaker_session or Session()
368372
self.environment_variables = resolve_value_from_config(
@@ -450,6 +454,7 @@ def __init__(
450454
self.keep_alive_period_in_seconds = keep_alive_period_in_seconds
451455
self.spark_config = spark_config
452456
self.use_spot_instances = use_spot_instances
457+
self.python_files_only = python_files_only
453458
self.max_wait_time_in_seconds = max_wait_time_in_seconds
454459
self.job_conda_env = resolve_value_from_config(
455460
direct_input=job_conda_env,
@@ -649,6 +654,7 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
649654
s3_base_uri=s3_base_uri,
650655
s3_kms_key=job_settings.s3_kms_key,
651656
sagemaker_session=job_settings.sagemaker_session,
657+
python_files_only=job_settings.python_files_only,
652658
)
653659

654660
stored_function = StoredFunction(
@@ -890,6 +896,7 @@ def _prepare_and_upload_dependencies(
890896
s3_base_uri: str,
891897
s3_kms_key: str,
892898
sagemaker_session: Session,
899+
python_files_only: bool,
893900
) -> str:
894901
"""Upload the job dependencies to S3 if present"""
895902

@@ -906,12 +913,12 @@ def _prepare_and_upload_dependencies(
906913
os.mkdir(tmp_workspace_dir)
907914
# TODO Remove the following hack to avoid dir_exists error in the copy_tree call below.
908915
tmp_workspace = os.path.join(tmp_workspace_dir, JOB_REMOTE_FUNCTION_WORKSPACE)
909-
916+
ignore = _filter_non_python_files if python_files_only else None
910917
if include_local_workdir:
911918
shutil.copytree(
912919
os.getcwd(),
913920
tmp_workspace,
914-
ignore=_filter_non_python_files,
921+
ignore=ignore,
915922
)
916923
logger.info("Copied user workspace python scripts to '%s'", tmp_workspace)
917924

0 commit comments

Comments
 (0)