Skip to content

Commit 48dcd78

Browse files
committed
fix(remote-function): use correct line endings and s3 uris on windows
Closes #4090
1 parent 9414236 commit 48dcd78

File tree

5 files changed

+23
-41
lines changed

5 files changed

+23
-41
lines changed

src/sagemaker/remote_function/core/serialization.py

Lines changed: 17 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import dataclasses
1717
import json
18-
import os
1918
import sys
2019
import hmac
2120
import hashlib
@@ -29,6 +28,8 @@
2928

3029
from tblib import pickling_support
3130

31+
# Note: do not use os.path.join for s3 uris, fails on windows
32+
3233

3334
def _get_python_version():
3435
return f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
@@ -143,18 +144,15 @@ def serialize_func_to_s3(
143144
Raises:
144145
SerializationError: when fail to serialize function to bytes.
145146
"""
146-
147147
bytes_to_upload = CloudpickleSerializer.serialize(func)
148148

149-
_upload_bytes_to_s3(
150-
bytes_to_upload, os.path.join(s3_uri, "payload.pkl"), s3_kms_key, sagemaker_session
151-
)
149+
_upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)
152150

153151
sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)
154152

155153
_upload_bytes_to_s3(
156154
_MetaData(sha256_hash).to_json(),
157-
os.path.join(s3_uri, "metadata.json"),
155+
f"{s3_uri}/metadata.json",
158156
s3_kms_key,
159157
sagemaker_session,
160158
)
@@ -177,20 +175,16 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key:
177175
DeserializationError: when fail to serialize function to bytes.
178176
"""
179177
metadata = _MetaData.from_json(
180-
_read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session)
178+
_read_bytes_from_s3(f"{s3_uri}/metadata.json", sagemaker_session)
181179
)
182180

183-
bytes_to_deserialize = _read_bytes_from_s3(
184-
os.path.join(s3_uri, "payload.pkl"), sagemaker_session
185-
)
181+
bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)
186182

187183
_perform_integrity_check(
188184
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
189185
)
190186

191-
return CloudpickleSerializer.deserialize(
192-
os.path.join(s3_uri, "payload.pkl"), bytes_to_deserialize
193-
)
187+
return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
194188

195189

196190
def serialize_obj_to_s3(
@@ -211,15 +205,13 @@ def serialize_obj_to_s3(
211205

212206
bytes_to_upload = CloudpickleSerializer.serialize(obj)
213207

214-
_upload_bytes_to_s3(
215-
bytes_to_upload, os.path.join(s3_uri, "payload.pkl"), s3_kms_key, sagemaker_session
216-
)
208+
_upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)
217209

218210
sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)
219211

220212
_upload_bytes_to_s3(
221213
_MetaData(sha256_hash).to_json(),
222-
os.path.join(s3_uri, "metadata.json"),
214+
f"{s3_uri}/metadata.json",
223215
s3_kms_key,
224216
sagemaker_session,
225217
)
@@ -240,20 +232,16 @@ def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: s
240232
"""
241233

242234
metadata = _MetaData.from_json(
243-
_read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session)
235+
_read_bytes_from_s3(f"{s3_uri}/metadata.json", sagemaker_session)
244236
)
245237

246-
bytes_to_deserialize = _read_bytes_from_s3(
247-
os.path.join(s3_uri, "payload.pkl"), sagemaker_session
248-
)
238+
bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)
249239

250240
_perform_integrity_check(
251241
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
252242
)
253243

254-
return CloudpickleSerializer.deserialize(
255-
os.path.join(s3_uri, "payload.pkl"), bytes_to_deserialize
256-
)
244+
return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
257245

258246

259247
def serialize_exception_to_s3(
@@ -275,15 +263,13 @@ def serialize_exception_to_s3(
275263

276264
bytes_to_upload = CloudpickleSerializer.serialize(exc)
277265

278-
_upload_bytes_to_s3(
279-
bytes_to_upload, os.path.join(s3_uri, "payload.pkl"), s3_kms_key, sagemaker_session
280-
)
266+
_upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)
281267

282268
sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)
283269

284270
_upload_bytes_to_s3(
285271
_MetaData(sha256_hash).to_json(),
286-
os.path.join(s3_uri, "metadata.json"),
272+
f"{s3_uri}/metadata.json",
287273
s3_kms_key,
288274
sagemaker_session,
289275
)
@@ -304,20 +290,16 @@ def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_
304290
"""
305291

