Skip to content

Commit 5bc30a5

Browse files
committed
Merge branch 'zwei' into deprecate-get-image-uri
2 parents 9ca6ffd + bba192a commit 5bc30a5

File tree

2 files changed

+49
-48
lines changed

2 files changed

+49
-48
lines changed

src/sagemaker/estimator.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -589,14 +589,16 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
589589
has a Complete status, it can be ``deploy()`` ed to create a SageMaker
590590
Endpoint and return a ``Predictor``.
591591
592-
If the training job is in progress, attach will block and display log
593-
messages from the training job, until the training job completes.
592+
If the training job is in progress, attach will block until the training job
593+
completes, but logs of the training job will not display. To see the logs
594+
content, please call ``logs()``
594595
595596
Examples:
596597
>>> my_estimator.fit(wait=False)
597598
>>> training_job_name = my_estimator.latest_training_job.name
598599
Later on:
599600
>>> attached_estimator = Estimator.attach(training_job_name)
601+
>>> attached_estimator.logs()
600602
>>> attached_estimator.deploy()
601603
602604
Args:
@@ -630,9 +632,17 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
630632
sagemaker_session=sagemaker_session, job_name=training_job_name
631633
)
632634
estimator._current_job_name = estimator.latest_training_job.name
633-
estimator.latest_training_job.wait()
635+
estimator.latest_training_job.wait(logs="None")
634636
return estimator
635637

