Skip to content

Commit 7acdbc4

Browse files
author
Chuyang Deng
committed
change: seperate logs() from attach()
1 parent 218d786 commit 7acdbc4

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

src/sagemaker/estimator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,9 +630,15 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
630630
sagemaker_session=sagemaker_session, job_name=training_job_name
631631
)
632632
estimator._current_job_name = estimator.latest_training_job.name
633-
estimator.latest_training_job.wait()
633+
estimator.latest_training_job.wait(logs="None")
634634
return estimator
635635

636+
def logs(self):
637+
"""Display the logs for Estimator's training job. If the output is a tty or a Jupyter
638+
cell, it will be color-coded based on which instance the log entry is from.
639+
"""
640+
self.sagemaker_session.logs_for_job(self.latest_training_job, wait=True)
641+
636642
def deploy(
637643
self,
638644
initial_instance_count,

tests/unit/test_estimator.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,27 @@ def test_attach_framework(sagemaker_session):
681681
assert framework_estimator.enable_network_isolation() is True
682682

683683

684+
def test_attach_no_logs(sagemaker_session):
685+
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
686+
mock_describe_training_job = Mock(
687+
name="describe_training_job", return_value=returned_job_description
688+
)
689+
sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job
690+
Estimator.attach(training_job_name="job", sagemaker_session=sagemaker_session)
691+
sagemaker_session.logs_for_job.assert_not_called()
692+
693+
694+
def test_logs(sagemaker_session):
695+
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
696+
mock_describe_training_job = Mock(
697+
name="describe_training_job", return_value=returned_job_description
698+
)
699+
sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job
700+
estimator = Estimator.attach(training_job_name="job", sagemaker_session=sagemaker_session)
701+
estimator.logs()
702+
sagemaker_session.logs_for_job.assert_called_with(estimator.latest_training_job, wait=True)
703+
704+
684705
def test_attach_without_hyperparameters(sagemaker_session):
685706
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
686707
del returned_job_description["HyperParameters"]

0 commit comments

Comments
 (0)