Skip to content

Commit 58d3f46

Browse files
author
Namrata Madan
committed
feature: support source directory on job side
1 parent f67918a commit 58d3f46

File tree

7 files changed

+113
-23
lines changed

7 files changed

+113
-23
lines changed

src/sagemaker/remote_function/core/stored_function.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
import os
1717
import pathlib
1818
import shutil
19+
import sys
1920

2021
from sagemaker.utils import _tmpdir
21-
from sagemaker.s3 import s3_path_join, S3Uploader
22+
from sagemaker.s3 import s3_path_join, S3Uploader, S3Downloader
2223
from sagemaker.remote_function import logging_config
2324

2425
import sagemaker.remote_function.core.serialization as serialization
@@ -80,17 +81,18 @@ def save(self, func, source_dir=None, *args, **kwargs):
8081
def _zip_and_upload_source_dir(self, source_dir):
8182
source_dir_path = pathlib.Path(source_dir)
8283
if not source_dir_path.is_dir():
83-
raise AttributeError(source_dir + "is not a valid directory.")
84+
raise AttributeError(source_dir + " is not a valid directory.")
85+
86+
s3_path = s3_path_join(self.s3_base_uri, "source_dir")
87+
logger.info(f"Uploading function source directory to {s3_path}")
8488
with _tmpdir() as tmp:
8589
archived_filepath = shutil.make_archive(
86-
os.path.join(tmp, "source"), "zip", source_dir_path.parent, source_dir_path.name
87-
)
88-
S3Uploader.upload(
89-
archived_filepath,
90-
s3_path_join(self.s3_base_uri, "source_dir"),
91-
self.s3_kms_key,
92-
self.sagemaker_session,
90+
os.path.join(tmp, source_dir_path.name),
91+
"zip",
92+
source_dir_path.parent,
93+
source_dir_path.name,
9394
)
95+
S3Uploader.upload(archived_filepath, s3_path, self.s3_kms_key, self.sagemaker_session)
9496

9597
def load_and_invoke(self) -> None:
9698
"""Load and deserialize the function and the arguments and then execute it."""
@@ -109,6 +111,8 @@ def load_and_invoke(self) -> None:
109111
self.sagemaker_session, s3_path_join(self.s3_base_uri, "arguments.pkl")
110112
)
111113

114+
self._download_and_unzip_source_dir()
115+
112116
logger.info("Invoking the function")
113117
result = func(*args, **kwargs)
114118

@@ -121,3 +125,27 @@ def load_and_invoke(self) -> None:
121125
s3_path_join(self.s3_base_uri, "results.pkl"),
122126
self.s3_kms_key,
123127
)
128+
129+
def _download_and_unzip_source_dir(self):
130+
source_dir_s3_path = s3_path_join(self.s3_base_uri, "source_dir")
131+
local_source_dir_path = os.path.join(os.getcwd(), "source_dir")
132+
133+
logger.info(
134+
f"Downloading source modules from {source_dir_s3_path} to {local_source_dir_path}"
135+
)
136+
137+
downloaded_paths = S3Downloader.download(
138+
source_dir_s3_path,
139+
local_source_dir_path,
140+
kms_key=self.s3_kms_key,
141+
sagemaker_session=self.sagemaker_session,
142+
)
143+
144+
if len(downloaded_paths) < 1:
145+
return
146+
147+
source_dir_archive_path = downloaded_paths[0]
148+
shutil.unpack_archive(
149+
source_dir_archive_path, pathlib.Path(source_dir_archive_path).parent.absolute()
150+
)
151+
sys.path.append(local_source_dir_path)

