Skip to content

Commit d957c8a

Browse files
ZhankuilNamrata Madan
authored andcommitted
pathway: add logging
1 parent 9545c32 commit d957c8a

File tree

10 files changed

+112
-26
lines changed

10 files changed

+112
-26
lines changed

src/sagemaker/remote_function/client.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from typing import Dict, List, Tuple, Any
2121
import functools
2222
import inspect
23-
import logging
2423

2524
from botocore.exceptions import ClientError
2625

@@ -29,6 +28,7 @@
2928
from sagemaker.session import Session
3029
from sagemaker.s3 import s3_path_join
3130
from sagemaker.remote_function.job import _JobSettings, _Job
31+
from sagemaker.remote_function import logging_config
3232

3333

3434
_API_CALL_LIMIT = {
@@ -44,8 +44,9 @@
4444
_CANCELLED = "CANCELLED"
4545
_FINISHED = "FINISHED"
4646

47-
LOGGER = logging.getLogger(__name__)
48-
LOGGER.setLevel(logging.INFO)
47+
48+
logging_config.basic_config()
49+
logger = logging_config.get_logger()
4950

5051

5152
def remote(
@@ -191,7 +192,7 @@ def _submit_worker(executor):
191192
else:
192193
executor._running_jobs[job.job_name] = job
193194
except Exception: # pylint: disable=broad-except
194-
LOGGER.exception("Error occurred while submitting CreateTrainingJob requests.")
195+
logger.exception("Error occurred while submitting CreateTrainingJob requests.")
195196

196197

197198
def _polling_worker(executor):
@@ -227,13 +228,13 @@ def _polling_worker(executor):
227228
!= "LimitExceededException"
228229
):
229230
# Couldn't check the job status, move on
230-
LOGGER.exception(
231+
logger.exception(
231232
"Error occurred while checking the status of job %s", job_name
232233
)
233234
del executor._running_jobs[job_name]
234235
executor._semaphore.release()
235236
except Exception: # pylint: disable=broad-except
236-
LOGGER.exception("Error occurred while monitoring the job statuses.")
237+
logger.exception("Error occurred while monitoring the job statuses.")
237238

238239

239240
# TODO: 1) add map method.

src/sagemaker/remote_function/core/runtime_environment.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,16 @@
1414

1515
from __future__ import absolute_import
1616

17-
import logging
1817
import sys
1918
import os
2019
import shlex
2120
import subprocess
2221
from sagemaker.s3 import s3_path_join, S3Uploader, S3Downloader
2322
from sagemaker.session import Session
2423
from sagemaker.remote_function.errors import RuntimeEnvironmentError
24+
from sagemaker.remote_function.logging_config import get_logger
2525

26-
logging.basicConfig(level=logging.INFO)
27-
logger = logging.getLogger(__name__)
26+
logger = get_logger()
2827

2928

3029
class RuntimeEnvironmentManager:
@@ -182,10 +181,10 @@ def _log_error(process):
182181
for line in iter(pipe.readline, b""):
183182
error_str = str(line, "UTF-8")
184183
if "ERROR:" in error_str:
185-
logging.error(error_str)
184+
logger.error(error_str)
186185
error_logs = error_logs + error_str
187186
else:
188-
logging.warn(error_str)
187+
logger.warn(error_str)
189188

190189
return error_logs
191190

src/sagemaker/remote_function/core/stored_function.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,17 @@
1616
import os
1717
import pathlib
1818
import shutil
19+
1920
from sagemaker.utils import _tmpdir
2021
from sagemaker.s3 import s3_path_join, S3Uploader
22+
from sagemaker.remote_function import logging_config
2123

2224
import sagemaker.remote_function.core.serialization as serialization
2325

2426

27+
logger = logging_config.get_logger()
28+
29+
2530
class StoredFunction:
2631
"""Class representing a remote function stored in S3."""
2732

@@ -81,15 +86,27 @@ def _zip_and_upload_source_dir(self, source_dir):
8186

8287
def load_and_invoke(self) -> None:
8388
"""Load and deserialize the function and the arguments and then execute it."""
89+
90+
logger.info(
91+
f"Deserializing function code from {s3_path_join(self.s3_base_uri, 'function.pkl')}"
92+
)
8493
func = serialization.deserialize_func_from_s3(
8594
self.sagemaker_session, s3_path_join(self.s3_base_uri, "function.pkl")
8695
)
96+
97+
logger.info(
98+
f"Deserializing function arguments from {s3_path_join(self.s3_base_uri, 'arguments.pkl')}"
99+
)
87100
args, kwargs = serialization.deserialize_obj_from_s3(
88101
self.sagemaker_session, s3_path_join(self.s3_base_uri, "arguments.pkl")
89102
)
90103

104+
logger.info("Invoking the function")
91105
result = func(*args, **kwargs)
92106

107+
logger.info(
108+
f"Serializing the function return and uploading to {s3_path_join(self.s3_base_uri, 'results.pkl')}"
109+
)
93110
serialization.serialize_obj_to_s3(
94111
result,
95112
self.sagemaker_session,

src/sagemaker/remote_function/job.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
JOBS_CONTAINER_ENTRYPOINT = ["invoke-remote-function"]
2828

2929

30-
# TODO: 1) provide reasonable default values, e.g. image_uri
31-
# 2) extend this class to load job settings from the configuration files.
30+
# TODO: extend this class to load job settings from the configuration files.
3231
class _JobSettings:
3332
"""Helper class that processes the job settings.
3433
@@ -62,7 +61,7 @@ def __init__(
6261

6362
self.environment_variables = environment_variables
6463

65-
self.image_uri = image_uri or "fake_image_uri"
64+
self.image_uri = image_uri or _JobSettings._get_default_image_uri()
6665
self.dependencies = dependencies
6766

6867
self.instance_type = instance_type
@@ -92,6 +91,13 @@ def __init__(
9291

9392
self.tags = [] if tags is None else [{"Key": k, "Value": v} for k, v in tags]
9493

94+
@staticmethod
95+
def _get_default_image_uri():
96+
"""Get the default image uri"""
97+
98+
# TODO: provide default image uri if not set
99+
raise ValueError("image_uri must be set")
100+
95101

96102
class _Job:
97103
"""Helper class that interacts with the SageMaker training service."""

src/sagemaker/remote_function/job_driver.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515

1616
import argparse
1717
import sys
18+
1819
import boto3
1920

2021
from sagemaker.session import Session
2122
from sagemaker.remote_function.errors import handle_error
2223
from sagemaker.remote_function.core.runtime_environment import RuntimeEnvironmentManager
24+
from sagemaker.remote_function import logging_config
2325

2426

2527
SUCCESS_EXIT_CODE = 0
@@ -79,6 +81,10 @@ def _execute_remote_function(sagemaker_session, s3_base_uri, s3_kms_key):
7981

8082
def main():
8183
"""Entry point for job driver script"""
84+
85+
logging_config.basic_config()
86+
logger = logging_config.get_logger()
87+
8288
exit_code = SUCCESS_EXIT_CODE
8389
try:
8490
args = _parse_agrs()
@@ -105,6 +111,7 @@ def main():
105111
_execute_remote_function(sagemaker_session, s3_base_uri, s3_kms_key)
106112

107113
except Exception as e: # pylint: disable=broad-except
114+
logger.exception("Error encountered when invoking the remote function.")
108115
exit_code = handle_error(e, sagemaker_session, s3_base_uri, s3_kms_key)
109116
finally:
110117
sys.exit(exit_code)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Utilities related to logging."""
14+
from __future__ import absolute_import
15+
16+
import logging
17+
18+
19+
def get_logger():
20+
"""Return a logger with the name 'sagemaker'"""
21+
logger = logging.getLogger("sagemaker")
22+
logger.setLevel(logging.INFO)
23+
return logger
24+
25+
26+
def basic_config():
27+
"""Set logger configuration."""
28+
logging.basicConfig(format="%(asctime)s %(name)-12s %(levelname)-8s %(message)s")

tests/integ/sagemaker/remote_function/test_decorator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def divide(x, y):
8080
divide(10, 2)
8181

8282

83+
# TODO: add VPC settings, update SageMakerRole with KMS permissions
8384
@pytest.mark.skip
8485
def test_advanced_job_setting(
8586
sagemaker_session, dummy_container_without_error, cpu_instance_type, s3_kms_key
@@ -88,7 +89,6 @@ def test_advanced_job_setting(
8889
role=ROLE,
8990
image_uri=dummy_container_without_error,
9091
instance_type=cpu_instance_type,
91-
# TODO: add VPC settings
9292
s3_kms_key=s3_kms_key,
9393
sagemaker_session=sagemaker_session,
9494
)

tests/unit/sagemaker/remote_function/core/test_serialization.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,15 @@ def test_serialize_deserialize_data():
9292
deserialized = deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri)
9393

9494
assert deserialized == [10]
95+
96+
97+
@patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload)
98+
@patch("sagemaker.s3.S3Downloader.read_bytes", new=read)
99+
def test_serialize_deserialize_none():
100+
101+
s3_uri = random_s3_uri()
102+
serialize_obj_to_s3(None, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY)
103+
104+
deserialized = deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri)
105+
106+
assert deserialized is None

tests/unit/sagemaker/remote_function/test_client.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -171,16 +171,18 @@ def test_executor_invalid_arguments():
171171
e.submit(job_function, 1, 2, c=3, d=4)
172172

173173

174-
def test_executor_submit_after_shutdown():
174+
@patch("sagemaker.remote_function.client._JobSettings")
175+
def test_executor_submit_after_shutdown(*args):
175176
with pytest.raises(RuntimeError):
176177
with RemoteExecutor(max_parallel_job=1, s3_root_uri="s3://bucket/") as e:
177178
pass
178179
e.submit(job_function, 1, 2, c=3, d=4)
179180

180181

181182
@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT)
183+
@patch("sagemaker.remote_function.client._JobSettings")
182184
@patch("sagemaker.remote_function.client._Job.start")
183-
def test_executor_submit_happy_case(mock_start):
185+
def test_executor_submit_happy_case(mock_start, *args):
184186
mock_job = Mock()
185187
mock_job.describe.return_value = COMPLETED_TRAINING_JOB
186188
mock_job.job_name = TRAINING_JOB_NAME
@@ -202,8 +204,9 @@ def test_executor_submit_happy_case(mock_start):
202204

