@@ -193,6 +193,7 @@ def __init__(
193
193
spark_config : SparkConfig = None ,
194
194
use_spot_instances = False ,
195
195
max_wait_time_in_seconds = None ,
196
+ python_files_only : bool = True ,
196
197
):
197
198
"""Initialize a _JobSettings instance which configures the remote job.
198
199
@@ -363,6 +364,9 @@ def __init__(
363
364
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
364
365
After this amount of time Amazon SageMaker will stop waiting for managed spot
365
366
training job to complete. Defaults to ``None``.
367
+
368
+ python_files_only (bool): Specifies whether non-python files are allowed in job
369
+ dependencies. Defaults to ``True``.
366
370
"""
367
371
self .sagemaker_session = sagemaker_session or Session ()
368
372
self .environment_variables = resolve_value_from_config (
@@ -450,6 +454,7 @@ def __init__(
450
454
self .keep_alive_period_in_seconds = keep_alive_period_in_seconds
451
455
self .spark_config = spark_config
452
456
self .use_spot_instances = use_spot_instances
457
+ self .python_files_only = python_files_only
453
458
self .max_wait_time_in_seconds = max_wait_time_in_seconds
454
459
self .job_conda_env = resolve_value_from_config (
455
460
direct_input = job_conda_env ,
@@ -649,6 +654,7 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
649
654
s3_base_uri = s3_base_uri ,
650
655
s3_kms_key = job_settings .s3_kms_key ,
651
656
sagemaker_session = job_settings .sagemaker_session ,
657
+ python_files_only = job_settings .python_files_only ,
652
658
)
653
659
654
660
stored_function = StoredFunction (
@@ -890,6 +896,7 @@ def _prepare_and_upload_dependencies(
890
896
s3_base_uri : str ,
891
897
s3_kms_key : str ,
892
898
sagemaker_session : Session ,
899
+ python_files_only : bool ,
893
900
) -> str :
894
901
"""Upload the job dependencies to S3 if present"""
895
902
@@ -906,12 +913,12 @@ def _prepare_and_upload_dependencies(
906
913
os .mkdir (tmp_workspace_dir )
907
914
# TODO Remove the following hack to avoid dir_exists error in the copy_tree call below.
908
915
tmp_workspace = os .path .join (tmp_workspace_dir , JOB_REMOTE_FUNCTION_WORKSPACE )
909
-
916
+ ignore = _filter_non_python_files if python_files_only else None
910
917
if include_local_workdir :
911
918
shutil .copytree (
912
919
os .getcwd (),
913
920
tmp_workspace ,
914
- ignore = _filter_non_python_files ,
921
+ ignore = ignore ,
915
922
)
916
923
logger .info ("Copied user workspace python scripts to '%s'" , tmp_workspace )
917
924
0 commit comments