Skip to content

fix: use correct line endings and s3 uris on windows #4118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 17 additions & 35 deletions src/sagemaker/remote_function/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import dataclasses
import json
import os
import sys
import hmac
import hashlib
Expand All @@ -29,6 +28,8 @@

from tblib import pickling_support

# Note: do not use os.path.join for s3 uris, fails on windows
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wanted to understand more about this shortcoming, if this was researched already?
We may need to add this in MergeChecklist for contributors to be aware.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, windows users run across this with s3 related tooling every so often. The os.sep on windows is \\ which doesn't play nice with s3. Not the first time we've run run in to it, other tools make the same misstep.

When testing the changes to the line endings, we had to change how the s3 uris are constructed to get it making the correct objects on s3 from windows. Python can handle / paths for windows filesystems nicely. But s3 will make an object with \ in the object name instead of a prefix.

image

The tests were also failing on windows since they expect / in the assertions. But were getting \\.

Before change
pytest ./tests/unit/sagemaker/remote_function/
...

E           AssertionError: expected call not found.
E           Expected: exists('/opt/ml/input/data/sm_rf_user_ws')
E           Actual: exists('/opt/ml/input/data\\sm_rf_user_ws')

venv\Lib\site-packages\mock\mock.py:913: AssertionError
------------------------------------------------------------------- Captured stderr call -------------------------------------------------------------------
2023-09-13 18:07:57,722 sagemaker.remote_function INFO     Successfully unpacked workspace archive at '\user\set\workdir'.
2023-09-13 18:07:57,722 sagemaker.remote_function INFO     Did not find any dependency file in workspace directory at '/opt/ml/input/data\sm_rf_user_ws'. Assuming no additional dependencies to install.
================================================================= short test summary info ==================================================================
FAILED tests/unit/sagemaker/remote_function/core/test_serialization.py::test_deserialize_func_deserialization_error - AssertionError: Regex pattern "Error...
FAILED tests/unit/sagemaker/remote_function/core/test_serialization.py::test_deserialize_obj_deserialization_error - AssertionError: Regex pattern "Error ...
FAILED tests/unit/sagemaker/remote_function/core/test_serialization.py::test_serialize_deserialize_service_error - AssertionError: Regex pattern "Failed t...
FAILED tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py::test_main_success - AssertionError: expected call n...
FAILED tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py::test_main_channel_folder_does_not_exist - Assertion...
FAILED tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py::test_main_no_workspace_archive - AssertionError: ex...
FAILED tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py::test_main_no_dependency_file - AssertionError: expe...

I can add that if you like? Or separate issue/PR to do a search over the sdk for this pattern and correct them?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'l take a note of this to our team and drive the required changes. Thanks for the explanation.



def _get_python_version():
return f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
Expand Down Expand Up @@ -143,18 +144,15 @@ def serialize_func_to_s3(
Raises:
SerializationError: when fail to serialize function to bytes.
"""

bytes_to_upload = CloudpickleSerializer.serialize(func)

_upload_bytes_to_s3(
bytes_to_upload, os.path.join(s3_uri, "payload.pkl"), s3_kms_key, sagemaker_session
)
_upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)

sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)

_upload_bytes_to_s3(
_MetaData(sha256_hash).to_json(),
os.path.join(s3_uri, "metadata.json"),
f"{s3_uri}/metadata.json",
s3_kms_key,
sagemaker_session,
)
Expand All @@ -177,20 +175,16 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key:
DeserializationError: when fail to serialize function to bytes.
"""
metadata = _MetaData.from_json(
_read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session)
_read_bytes_from_s3(f"{s3_uri}/metadata.json", sagemaker_session)
)

bytes_to_deserialize = _read_bytes_from_s3(
os.path.join(s3_uri, "payload.pkl"), sagemaker_session
)
bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)

_perform_integrity_check(
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
)

return CloudpickleSerializer.deserialize(
os.path.join(s3_uri, "payload.pkl"), bytes_to_deserialize
)
return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)


def serialize_obj_to_s3(
Expand All @@ -211,15 +205,13 @@ def serialize_obj_to_s3(

bytes_to_upload = CloudpickleSerializer.serialize(obj)

_upload_bytes_to_s3(
bytes_to_upload, os.path.join(s3_uri, "payload.pkl"), s3_kms_key, sagemaker_session
)
_upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)

sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)