203205

204206
@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT)
207+
@patch("sagemaker.remote_function.client._JobSettings")
205208
@patch("sagemaker.remote_function.client._Job.start")
206-
def test_executor_submit_enforcing_max_parallel_jobs(mock_start):
209+
def test_executor_submit_enforcing_max_parallel_jobs(mock_start, *args):
207210
mock_job = Mock()
208211
mock_job.describe.return_value = INPROGRESS_TRAINING_JOB
209212
mock_start.return_value = mock_job
@@ -227,8 +230,9 @@ def test_executor_submit_enforcing_max_parallel_jobs(mock_start):
227230

228231

229232
@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT)
233+
@patch("sagemaker.remote_function.client._JobSettings")
230234
@patch("sagemaker.remote_function.client._Job.start")
231-
def test_executor_fails_to_start_job(mock_start):
235+
def test_executor_fails_to_start_job(mock_start, *args):
232236
mock_job = Mock()
233237
mock_job.describe.return_value = COMPLETED_TRAINING_JOB
234238

@@ -245,8 +249,9 @@ def test_executor_fails_to_start_job(mock_start):
245249

246250

247251
@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT)
252+
@patch("sagemaker.remote_function.client._JobSettings")
248253
@patch("sagemaker.remote_function.client._Job.start")
249-
def test_executor_submit_and_cancel(mock_start):
254+
def test_executor_submit_and_cancel(mock_start, *args):
250255
mock_job = Mock()
251256
mock_job.describe.return_value = INPROGRESS_TRAINING_JOB
252257
mock_start.return_value = mock_job
@@ -270,8 +275,9 @@ def test_executor_submit_and_cancel(mock_start):
270275