306292
metadata = _MetaData.from_json(
307-
_read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session)
293+
_read_bytes_from_s3(f"{s3_uri}/metadata.json", sagemaker_session)
308294
)
309295

310-
bytes_to_deserialize = _read_bytes_from_s3(
311-
os.path.join(s3_uri, "payload.pkl"), sagemaker_session
312-
)
296+
bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)
313297

314298
_perform_integrity_check(
315299
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
316300
)
317301

318-
return CloudpickleSerializer.deserialize(
319-
os.path.join(s3_uri, "payload.pkl"), bytes_to_deserialize
320-
)
302+
return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
321303

322304

323305
def _upload_bytes_to_s3(bytes, s3_uri, s3_kms_key, sagemaker_session):

src/sagemaker/remote_function/job.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,7 @@ def _prepare_and_upload_runtime_scripts(
860860
)
861861
shutil.copy2(spark_script_path, bootstrap_scripts)
862862

863-
with open(entrypoint_script_path, "w") as file:
863+
with open(entrypoint_script_path, "w", newline="\n") as file:
864864
file.writelines(entry_point_script)
865865

866866
bootstrap_script_path = os.path.join(

src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def _bootstrap_runtime_environment(
7474
Args:
7575
conda_env (str): conda environment to be activated. Default is None.
7676
"""
77-
workspace_archive_dir_path = os.path.join(BASE_CHANNEL_PATH, REMOTE_FUNCTION_WORKSPACE)
77+
workspace_archive_dir_path = f"{BASE_CHANNEL_PATH}/{REMOTE_FUNCTION_WORKSPACE}"
7878

7979
if not os.path.exists(workspace_archive_dir_path):
8080
logger.info(
@@ -84,7 +84,7 @@ def _bootstrap_runtime_environment(
8484
return
8585

8686
# Unpack user workspace archive first.
87-
workspace_archive_path = os.path.join(workspace_archive_dir_path, "workspace.zip")
87+
workspace_archive_path = f"{workspace_archive_dir_path}/workspace.zip"
8888
if not os.path.isfile(workspace_archive_path):
8989
logger.info(
9090
"Workspace archive '%s' does not exist. Assuming no dependencies to bootstrap.",

tests/unit/sagemaker/remote_function/core/test_serialization.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
import os.path
1615
import random
1716
import string
1817
import pytest
@@ -186,7 +185,7 @@ def square(x):
186185
serialize_func_to_s3(
187186
func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY
188187
)
189-
mock_s3[os.path.join(s3_uri, "metadata.json")] = b"not json serializable"
188+
mock_s3[f"{s3_uri}/metadata.json"] = b"not json serializable"
190189

191190
del square
192191

tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,8 @@ def test_main_no_dependency_file(
199199
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
200200
path_exists.assert_called_once_with(TEST_WORKSPACE_ARCHIVE_DIR_PATH)
201201
file_exists.assert_called_once_with(TEST_WORKSPACE_ARCHIVE_PATH)
202-
get_cwd.assert_called_once()
202+
# Called twice by pathlib on some platforms
203+
get_cwd.assert_called()
203204
list_dir.assert_called_once_with(pathlib.Path(TEST_DEPENDENCIES_PATH))
204205
run_pre_exec_script.assert_called()
205206
bootstrap_runtime.assert_not_called()

0 commit comments

Comments
 (0)