src/sagemaker/s3.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ def download(s3_uri, local_path, kms_key=None, sagemaker_session=None):
171171
manages interactions with Amazon SageMaker APIs and any other
172172
AWS services needed. If not specified, one is created
173173
using the default AWS configuration chain.
174+
175+
Returns:
176+
list[str]: List of local paths of downloaded files
174177
"""
175178
sagemaker_session = sagemaker_session or Session()
176179
bucket, key_prefix = parse_s3_url(url=s3_uri)
@@ -179,7 +182,7 @@ def download(s3_uri, local_path, kms_key=None, sagemaker_session=None):
179182
else:
180183
extra_args = None
181184

182-
sagemaker_session.download_data(
185+
return sagemaker_session.download_data(
183186
path=local_path, bucket=bucket, key_prefix=key_prefix, extra_args=extra_args
184187
)
185188

src/sagemaker/session.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,9 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
389389
download operation. Please refer to the ExtraArgs parameter in the boto3
390390
documentation here:
391391
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-example-download-file.html
392+
393+
Returns:
394+
list[str]: List of local paths of downloaded files
392395
"""
393396
# Initialize the S3 client.
394397
if self.s3_client is None:
@@ -408,7 +411,12 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
408411
if next_token != "":
409412
request_parameters.update({"ContinuationToken": next_token})
410413
response = s3.list_objects_v2(**request_parameters)
411-
contents = response.get("Contents")
414+
contents = response.get("Contents", None)
415+
if not contents:
416+
LOGGER.info(
417+
"Nothing to download from bucket: %s, key_prefix: %s.", bucket, key_prefix
418+
)
419+
return []
412420
# For each object, save its key or directory.
413421
for s3_object in contents:
414422
key = s3_object.get("Key")
@@ -417,6 +425,7 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
417425

418426
# For each object key, create the directory on the local machine if needed, and then
419427
# download the file.
428+
downloaded_paths = []
420429
for key in keys:
421430
tail_s3_uri_path = os.path.basename(key)
422431
if not os.path.splitext(key_prefix)[1]:
@@ -427,6 +436,8 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
427436
s3.download_file(
428437
Bucket=bucket, Key=key, Filename=destination_path, ExtraArgs=extra_args
429438
)
439+
downloaded_paths.append(destination_path)
440+
return downloaded_paths
430441