271276

272277
@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT)
278+
@patch("sagemaker.remote_function.client._JobSettings")
273279
@patch("sagemaker.remote_function.client._Job.start")
274-
def test_executor_describe_job_throttled_temporarily(mock_start):
280+
def test_executor_describe_job_throttled_temporarily(mock_start, *args):
275281
throttling_error = ClientError(
276282
error_response={"Error": {"Code": "LimitExceededException"}},
277283
operation_name="SomeOperation",
@@ -298,8 +304,9 @@ def test_executor_describe_job_throttled_temporarily(mock_start):
298304

299305

300306
@patch("sagemaker.remote_function.client._API_CALL_LIMIT", new=API_CALL_LIMIT)
307+
@patch("sagemaker.remote_function.client._JobSettings")
301308
@patch("sagemaker.remote_function.client._Job.start")
302-
def test_executor_describe_job_failed_permanently(mock_start):
309+
def test_executor_describe_job_failed_permanently(mock_start, *args):
303310
mock_job = Mock()
304311
mock_job.describe.side_effect = RuntimeError()
305312
mock_start.return_value = mock_job
@@ -333,7 +340,8 @@ def test_executor_describe_job_failed_permanently(mock_start):
333340
),
334341
],
335342
)
336-
def test_executor_submit_invalid_function_args(args, kwargs, error_message):
343+
@patch("sagemaker.remote_function.client._JobSettings")
344+
def test_executor_submit_invalid_function_args(MockJobSettings, args, kwargs, error_message):
337345
with pytest.raises(TypeError) as e:
338346
with RemoteExecutor(max_parallel_job=1, s3_root_uri="s3://bucket/") as executor:
339347
executor.submit(job_function, *args, **kwargs)

