Skip to content

Commit bba192a

Browse files
chuyang-dengChuyang Deng
andauthored
change: separate logs() from attach() (#1708)
* change: seperate logs() from attach() * make fixture for training_job_description Co-authored-by: Chuyang Deng <[email protected]>
1 parent 2bb6713 commit bba192a

File tree

2 files changed

+49
-48
lines changed

2 files changed

+49
-48
lines changed

src/sagemaker/estimator.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -589,14 +589,16 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
589589
has a Complete status, it can be ``deploy()`` ed to create a SageMaker
590590
Endpoint and return a ``Predictor``.
591591
592-
If the training job is in progress, attach will block and display log
593-
messages from the training job, until the training job completes.
592+
If the training job is in progress, attach will block until the training job
593+
completes, but logs of the training job will not display. To see the logs
594+
content, please call ``logs()``
594595
595596
Examples:
596597
>>> my_estimator.fit(wait=False)
597598
>>> training_job_name = my_estimator.latest_training_job.name
598599
Later on:
599600
>>> attached_estimator = Estimator.attach(training_job_name)
601+
>>> attached_estimator.logs()
600602
>>> attached_estimator.deploy()
601603
602604
Args:
@@ -630,9 +632,17 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
630632
sagemaker_session=sagemaker_session, job_name=training_job_name
631633
)
632634
estimator._current_job_name = estimator.latest_training_job.name
633-
estimator.latest_training_job.wait()
635+
estimator.latest_training_job.wait(logs="None")
634636
return estimator
635637