638+
def logs(self):
639+
"""Display the logs for Estimator's training job.
640+
641+
If the output is a tty or a Jupyter cell, it will be color-coded based
642+
on which instance the log entry is from.
643+
"""
644+
self.sagemaker_session.logs_for_job(self.latest_training_job, wait=True)
645+
636646
def deploy(
637647
self,
638648
initial_instance_count,
@@ -1842,14 +1852,16 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name="m
18421852
has a Complete status, it can be ``deploy()`` ed to create a SageMaker
18431853
Endpoint and return a ``Predictor``.
18441854
1845-
If the training job is in progress, attach will block and display log
1846-
messages from the training job, until the training job completes.
1855+
If the training job is in progress, attach will block until the training job
1856+
completes, but logs of the training job will not display. To see the logs
1857+
content, please call ``logs()``
18471858
18481859
Examples:
18491860
>>> my_estimator.fit(wait=False)
18501861
>>> training_job_name = my_estimator.latest_training_job.name
18511862
Later on:
18521863
>>> attached_estimator = Estimator.attach(training_job_name)
1864+
>>> attached_estimator.logs()
18531865
>>> attached_estimator.deploy()
18541866
18551867
Args:

tests/unit/test_estimator.py

Lines changed: 32 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,16 @@ def sagemaker_session():
190190
return sms
191191

192192

193+
@pytest.fixture()
194+
def training_job_description(sagemaker_session):
195+
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
196+
mock_describe_training_job = Mock(
197+
name="describe_training_job", return_value=returned_job_description
198+
)
199+
sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job
200+
return returned_job_description
201+
202+
193203
def test_framework_all_init_args(sagemaker_session):
194204
f = DummyFramework(
195205
"my_script.py",
@@ -650,13 +660,9 @@ def test_enable_cloudwatch_metrics(sagemaker_session):
650660
assert train_kwargs["hyperparameters"]["sagemaker_enable_cloudwatch_metrics"]
651661

652662

653-
def test_attach_framework(sagemaker_session):
654-
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
655-
returned_job_description["VpcConfig"] = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
656-
returned_job_description["EnableNetworkIsolation"] = True
657-
sagemaker_session.sagemaker_client.describe_training_job = Mock(
658-
name="describe_training_job", return_value=returned_job_description
659-
)
663+
def test_attach_framework(sagemaker_session, training_job_description):
664+
training_job_description["VpcConfig"] = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
665+
training_job_description["EnableNetworkIsolation"] = True
660666

661667
framework_estimator = DummyFramework.attach(
662668
training_job_name="neo", sagemaker_session=sagemaker_session
@@ -680,29 +686,25 @@ def test_attach_framework(sagemaker_session):
680686
assert framework_estimator.enable_network_isolation() is True
681687

682688

683-
def test_attach_without_hyperparameters(sagemaker_session):
684-
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
685-
del returned_job_description["HyperParameters"]
689+
def test_attach_no_logs(sagemaker_session, training_job_description):
690+
Estimator.attach(training_job_name="job", sagemaker_session=sagemaker_session)
691+
sagemaker_session.logs_for_job.assert_not_called()
686692

687-
mock_describe_training_job = Mock(
688-
name="describe_training_job", return_value=returned_job_description
689-
)
690-
sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job
691693

694+
def test_logs(sagemaker_session, training_job_description):
692695
estimator = Estimator.attach(training_job_name="job", sagemaker_session=sagemaker_session)
693-
694-
assert estimator.hyperparameters() == {}
696+
estimator.logs()
697+
sagemaker_session.logs_for_job.assert_called_with(estimator.latest_training_job, wait=True)
695698

696699

697-
def test_attach_framework_with_tuning(sagemaker_session):
698-
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
699-
returned_job_description["HyperParameters"]["_tuning_objective_metric"] = "Validation-accuracy"
700+
def test_attach_without_hyperparameters(sagemaker_session, training_job_description):
701+
del training_job_description["HyperParameters"]
702+
estimator = Estimator.attach(training_job_name="job", sagemaker_session=sagemaker_session)
703+
assert estimator.hyperparameters() == {}
700704

701-
mock_describe_training_job = Mock(
702-
name="describe_training_job", return_value=returned_job_description
703-
)
704-
sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job
705705

706+
def test_attach_framework_with_tuning(sagemaker_session, training_job_description):
707+
training_job_description["HyperParameters"]["_tuning_objective_metric"] = "Validation-accuracy"
706708
framework_estimator = DummyFramework.attach(
707709
training_job_name="neo", sagemaker_session=sagemaker_session
708710
)
@@ -722,48 +724,35 @@ def test_attach_framework_with_tuning(sagemaker_session):
722724
assert framework_estimator.encrypt_inter_container_traffic is False
723725

724726

725-
def test_attach_framework_with_model_channel(sagemaker_session):
727+
def test_attach_framework_with_model_channel(sagemaker_session, training_job_description):
726728
s3_uri = "s3://some/s3/path/model.tar.gz"
727-
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
728-
returned_job_description["InputDataConfig"] = [
729+
training_job_description["InputDataConfig"] = [
729730
{
730731
"ChannelName": "model",
731732
"InputMode": "File",
732733
"DataSource": {"S3DataSource": {"S3Uri": s3_uri}},
733734
}
734735
]
735736

736-
sagemaker_session.sagemaker_client.describe_training_job = Mock(
737-
name="describe_training_job", return_value=returned_job_description
738-
)
739-
740737
framework_estimator = DummyFramework.attach(
741738
training_job_name="neo", sagemaker_session=sagemaker_session
742739
)
743740
assert framework_estimator.model_uri is s3_uri
744741
assert framework_estimator.encrypt_inter_container_traffic is False
745742

746743

747-
def test_attach_framework_with_inter_container_traffic_encryption_flag(sagemaker_session):
748-
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
749-
returned_job_description["EnableInterContainerTrafficEncryption"] = True
750-
751-
sagemaker_session.sagemaker_client.describe_training_job = Mock(
752-
name="describe_training_job", return_value=returned_job_description
753-
)
754-
744+
def test_attach_framework_with_inter_container_traffic_encryption_flag(
745+
sagemaker_session, training_job_description
746+
):
747+
training_job_description["EnableInterContainerTrafficEncryption"] = True
755748
framework_estimator = DummyFramework.attach(
756749
training_job_name="neo", sagemaker_session=sagemaker_session
757750
)
758751

759752
assert framework_estimator.encrypt_inter_container_traffic is True
760753

761754

762-
def test_attach_framework_base_from_generated_name(sagemaker_session):
763-
sagemaker_session.sagemaker_client.describe_training_job = Mock(
764-
name="describe_training_job", return_value=RETURNED_JOB_DESCRIPTION
765-
)
766-
755+
def test_attach_framework_base_from_generated_name(sagemaker_session, training_job_description):
767756
base_job_name = "neo"
768757
framework_estimator = DummyFramework.attach(
769758
training_job_name=utils.name_from_base("neo"), sagemaker_session=sagemaker_session

0 commit comments

Comments
 (0)