Skip to content

Commit 9e70ce1

Browse files
authored
Merge branch 'master' into master-training-remote-debug
2 parents d1dd8ec + 8774952 commit 9e70ce1

File tree

15 files changed

+349
-90
lines changed

15 files changed

+349
-90
lines changed

src/sagemaker/config/config.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@
2828
from botocore.utils import merge_dicts
2929
from six.moves.urllib.parse import urlparse
3030
from sagemaker.config.config_schema import SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA
31-
from sagemaker.config.config_utils import get_sagemaker_config_logger
31+
from sagemaker.config.config_utils import non_repeating_log_factory, get_sagemaker_config_logger
3232

3333
logger = get_sagemaker_config_logger()
34+
log_info_function = non_repeating_log_factory(logger, "info")
3435

3536
_APP_NAME = "sagemaker"
3637
# The default name of the config file.
@@ -52,7 +53,9 @@
5253
S3_PREFIX = "s3://"
5354

5455

55-
def load_sagemaker_config(additional_config_paths: List[str] = None, s3_resource=None) -> dict:
56+
def load_sagemaker_config(
57+
additional_config_paths: List[str] = None, s3_resource=None, repeat_log=False
58+
) -> dict:
5659
"""Loads config files and merges them.
5760
5861
By default, this method first searches for config files in the default locations
@@ -99,6 +102,8 @@ def load_sagemaker_config(additional_config_paths: List[str] = None, s3_resource
99102
<https://boto3.amazonaws.com/v1/documentation/api\
100103
/latest/reference/core/session.html#boto3.session.Session.resource>`__.
101104
This argument is not needed if the config files are present in the local file system.
105+
repeat_log (bool): Whether the log with the same contents should be emitted.
106+
Default to ``False``
102107
"""
103108
default_config_path = os.getenv(
104109
ENV_VARIABLE_ADMIN_CONFIG_OVERRIDE, _DEFAULT_ADMIN_CONFIG_FILE_PATH
@@ -109,6 +114,11 @@ def load_sagemaker_config(additional_config_paths: List[str] = None, s3_resource
109114
config_paths += additional_config_paths
110115
config_paths = list(filter(lambda item: item is not None, config_paths))
111116
merged_config = {}
117+
118+
log_info = log_info_function
119+
if repeat_log:
120+
log_info = logger.info
121+
112122
for file_path in config_paths:
113123
config_from_file = {}
114124
if file_path.startswith(S3_PREFIX):
@@ -130,9 +140,9 @@ def load_sagemaker_config(additional_config_paths: List[str] = None, s3_resource
130140
if config_from_file:
131141
validate_sagemaker_config(config_from_file)
132142
merge_dicts(merged_config, config_from_file)
133-
logger.info("Fetched defaults config from location: %s", file_path)
143+
log_info("Fetched defaults config from location: %s", file_path)
134144
else:
135-
logger.info("Not applying SDK defaults from location: %s", file_path)
145+
log_info("Not applying SDK defaults from location: %s", file_path)
136146

137147
return merged_config
138148

src/sagemaker/config/config_utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
These utils may be used inside or outside the config module.
1616
"""
1717
from __future__ import absolute_import
18+
from collections import deque
1819

1920
import logging
2021
import sys
22+
from typing import Callable
2123

2224

2325
def get_sagemaker_config_logger():
@@ -197,3 +199,33 @@ def _log_sagemaker_config_merge(
197199
else:
198200
# nothing was specified in the config and nothing is being automatically applied
199201
logger.debug("Skipped value because no value defined\n config key = %s", config_key_path)
202+
203+
204+
def non_repeating_log_factory(logger: logging.Logger, method: str, cache_size=100) -> Callable:
205+
"""Create log function that filters the repeated messages.
206+
207+
By default. It only keeps track of last 100 messages, if a repeated
208+
message arrives after the ``cache_size`` messages, it will be displayed.
209+
210+
Args:
211+
logger (logging.Logger): the logger to be used to dispatch the message.
212+
method (str): the log method, can be info, warning or debug.
213+
cache_size (int): the number of last log messages to keep in cache.
214+
Default to 100
215+
216+
Returns:
217+
(Callable): the new log method
218+
"""
219+
if method not in ["info", "warning", "debug"]:
220+
raise ValueError("Not supported logging method.")
221+
222+
_caches = deque(maxlen=cache_size)
223+
log_method = getattr(logger, method)
224+
225+
def new_log_method(msg, *args, **kwargs):
226+
key = f"{msg}:{args}"
227+
if key not in _caches:
228+
log_method(msg, *args, **kwargs)
229+
_caches.append(key)
230+
231+
return new_log_method

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ def get_deploy_kwargs(
338338
tolerate_vulnerable_model=tolerate_vulnerable_model,
339339
tolerate_deprecated_model=tolerate_deprecated_model,
340340
training_instance_type=training_instance_type,
341+
disable_instance_type_logging=True,
341342
)
342343

343344
estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs(

src/sagemaker/jumpstart/factory/model.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,9 @@ def _add_vulnerable_and_deprecated_status_to_kwargs(
171171
return kwargs
172172

173173

174-
def _add_instance_type_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
174+
def _add_instance_type_to_kwargs(
175+
kwargs: JumpStartModelInitKwargs, disable_instance_type_logging: bool = False
176+
) -> JumpStartModelInitKwargs:
175177
"""Sets instance type based on default or override, returns full kwargs."""
176178

177179
orig_instance_type = kwargs.instance_type
@@ -187,7 +189,7 @@ def _add_instance_type_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartM
187189
training_instance_type=kwargs.training_instance_type,
188190
)
189191

190-
if orig_instance_type is None:
192+
if not disable_instance_type_logging and orig_instance_type is None:
191193
JUMPSTART_LOGGER.info(
192194
"No instance type selected for inference hosting endpoint. Defaulting to %s.",
193195
kwargs.instance_type,
@@ -551,9 +553,7 @@ def get_deploy_kwargs(
551553

552554
deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs)
553555

554-
deploy_kwargs = _add_instance_type_to_kwargs(
555-
kwargs=deploy_kwargs,
556-
)
556+
deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs)
557557

558558
deploy_kwargs.initial_instance_count = initial_instance_count or 1
559559

@@ -677,6 +677,7 @@ def get_init_kwargs(
677677
git_config: Optional[Dict[str, str]] = None,
678678
model_package_arn: Optional[str] = None,
679679
training_instance_type: Optional[str] = None,
680+
disable_instance_type_logging: bool = False,
680681
resources: Optional[ResourceRequirements] = None,
681682
) -> JumpStartModelInitKwargs:
682683
"""Returns kwargs required to instantiate `sagemaker.estimator.Model` object."""
@@ -720,7 +721,7 @@ def get_init_kwargs(
720721
model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs)
721722

722723
model_init_kwargs = _add_instance_type_to_kwargs(
723-
kwargs=model_init_kwargs,
724+
kwargs=model_init_kwargs, disable_instance_type_logging=disable_instance_type_logging
724725
)
725726

726727
model_init_kwargs = _add_image_uri_to_kwargs(kwargs=model_init_kwargs)

src/sagemaker/model.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -866,16 +866,10 @@ def _create_sagemaker_model(
866866
# _base_name, model_name are not needed under PipelineSession.
867867
# the model_data may be Pipeline variable
868868
# which may break the _base_name generation
869-
model_uri = None
870-
if isinstance(self.model_data, (str, PipelineVariable)):
871-
model_uri = self.model_data
872-
elif isinstance(self.model_data, dict):
873-
model_uri = self.model_data.get("S3DataSource", {}).get("S3Uri", None)
874-
875869
self._ensure_base_name_if_needed(
876870
image_uri=container_def["Image"],
877871
script_uri=self.source_dir,
878-
model_uri=model_uri,
872+
model_uri=self._get_model_uri(),
879873
)
880874
self._set_model_name_if_needed()
881875

@@ -912,6 +906,14 @@ def _create_sagemaker_model(
912906
)
913907
self.sagemaker_session.create_model(**create_model_args)
914908

909+
def _get_model_uri(self):
910+
model_uri = None
911+
if isinstance(self.model_data, (str, PipelineVariable)):
912+
model_uri = self.model_data
913+
elif isinstance(self.model_data, dict):
914+
model_uri = self.model_data.get("S3DataSource", {}).get("S3Uri", None)
915+
return model_uri
916+
915917
def _ensure_base_name_if_needed(self, image_uri, script_uri, model_uri):
916918
"""Create a base name from the image URI if there is no model name provided.
917919
@@ -1496,7 +1498,7 @@ def deploy(
14961498
self._ensure_base_name_if_needed(
14971499
image_uri=self.image_uri,
14981500
script_uri=self.source_dir,
1499-
model_uri=self.model_data,
1501+
model_uri=self._get_model_uri(),
15001502
)
15011503
if self._base_name is not None:
15021504
self._base_name = "-".join((self._base_name, compiled_model_suffix))

src/sagemaker/remote_function/job.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,7 @@ def wait(self, timeout: int = None):
891891
"""
892892

893893
self._last_describe_response = _logs_for_job(
894-
boto_session=self.sagemaker_session.boto_session,
894+
sagemaker_session=self.sagemaker_session,
895895
job_name=self.job_name,
896896
wait=True,
897897
timeout=timeout,

src/sagemaker/session.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,7 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
472472

473473
# Initialize the variables used to loop through the contents of the S3 bucket.
474474
keys = []
475+
directories = []
475476
next_token = ""
476477
base_parameters = {"Bucket": bucket, "Prefix": key_prefix}
477478

@@ -490,20 +491,26 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
490491
return []
491492
# For each object, save its key or directory.
492493
for s3_object in contents:
493-
key = s3_object.get("Key")
494-
keys.append(key)
494+
key: str = s3_object.get("Key")
495+
obj_size = s3_object.get("Size")
496+
if key.endswith("/") and int(obj_size) == 0:
497+
directories.append(os.path.join(path, key))
498+
else:
499+
keys.append(key)
495500
next_token = response.get("NextContinuationToken")
496501

497502
# For each object key, create the directory on the local machine if needed, and then
498503
# download the file.
499504
downloaded_paths = []
505+
for dir_path in directories:
506+
os.makedirs(os.path.dirname(dir_path), exist_ok=True)
500507
for key in keys:
501508
tail_s3_uri_path = os.path.basename(key)
502509
if not os.path.splitext(key_prefix)[1]:
503510
tail_s3_uri_path = os.path.relpath(key, key_prefix)
504511
destination_path = os.path.join(path, tail_s3_uri_path)
505512
if not os.path.exists(os.path.dirname(destination_path)):
506-
os.makedirs(os.path.dirname(destination_path))
513+
os.makedirs(os.path.dirname(destination_path), exist_ok=True)
507514
s3.download_file(
508515
Bucket=bucket, Key=key, Filename=destination_path, ExtraArgs=extra_args
509516
)
@@ -5495,7 +5502,7 @@ def logs_for_job(self, job_name, wait=False, poll=10, log_type="All", timeout=No
54955502
exceptions.CapacityError: If the training job fails with CapacityError.
54965503
exceptions.UnexpectedStatusException: If waiting and the training job fails.
54975504
"""
5498-
_logs_for_job(self.boto_session, job_name, wait, poll, log_type, timeout)
5505+
_logs_for_job(self, job_name, wait, poll, log_type, timeout)
54995506

55005507
def logs_for_processing_job(self, job_name, wait=False, poll=10):
55015508
"""Display logs for a given processing job, optionally tailing them until the is complete.
@@ -7378,17 +7385,16 @@ def _rule_statuses_changed(current_statuses, last_statuses):
73787385

73797386

73807387
def _logs_for_job( # noqa: C901 - suppress complexity warning for this method
7381-
boto_session, job_name, wait=False, poll=10, log_type="All", timeout=None
7388+
sagemaker_session, job_name, wait=False, poll=10, log_type="All", timeout=None
73827389
):
73837390
"""Display logs for a given training job, optionally tailing them until job is complete.
73847391
73857392
If the output is a tty or a Jupyter cell, it will be color-coded
73867393
based on which instance the log entry is from.
73877394
73887395
Args:
7389-
boto_session (boto3.session.Session): The underlying Boto3 session which AWS service
7390-
calls are delegated to (default: None). If not provided, one is created with
7391-
default AWS configuration chain.
7396+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
7397+
object, used for SageMaker interactions.
73927398
job_name (str): Name of the training job to display the logs for.
73937399
wait (bool): Whether to keep looking for new log entries until the job completes
73947400
(default: False).
@@ -7405,13 +7411,13 @@ def _logs_for_job( # noqa: C901 - suppress complexity warning for this method
74057411
exceptions.CapacityError: If the training job fails with CapacityError.
74067412
exceptions.UnexpectedStatusException: If waiting and the training job fails.
74077413
"""
7408-
sagemaker_client = boto_session.client("sagemaker")
7414+
sagemaker_client = sagemaker_session.sagemaker_client
74097415
request_end_time = time.time() + timeout if timeout else None
74107416
description = sagemaker_client.describe_training_job(TrainingJobName=job_name)
74117417
print(secondary_training_status_message(description, None), end="")
74127418

74137419
instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init(
7414-
boto_session, description, job="Training"
7420+
sagemaker_session.boto_session, description, job="Training"
74157421
)
74167422

74177423
state = _get_initial_job_state(description, "TrainingJobStatus", wait)

tests/integ/sagemaker/remote_function/test_decorator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def test_with_additional_dependencies(
207207
def cuberoot(x):
208208
from scipy.special import cbrt
209209

210-
return cbrt(27)
210+
return cbrt(x)
211211

212212
assert cuberoot(27) == 3
213213

@@ -742,7 +742,7 @@ def test_with_user_and_workdir_set_in_the_image(
742742
def cuberoot(x):
743743
from scipy.special import cbrt
744744

745-
return cbrt(27)
745+
return cbrt(x)
746746

747747
assert cuberoot(27) == 3
748748

0 commit comments

Comments
 (0)