638+
def logs(self):
639+
"""Display the logs for Estimator's training job.
640+
641+
If the output is a tty or a Jupyter cell, it will be color-coded based
642+
on which instance the log entry is from.
643+
"""
644+
self.sagemaker_session.logs_for_job(self.latest_training_job, wait=True)
645+
636646
def deploy(
637647
self,
638648
initial_instance_count,
@@ -1842,14 +1852,16 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
18421852
has a Complete status, it can be ``deploy()`` ed to create a SageMaker
18431853
Endpoint and return a ``Predictor``.
18441854
1845-
If the training job is in progress, attach will block and display log
1846-
messages from the training job, until the training job completes.
1855+
If the training job is in progress, attach will block until the training job
1856+
completes, but logs of the training job will not display. To see the logs
1857+
content, please call ``logs()``
18471858
18481859
Examples:
18491860
>>> my_estimator.fit(wait=False)
18501861
>>> training_job_name = my_estimator.latest_training_job.name
18511862
Later on:
18521863
>>> attached_estimator = Estimator.attach(training_job_name)
1864+
>>> attached_estimator.logs()
18531865
>>> attached_estimator.deploy()
18541866
18551867
Args:

tests/unit/test_estimator.py

Lines changed: 32 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,16 @@ def sagemaker_session():
191191
return sms
192192

193193

194+
@pytest.fixture()
195+
def training_job_description(sagemaker_session):
196+
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
197+
mock_describe_training_job = Mock(
198+
name="describe_training_job", return_value=returned_job_description
199+
)
200+
sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job
201+
return returned_job_description
202+
203+
194204
def test_framework_all_init_args(sagemaker_session):
195205
f = DummyFramework(
196206
"my_script.py",
@@ -651,13 +661,9 @@ def test_enable_cloudwatch_metrics(sagemaker_session):
651661
assert train_kwargs["hyperparameters"]["sagemaker_enable_cloudwatch_metrics"]
652662

653663

654-
def test_attach_framework(sagemaker_session):
655-
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
656-
returned_job_description["VpcConfig"] = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
657-
returned_job_description["EnableNetworkIsolation"] = True
658-
sagemaker_session.sagemaker_client.describe_training_job = Mock(
659-
name="describe_training_job", return_value=returned_job_description
660-
)
664+
def test_attach_framework(sagemaker_session, training_job_description):
665+
training_job_description["VpcConfig"] = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
666+
training_job_description["EnableNetworkIsolation"] = True
661667

662668
framework_estimator = DummyFramework.attach(
663669
training_job_name="neo", sagemaker_session=sagemaker_session
@@ -681,29 +687,25 @@ def test_attach_framework(sagemaker_session):
681687
assert framework_estimator.enable_network_isolation() is True
682688

683689

684-
def test_attach_without_hyperparameters(sagemaker_session):
685-
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
686-
del returned_job_description["HyperParameters"]
690+
def test_attach_no_logs(sagemaker_session, training_job_description):
691+
Estimator.attach(training_job_name="job", sagemaker_session=sagemaker_session)
692+
sagemaker_session.logs_for_job.assert_not_called()
687693

688-
mock_describe_training_job = Mock(
689-
name="describe_training_job", return_value=returned_job_description
690-
)
691-
sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job
692694

695+
def test_logs(sagemaker_session, training_job_description):
693696
estimator = Estimator.attach(training_job_name="job", sagemaker_session=sagemaker_session)
694-
695-
assert estimator.hyperparameters() == {}
697+
estimator.logs()
698+
sagemaker_session.logs_for_job.assert_called_with(estimator.latest_training_job, wait=True)
696699

697700

698-
def test_attach_framework_with_tuning(sagemaker_session):
699-
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
700-
returned_job_description["HyperParameters"]["_tuning_objective_metric"] = "Validation-accuracy"
701+
def test_attach_without_hyperparameters(sagemaker_session, training_job_description):
702+
del training_job_description["HyperParameters"]
703+
estimator = Estimator.attach(training_job_name="job", sagemaker_session=sagemaker_session)
704+
assert estimator.hyperparameters() == {}
701705

702-
mock_describe_training_job = Mock(
703-
name="describe_training_job", return_value=returned_job_description
704-
)
705-
sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job
706706

707+
def test_attach_framework_with_tuning(sagemaker_session, training_job_description):
708+
training_job_description["HyperParameters"]["_tuning_objective_metric"] = "Validation-accuracy"
707709
framework_estimator = DummyFramework.attach(
708710
training_job_name="neo", sagemaker_session=sagemaker_session
709711
)
@@ -723,48 +725,35 @@ def test_attach_framework_with_tuning(sagemaker_session):
723725
assert framework_estimator.encrypt_inter_container_traffic is False
724726

725727

726-
def test_attach_framework_with_model_channel(sagemaker_session):
728+
def test_attach_framework_with_model_channel(sagemaker_session, training_job_description):
727729
s3_uri = "s3://some/s3/path/model.tar.gz"
728-
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
729-
returned_job_description["InputDataConfig"] = [
730+
training_job_description["InputDataConfig"] = [
730731
{
731732
"ChannelName": "model",
732733
"InputMode": "File",
733734
"DataSource": {"S3DataSource": {"S3Uri": s3_uri}},
734735
}
735736
]
736737

737-
sagemaker_session.sagemaker_client.describe_training_job = Mock(
738-
name="describe_training_job", return_value=returned_job_description
739-
)
740-
741738
framework_estimator = DummyFramework.attach(
742739
training_job_name="neo", sagemaker_session=sagemaker_session
743740
)
744741
assert framework_estimator.model_uri is s3_uri
745742
assert framework_estimator.encrypt_inter_container_traffic is False
746743

747744

748-
def test_attach_framework_with_inter_container_traffic_encryption_flag(sagemaker_session):
749-
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
750-
returned_job_description["EnableInterContainerTrafficEncryption"] = True
751-
752-
sagemaker_session.sagemaker_client.describe_training_job = Mock(
753-
name="describe_training_job", return_value=returned_job_description
754-
)
755-
745+
def test_attach_framework_with_inter_container_traffic_encryption_flag(
746+
sagemaker_session, training_job_description
747+
):
748+
training_job_description["EnableInterContainerTrafficEncryption"] = True
756749
framework_estimator = DummyFramework.attach(
757750
training_job_name="neo", sagemaker_session=sagemaker_session
758751
)
759752

760753
assert framework_estimator.encrypt_inter_container_traffic is True
761754

762755

763-
def test_attach_framework_base_from_generated_name(sagemaker_session):
764-
sagemaker_session.sagemaker_client.describe_training_job = Mock(
765-
name="describe_training_job", return_value=RETURNED_JOB_DESCRIPTION
766-
)
767-
756+
def test_attach_framework_base_from_generated_name(sagemaker_session, training_job_description):
768757
base_job_name = "neo"
769758
framework_estimator = DummyFramework.attach(
770759
training_job_name=utils.name_from_base("neo"), sagemaker_session=sagemaker_session

0 commit comments

Comments
 (0)