tests/unit/sagemaker/remote_function/test_job.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import pytest
1516
from mock import patch, Mock, ANY
1617
from sagemaker.remote_function.job import _JobSettings, _Job
1718

@@ -80,13 +81,20 @@ def job_function(a, b=1, *, c, d=3):
8081
@patch("sagemaker.remote_function.job.Session", return_value=mock_session())
8182
@patch("sagemaker.remote_function.job.get_execution_role", return_value=DEFAULT_ROLE_ARN)
8283
def test_default_settings(*args):
83-
job_settings = _JobSettings()
84-
assert job_settings.image_uri == "fake_image_uri"
85-
assert job_settings.s3_root_uri == f"s3://{BUCKET}/fake_image_uri"
84+
job_settings = _JobSettings(image_uri="image_uri")
85+
assert job_settings.image_uri == "image_uri"
86+
assert job_settings.s3_root_uri == f"s3://{BUCKET}/image_uri"
8687
assert job_settings.role == DEFAULT_ROLE_ARN
8788
assert job_settings.tags == []
8889

8990

91+
@patch("sagemaker.remote_function.job.Session", return_value=mock_session())
92+
@patch("sagemaker.remote_function.job.get_execution_role", return_value=DEFAULT_ROLE_ARN)
93+
def test_fails_on_missing_image_uri(*args):
94+
with pytest.raises(ValueError):
95+
_JobSettings(image_uri=None)
96+
97+
9098
@patch("sagemaker.remote_function.job.RuntimeEnvironmentManager")
9199
@patch("sagemaker.remote_function.job.StoredFunction")
92100
@patch("sagemaker.remote_function.job.Session", return_value=mock_session())

0 commit comments

Comments
 (0)