_upload_bytes_to_s3(
_MetaData(sha256_hash).to_json(),
os.path.join(s3_uri, "metadata.json"),
f"{s3_uri}/metadata.json",
s3_kms_key,
sagemaker_session,
)
Expand All @@ -240,20 +232,16 @@ def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: s
"""

metadata = _MetaData.from_json(
_read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session)
_read_bytes_from_s3(f"{s3_uri}/metadata.json", sagemaker_session)
)

bytes_to_deserialize = _read_bytes_from_s3(
os.path.join(s3_uri, "payload.pkl"), sagemaker_session
)
bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)

_perform_integrity_check(
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
)

return CloudpickleSerializer.deserialize(
os.path.join(s3_uri, "payload.pkl"), bytes_to_deserialize
)
return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)


def serialize_exception_to_s3(
Expand All @@ -275,15 +263,13 @@ def serialize_exception_to_s3(

bytes_to_upload = CloudpickleSerializer.serialize(exc)

_upload_bytes_to_s3(
bytes_to_upload, os.path.join(s3_uri, "payload.pkl"), s3_kms_key, sagemaker_session
)
_upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)

sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)

_upload_bytes_to_s3(
_MetaData(sha256_hash).to_json(),
os.path.join(s3_uri, "metadata.json"),
f"{s3_uri}/metadata.json",
s3_kms_key,
sagemaker_session,
)
Expand All @@ -304,20 +290,16 @@ def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_
"""

metadata = _MetaData.from_json(
_read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session)
_read_bytes_from_s3(f"{s3_uri}/metadata.json", sagemaker_session)
)

bytes_to_deserialize = _read_bytes_from_s3(
os.path.join(s3_uri, "payload.pkl"), sagemaker_session
)
bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)

_perform_integrity_check(
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
)

return CloudpickleSerializer.deserialize(
os.path.join(s3_uri, "payload.pkl"), bytes_to_deserialize
)
return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)


def _upload_bytes_to_s3(bytes, s3_uri, s3_kms_key, sagemaker_session):
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/remote_function/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,7 @@ def _prepare_and_upload_runtime_scripts(
)
shutil.copy2(spark_script_path, bootstrap_scripts)

with open(entrypoint_script_path, "w") as file:
with open(entrypoint_script_path, "w", newline="\n") as file:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this change backwards compatible?

Copy link
Contributor Author

@jmahlik jmahlik Sep 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it ensures the line endings on the job_driver.sh get written out with \n as newlines on all platforms instead of \r\n on windows. They need to be \n since it will be executed in a linux container.

Currently, it's not an issue if submitting remote functions from mac or linux. But submitting them from windows, the script gets written out with \r\n the uploaded to s3. Which causes /opt/ml/input/data/sagemaker_remote_function_bootstrap/job_driver.sh: line 3: $'\r': command not found when the script is executed on the training image (which we can safely assume will not be a windows container).

Side note: windows users run across the line endings on .sh scripts quite often. Another workaround is running something like dos2unix on the scripts before executing them on linux, but that's not possible in this case since the sdk is writing them out.

file.writelines(entry_point_script)

bootstrap_script_path = os.path.join(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _bootstrap_runtime_environment(
Args:
conda_env (str): conda environment to be activated. Default is None.
"""
workspace_archive_dir_path = os.path.join(BASE_CHANNEL_PATH, REMOTE_FUNCTION_WORKSPACE)
workspace_archive_dir_path = f"{BASE_CHANNEL_PATH}/{REMOTE_FUNCTION_WORKSPACE}"

if not os.path.exists(workspace_archive_dir_path):
logger.info(
Expand All @@ -84,7 +84,7 @@ def _bootstrap_runtime_environment(
return

# Unpack user workspace archive first.
workspace_archive_path = os.path.join(workspace_archive_dir_path, "workspace.zip")
workspace_archive_path = f"{workspace_archive_dir_path}/workspace.zip"
if not os.path.isfile(workspace_archive_path):
logger.info(
"Workspace archive '%s' does not exist. Assuming no dependencies to bootstrap.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os.path
import random
import string
import pytest
Expand Down Expand Up @@ -186,7 +185,7 @@ def square(x):
serialize_func_to_s3(
func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY
)
mock_s3[os.path.join(s3_uri, "metadata.json")] = b"not json serializable"
mock_s3[f"{s3_uri}/metadata.json"] = b"not json serializable"

del square

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ def test_main_no_dependency_file(
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
path_exists.assert_called_once_with(TEST_WORKSPACE_ARCHIVE_DIR_PATH)
file_exists.assert_called_once_with(TEST_WORKSPACE_ARCHIVE_PATH)
get_cwd.assert_called_once()
# Called twice by pathlib on some platforms
get_cwd.assert_called()
list_dir.assert_called_once_with(pathlib.Path(TEST_DEPENDENCIES_PATH))
run_pre_exec_script.assert_called()
bootstrap_runtime.assert_not_called()
Expand Down