Skip to content

feature: allow non-python files in job dependencies #4138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# language governing permissions and limitations under the License.
"""Contains classes for preparing and uploading configs for a scheduled feature processor."""
from __future__ import absolute_import
from typing import Callable, Dict, Tuple, List
from typing import Callable, Dict, Optional, Tuple, List
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using Optional[...], it would be better to use ... | None. That requires having this at the top in order to support Python < 3.10:

from __future__ import annotations

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Optional is used a lot in the code base, I'll address it in another PR


import attr

Expand Down Expand Up @@ -70,6 +70,7 @@ def prepare_step_input_channel_for_spark_mode(
s3_base_uri,
self.remote_decorator_config.s3_kms_key,
sagemaker_session,
self.remote_decorator_config.custom_file_filter,
)

(
Expand Down Expand Up @@ -134,6 +135,7 @@ def _prepare_and_upload_dependencies(
s3_base_uri: str,
s3_kms_key: str,
sagemaker_session: Session,
custom_file_filter: Optional[Callable[[str, List], List]] = None,
) -> str:
"""Upload the training step dependencies to S3 if present"""
return _prepare_and_upload_dependencies(
Expand All @@ -144,6 +146,7 @@ def _prepare_and_upload_dependencies(
s3_base_uri=s3_base_uri,
s3_kms_key=s3_kms_key,
sagemaker_session=sagemaker_session,
custom_file_filter=custom_file_filter,
)

def _prepare_and_upload_runtime_scripts(
Expand Down
16 changes: 15 additions & 1 deletion src/sagemaker/remote_function/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from collections import deque
import time
import threading
from typing import Dict, List, Tuple, Any
from typing import Callable, Dict, List, Optional, Tuple, Any
import functools
import itertools
import inspect
Expand Down Expand Up @@ -85,6 +85,7 @@ def remote(
spark_config: SparkConfig = None,
use_spot_instances=False,
max_wait_time_in_seconds=None,
custom_file_filter: Optional[Callable[[str, List], List]] = None,
):
"""Decorator for running the annotated function as a SageMaker training job.

Expand Down Expand Up @@ -265,6 +266,11 @@ def remote(
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
After this amount of time Amazon SageMaker will stop waiting for managed spot training
job to complete. Defaults to ``None``.

custom_file_filter (Callable[[str, List], List]): A function that filters job
dependencies to be uploaded to S3. This function is passed to the ``ignore``
argument of ``shutil.copytree``. Defaults to ``None``, which means only python
files are accepted.
"""

def _remote(func):
Expand Down Expand Up @@ -296,6 +302,7 @@ def _remote(func):
spark_config=spark_config,
use_spot_instances=use_spot_instances,
max_wait_time_in_seconds=max_wait_time_in_seconds,
custom_file_filter=custom_file_filter,
)

@functools.wraps(func)
Expand Down Expand Up @@ -506,6 +513,7 @@ def __init__(
spark_config: SparkConfig = None,
use_spot_instances=False,
max_wait_time_in_seconds=None,
custom_file_filter: Optional[Callable[[str, List], List]] = None,
):
"""Constructor for RemoteExecutor

Expand Down Expand Up @@ -692,6 +700,11 @@ def __init__(
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
After this amount of time Amazon SageMaker will stop waiting for managed spot training
job to complete. Defaults to ``None``.

custom_file_filter (Callable[[str, List], List]): A function that filters job
dependencies to be uploaded to S3. This function is passed to the ``ignore``
argument of ``shutil.copytree``. Defaults to ``None``, which means only python
files are accepted.
"""
self.max_parallel_jobs = max_parallel_jobs

Expand Down Expand Up @@ -731,6 +744,7 @@ def __init__(
spark_config=spark_config,
use_spot_instances=use_spot_instances,
max_wait_time_in_seconds=max_wait_time_in_seconds,
custom_file_filter=custom_file_filter,
)

self._state_condition = threading.Condition()
Expand Down
15 changes: 12 additions & 3 deletions src/sagemaker/remote_function/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import sys
import json
import secrets
from typing import Dict, List, Tuple
from typing import Callable, Dict, List, Optional, Tuple
from urllib.parse import urlparse
from io import BytesIO

Expand Down Expand Up @@ -193,6 +193,7 @@ def __init__(
spark_config: SparkConfig = None,
use_spot_instances=False,
max_wait_time_in_seconds=None,
custom_file_filter: Optional[Callable[[str, List], List]] = None,
):
"""Initialize a _JobSettings instance which configures the remote job.
Expand Down Expand Up @@ -363,6 +364,11 @@ def __init__(
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
After this amount of time Amazon SageMaker will stop waiting for managed spot
training job to complete. Defaults to ``None``.
custom_file_filter (Callable[[str, List], List]): A function that filters job
dependencies to be uploaded to S3. This function is passed to the ``ignore``
argument of ``shutil.copytree``. Defaults to ``None``, which means only python
files are accepted.
"""
self.sagemaker_session = sagemaker_session or Session()
self.environment_variables = resolve_value_from_config(
Expand Down Expand Up @@ -450,6 +456,7 @@ def __init__(
self.keep_alive_period_in_seconds = keep_alive_period_in_seconds
self.spark_config = spark_config
self.use_spot_instances = use_spot_instances
self.custom_file_filter = custom_file_filter
self.max_wait_time_in_seconds = max_wait_time_in_seconds
self.job_conda_env = resolve_value_from_config(
direct_input=job_conda_env,
Expand Down Expand Up @@ -649,6 +656,7 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
s3_base_uri=s3_base_uri,
s3_kms_key=job_settings.s3_kms_key,
sagemaker_session=job_settings.sagemaker_session,
custom_file_filter=job_settings.custom_file_filter,
)

stored_function = StoredFunction(
Expand Down Expand Up @@ -890,6 +898,7 @@ def _prepare_and_upload_dependencies(
s3_base_uri: str,
s3_kms_key: str,
sagemaker_session: Session,
custom_file_filter: Optional[Callable[[str, List], List]] = None,
) -> str:
"""Upload the job dependencies to S3 if present"""

Expand All @@ -906,12 +915,12 @@ def _prepare_and_upload_dependencies(
os.mkdir(tmp_workspace_dir)
# TODO Remove the following hack to avoid dir_exists error in the copy_tree call below.
tmp_workspace = os.path.join(tmp_workspace_dir, JOB_REMOTE_FUNCTION_WORKSPACE)

ignore = custom_file_filter if custom_file_filter is not None else _filter_non_python_files
if include_local_workdir:
shutil.copytree(
os.getcwd(),
tmp_workspace,
ignore=_filter_non_python_files,
ignore=ignore,
)
logger.info("Copied user workspace python scripts to '%s'", tmp_workspace)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def runtime_env_manager():
return mocked_runtime_env_manager


def custom_file_filter():
pass


@pytest.fixture
def remote_decorator_config(sagemaker_session):
return Mock(
Expand All @@ -64,6 +68,7 @@ def remote_decorator_config(sagemaker_session):
pre_execution_script="some_path",
python_sdk_whl_s3_uri=SAGEMAKER_WHL_FILE_S3_PATH,
environment_variables={"REMOTE_FUNCTION_SECRET_KEY": "some_secret_key"},
custom_file_filter=None,
)


Expand All @@ -72,6 +77,24 @@ def config_uploader(remote_decorator_config, runtime_env_manager):
return ConfigUploader(remote_decorator_config, runtime_env_manager)


@pytest.fixture
def remote_decorator_config_with_filter(sagemaker_session):
return Mock(
_JobSettings,
sagemaker_session=sagemaker_session,
s3_root_uri="some_s3_uri",
s3_kms_key="some_kms",
spark_config=SparkConfig(),
dependencies=None,
include_local_workdir=True,
pre_execution_commands="some_commands",
pre_execution_script="some_path",
python_sdk_whl_s3_uri=SAGEMAKER_WHL_FILE_S3_PATH,
environment_variables={"REMOTE_FUNCTION_SECRET_KEY": "some_secret_key"},
custom_file_filter=custom_file_filter,
)


@patch("sagemaker.feature_store.feature_processor._config_uploader.StoredFunction")
def test_prepare_and_upload_callable(mock_stored_function, config_uploader, wrapped_func):
mock_stored_function.save(wrapped_func).return_value = None
Expand Down Expand Up @@ -108,6 +131,42 @@ def test_prepare_and_upload_dependencies(mock_upload, config_uploader):
s3_base_uri=remote_decorator_config.s3_root_uri,
s3_kms_key=remote_decorator_config.s3_kms_key,
sagemaker_session=sagemaker_session,
custom_file_filter=None,
)


@patch(
"sagemaker.feature_store.feature_processor._config_uploader._prepare_and_upload_dependencies",
return_value="some_s3_uri",
)
def test_prepare_and_upload_dependencies_with_filter(
mock_job_upload, remote_decorator_config_with_filter, runtime_env_manager
):
config_uploader_with_filter = ConfigUploader(
remote_decorator_config=remote_decorator_config_with_filter,
runtime_env_manager=runtime_env_manager,
)
remote_decorator_config = config_uploader_with_filter.remote_decorator_config
config_uploader_with_filter._prepare_and_upload_dependencies(
local_dependencies_path="some/path/to/dependency",
include_local_workdir=True,
pre_execution_commands=remote_decorator_config.pre_execution_commands,
pre_execution_script_local_path=remote_decorator_config.pre_execution_script,
s3_base_uri=remote_decorator_config.s3_root_uri,
s3_kms_key=remote_decorator_config.s3_kms_key,
sagemaker_session=sagemaker_session,
custom_file_filter=remote_decorator_config_with_filter.custom_file_filter,
)

mock_job_upload.assert_called_once_with(
local_dependencies_path="some/path/to/dependency",
include_local_workdir=True,
pre_execution_commands=remote_decorator_config.pre_execution_commands,
pre_execution_script_local_path=remote_decorator_config.pre_execution_script,
s3_base_uri=remote_decorator_config.s3_root_uri,
s3_kms_key=remote_decorator_config.s3_kms_key,
sagemaker_session=sagemaker_session,
custom_file_filter=custom_file_filter,
)


Expand Down Expand Up @@ -201,6 +260,7 @@ def test_prepare_step_input_channel(
s3_base_uri=remote_decorator_config.s3_root_uri,
s3_kms_key="some_kms",
sagemaker_session=sagemaker_session,
custom_file_filter=None,
)

mock_spark_dependency_upload.assert_called_once_with(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,7 @@ def test_to_pipeline(
)

mock_dependency_upload.assert_called_once_with(
local_dependencies_path,
True,
None,
None,
f"{S3_URI}/pipeline_name",
None,
session,
local_dependencies_path, True, None, None, f"{S3_URI}/pipeline_name", None, session, None
)

mock_spark_dependency_upload.assert_called_once_with(
Expand Down Expand Up @@ -875,6 +869,7 @@ def test_remote_decorator_fields_consistency(get_execution_role, session):
"tags",
"use_spot_instances",
"max_wait_time_in_seconds",
"custom_file_filter",
}

job_settings = _JobSettings(
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/sagemaker/remote_function/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,32 @@ def square(x):
assert mock_job_settings.call_args.kwargs["image_uri"] == IMAGE


@patch(
"sagemaker.remote_function.client.serialization.deserialize_obj_from_s3",
return_value=EXPECTED_JOB_RESULT,
)
@patch("sagemaker.remote_function.client._JobSettings")
@patch("sagemaker.remote_function.client._Job.start")
def test_decorator_with_custom_file_filter(
mock_start, mock_job_settings, mock_deserialize_obj_from_s3
):
mock_job = Mock(job_name=TRAINING_JOB_NAME)
mock_job.describe.return_value = COMPLETED_TRAINING_JOB

mock_start.return_value = mock_job

def custom_file_filter():
pass

@remote(image_uri=IMAGE, s3_root_uri=S3_URI, custom_file_filter=custom_file_filter)
def square(x):
return x * x

result = square(5)
assert result == EXPECTED_JOB_RESULT
assert mock_job_settings.call_args.kwargs["custom_file_filter"] == custom_file_filter


@patch(
"sagemaker.remote_function.client.serialization.deserialize_exception_from_s3",
return_value=ZeroDivisionError("division by zero"),
Expand Down
28 changes: 28 additions & 0 deletions tests/unit/sagemaker/remote_function/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def test_start(
s3_base_uri=f"{S3_URI}/{job.job_name}",
s3_kms_key=None,
sagemaker_session=session(),
custom_file_filter=None,
)

session().sagemaker_client.create_training_job.assert_called_once_with(
Expand Down Expand Up @@ -480,6 +481,7 @@ def test_start_with_complete_job_settings(
s3_base_uri=f"{S3_URI}/{job.job_name}",
s3_kms_key=job_settings.s3_kms_key,
sagemaker_session=session(),
custom_file_filter=None,
)

session().sagemaker_client.create_training_job.assert_called_once_with(
Expand Down Expand Up @@ -778,6 +780,32 @@ def test_prepare_and_upload_dependencies(session, mock_copytree, mock_copy, mock
)


@patch("sagemaker.s3.S3Uploader.upload", return_value="some_uri")
@patch("shutil.copy2")
@patch("shutil.copytree")
@patch("sagemaker.remote_function.job.Session", return_value=mock_session())
def test_prepare_and_upload_dependencies_with_custom_filter(
session, mock_copytree, mock_copy, mock_s3_upload
):
def custom_file_filter():
pass

s3_path = _prepare_and_upload_dependencies(
local_dependencies_path="some/path/to/dependency",
include_local_workdir=True,
pre_execution_commands=["cmd_1", "cmd_2"],
pre_execution_script_local_path=None,
s3_base_uri=S3_URI,
s3_kms_key=KMS_KEY_ARN,
sagemaker_session=session,
custom_file_filter=custom_file_filter,
)

assert s3_path == mock_s3_upload.return_value

mock_copytree.assert_called_with(os.getcwd(), ANY, ignore=custom_file_filter)


@patch("sagemaker.remote_function.job.Session", return_value=mock_session())
def test_prepare_and_upload_spark_dependent_file_without_spark_config(session):
assert _prepare_and_upload_spark_dependent_files(
Expand Down