431442
def read_s3_file(self, bucket, key_prefix):
432443
"""Read a single file from S3.

tests/integ/sagemaker/remote_function/helpers/__init__.py

Whitespace-only changes.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from __future__ import absolute_import
2+
3+
4+
def square(x):
5+
return x * x

tests/integ/sagemaker/remote_function/test_decorator.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,32 @@ def divide(x, y):
9898
assert divide(10, 2) == 5
9999

100100

101+
def test_with_local_dependencies(
102+
sagemaker_session, dummy_container_without_error, cpu_instance_type
103+
):
104+
dependencies_path = os.path.join(DATA_DIR, "remote_function", "requirements.txt")
105+
source_dir_path = os.path.join(os.path.dirname(__file__), "helpers")
106+
107+
@remote(
108+
role=ROLE,
109+
image_uri=dummy_container_without_error,
110+
dependencies=dependencies_path,
111+
instance_type=cpu_instance_type,
112+
sagemaker_session=sagemaker_session,
113+
source_dir=source_dir_path,
114+
)
115+
def square(x):
116+
from helpers import local_module
117+
118+
return local_module.square(x)
119+
120+
assert square(9) == 81
121+
122+
101123
def test_with_additional_dependencies(
102124
sagemaker_session, dummy_container_without_error, cpu_instance_type
103125
):
104-
dependencies_path = os.path.join(DATA_DIR, "remote_function/requirements.txt")
126+
dependencies_path = os.path.join(DATA_DIR, "remote_function", "requirements.txt")
105127

106128
@remote(
107129
role=ROLE,
@@ -121,7 +143,7 @@ def cuberoot(x):
121143
def test_additional_dependencies_with_job_conda_env(
122144
sagemaker_session, dummy_container_with_conda, cpu_instance_type
123145
):
124-
dependencies_path = os.path.join(DATA_DIR, "remote_function/requirements.txt")
146+
dependencies_path = os.path.join(DATA_DIR, "remote_function", "requirements.txt")
125147

126148
@remote(
127149
role=ROLE,
@@ -142,7 +164,7 @@ def cuberoot(x):
142164
def test_additional_dependencies_with_default_conda_env(
143165
sagemaker_session, dummy_container_with_conda, cpu_instance_type
144166
):
145-
dependencies_path = os.path.join(DATA_DIR, "remote_function/requirements.txt")
167+
dependencies_path = os.path.join(DATA_DIR, "remote_function", "requirements.txt")
146168

147169
@remote(
148170
role=ROLE,
@@ -162,7 +184,7 @@ def cuberoot(x):
162184
def test_additional_dependencies_with_non_existent_conda_env(
163185
sagemaker_session, dummy_container_with_conda, cpu_instance_type
164186
):
165-
dependencies_path = os.path.join(DATA_DIR, "remote_function/requirements.txt")
187+
dependencies_path = os.path.join(DATA_DIR, "remote_function", "requirements.txt")
166188

167189
@remote(
168190
role=ROLE,
@@ -184,7 +206,7 @@ def cuberoot(x):
184206
def test_with_non_existent_dependencies(
185207
sagemaker_session, dummy_container_without_error, cpu_instance_type
186208
):
187-
dependencies_path = os.path.join(DATA_DIR, "remote_function/non_existent_requirements.txt")
209+
dependencies_path = os.path.join(DATA_DIR, "remote_function", "non_existent_requirements.txt")
188210

189211
@remote(
190212
role=ROLE,
@@ -204,7 +226,7 @@ def test_with_incompatible_dependencies(
204226
sagemaker_session, dummy_container_without_error, cpu_instance_type
205227
):
206228

207-
dependencies_path = os.path.join(DATA_DIR, "remote_function/old_deps_requirements.txt")
229+
dependencies_path = os.path.join(DATA_DIR, "remote_function", "old_deps_requirements.txt")
208230

209231
@remote(
210232
role=ROLE,

tests/unit/sagemaker/remote_function/core/test_stored_function.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import os
1516
import tempfile
1617

1718
import pytest
@@ -31,12 +32,12 @@ def random_s3_uri():
3132
return "".join(random.choices(string.ascii_uppercase + string.digits, k=10))
3233

3334

34-
def upload(b, s3_uri, kms_key=None, sagemaker_session=None):
35+
def upload_bytes(b, s3_uri, kms_key=None, sagemaker_session=None):
3536
assert kms_key == KMS_KEY
3637
mock_s3[s3_uri] = b
3738

3839

39-
def read(s3_uri, sagemaker_session=None):
40+
def read_bytes(s3_uri, sagemaker_session=None):
4041
return mock_s3[s3_uri]
4142

4243

@@ -48,10 +49,11 @@ def quadratic(x=2, *, a=1, b=0, c=0):
4849
"args, kwargs",
4950
[([], {}), ([3], {}), ([], {"a": 2, "b": 1, "c": 1})],
5051
)
51-
@patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload)
52-
@patch("sagemaker.s3.S3Downloader.read_bytes", new=read)
52+
@patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload_bytes)
53+
@patch("sagemaker.s3.S3Downloader.read_bytes", new=read_bytes)
5354
@patch("sagemaker.s3.S3Uploader.upload")
54-
def test_save_and_load(mock_s3_upload, args, kwargs):
55+
@patch("sagemaker.s3.S3Downloader.download")
56+
def test_save_and_load(s3_source_dir_download, s3_source_dir_upload, args, kwargs):
5557
session = Mock()
5658
s3_base_uri = random_s3_uri()
5759
source_dir = tempfile.TemporaryDirectory()
@@ -61,12 +63,31 @@ def test_save_and_load(mock_s3_upload, args, kwargs):
6163
)
6264
stored_function.save(quadratic, source_dir.name, *args, **kwargs)
6365

64-
mock_s3_upload.assert_called_once_with(ANY, f"{s3_base_uri}/source_dir", KMS_KEY, session)
66+
s3_source_dir_upload.assert_called_once_with(ANY, f"{s3_base_uri}/source_dir", KMS_KEY, session)
6567

6668
stored_function.load_and_invoke()
6769

70+
s3_source_dir_download.assert_called_once_with(
71+
f"{s3_base_uri}/source_dir",
72+
os.path.join(os.getcwd(), "source_dir"),
73+
kms_key=KMS_KEY,
74+
sagemaker_session=session,
75+
)
76+
6877
assert deserialize_obj_from_s3(session, s3_uri=f"{s3_base_uri}/results.pkl") == quadratic(
6978
*args, **kwargs
7079
)
7180

7281
source_dir.cleanup()
82+
83+
84+
def test_save_invalid_source_directory():
85+
session = Mock()
86+
s3_base_uri = random_s3_uri()
87+
88+
stored_function = StoredFunction(
89+
sagemaker_session=session, s3_base_uri=s3_base_uri, s3_kms_key=KMS_KEY
90+
)
91+
92+
with pytest.raises(AttributeError, match="not a valid directory."):
93+
stored_function.save(quadratic, "invalid_source_dir")

0 commit comments

Comments
 (0)