-
Notifications
You must be signed in to change notification settings - Fork 1.2k
change: separate logs() from attach() #1708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7acdbc4
c5dafc5
17610e0
35c3daa
612ae57
3e3d8b9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -191,6 +191,16 @@ def sagemaker_session(): | |
return sms | ||
|
||
|
||
@pytest.fixture() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this per test or per module? (aka are there potential race conditions with the modification of this dictionary in the tests?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The pytest.fixture by default is run once per test function: https://docs.pytest.org/en/stable/fixture.html#scope-sharing-a-fixture-instance-across-tests-in-a-class-module-or-session Each test is modifying the |
||
def training_job_description(sagemaker_session): | ||
returned_job_description = RETURNED_JOB_DESCRIPTION.copy() | ||
mock_describe_training_job = Mock( | ||
name="describe_training_job", return_value=returned_job_description | ||
) | ||
sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job | ||
return returned_job_description | ||
|
||
|
||
def test_framework_all_init_args(sagemaker_session): | ||
f = DummyFramework( | ||
"my_script.py", | ||
|
@@ -651,13 +661,9 @@ def test_enable_cloudwatch_metrics(sagemaker_session): | |
assert train_kwargs["hyperparameters"]["sagemaker_enable_cloudwatch_metrics"] | ||
|
||
|
||
def test_attach_framework(sagemaker_session): | ||
returned_job_description = RETURNED_JOB_DESCRIPTION.copy() | ||
returned_job_description["VpcConfig"] = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} | ||
returned_job_description["EnableNetworkIsolation"] = True | ||
sagemaker_session.sagemaker_client.describe_training_job = Mock( | ||
name="describe_training_job", return_value=returned_job_description | ||
) | ||
def test_attach_framework(sagemaker_session, training_job_description): | ||
training_job_description["VpcConfig"] = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} | ||
training_job_description["EnableNetworkIsolation"] = True | ||
|
||
framework_estimator = DummyFramework.attach( | ||
training_job_name="neo", sagemaker_session=sagemaker_session | ||
|
@@ -681,29 +687,25 @@ def test_attach_framework(sagemaker_session): | |
assert framework_estimator.enable_network_isolation() is True | ||
|
||
|
||
def test_attach_without_hyperparameters(sagemaker_session): | ||
returned_job_description = RETURNED_JOB_DESCRIPTION.copy() | ||
del returned_job_description["HyperParameters"] | ||
def test_attach_no_logs(sagemaker_session, training_job_description): | ||
Estimator.attach(training_job_name="job", sagemaker_session=sagemaker_session) | ||
sagemaker_session.logs_for_job.assert_not_called() | ||
|
||
mock_describe_training_job = Mock( | ||
name="describe_training_job", return_value=returned_job_description | ||
) | ||
sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job | ||
|
||
def test_logs(sagemaker_session, training_job_description): | ||
estimator = Estimator.attach(training_job_name="job", sagemaker_session=sagemaker_session) | ||
|
||
assert estimator.hyperparameters() == {} | ||
estimator.logs() | ||
sagemaker_session.logs_for_job.assert_called_with(estimator.latest_training_job, wait=True) | ||
|
||
|
||
def test_attach_framework_with_tuning(sagemaker_session): | ||
returned_job_description = RETURNED_JOB_DESCRIPTION.copy() | ||
returned_job_description["HyperParameters"]["_tuning_objective_metric"] = "Validation-accuracy" | ||
def test_attach_without_hyperparameters(sagemaker_session, training_job_description): | ||
del training_job_description["HyperParameters"] | ||
estimator = Estimator.attach(training_job_name="job", sagemaker_session=sagemaker_session) | ||
assert estimator.hyperparameters() == {} | ||
|
||
mock_describe_training_job = Mock( | ||
name="describe_training_job", return_value=returned_job_description | ||
) | ||
sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job | ||
|
||
def test_attach_framework_with_tuning(sagemaker_session, training_job_description): | ||
training_job_description["HyperParameters"]["_tuning_objective_metric"] = "Validation-accuracy" | ||
framework_estimator = DummyFramework.attach( | ||
training_job_name="neo", sagemaker_session=sagemaker_session | ||
) | ||
|
@@ -723,48 +725,35 @@ def test_attach_framework_with_tuning(sagemaker_session): | |
assert framework_estimator.encrypt_inter_container_traffic is False | ||
|
||
|
||
def test_attach_framework_with_model_channel(sagemaker_session): | ||
def test_attach_framework_with_model_channel(sagemaker_session, training_job_description): | ||
s3_uri = "s3://some/s3/path/model.tar.gz" | ||
returned_job_description = RETURNED_JOB_DESCRIPTION.copy() | ||
returned_job_description["InputDataConfig"] = [ | ||
training_job_description["InputDataConfig"] = [ | ||
{ | ||
"ChannelName": "model", | ||
"InputMode": "File", | ||
"DataSource": {"S3DataSource": {"S3Uri": s3_uri}}, | ||
} | ||
] | ||
|
||
sagemaker_session.sagemaker_client.describe_training_job = Mock( | ||
name="describe_training_job", return_value=returned_job_description | ||
) | ||
|
||
framework_estimator = DummyFramework.attach( | ||
training_job_name="neo", sagemaker_session=sagemaker_session | ||
) | ||
assert framework_estimator.model_uri is s3_uri | ||
assert framework_estimator.encrypt_inter_container_traffic is False | ||
|
||
|
||
def test_attach_framework_with_inter_container_traffic_encryption_flag(sagemaker_session): | ||
returned_job_description = RETURNED_JOB_DESCRIPTION.copy() | ||
returned_job_description["EnableInterContainerTrafficEncryption"] = True | ||
|
||
sagemaker_session.sagemaker_client.describe_training_job = Mock( | ||
name="describe_training_job", return_value=returned_job_description | ||
) | ||
|
||
def test_attach_framework_with_inter_container_traffic_encryption_flag( | ||
sagemaker_session, training_job_description | ||
): | ||
training_job_description["EnableInterContainerTrafficEncryption"] = True | ||
framework_estimator = DummyFramework.attach( | ||
training_job_name="neo", sagemaker_session=sagemaker_session | ||
) | ||
|
||
assert framework_estimator.encrypt_inter_container_traffic is True | ||
|
||
|
||
def test_attach_framework_base_from_generated_name(sagemaker_session): | ||
sagemaker_session.sagemaker_client.describe_training_job = Mock( | ||
name="describe_training_job", return_value=RETURNED_JOB_DESCRIPTION | ||
) | ||
|
||
def test_attach_framework_base_from_generated_name(sagemaker_session, training_job_description): | ||
base_job_name = "neo" | ||
framework_estimator = DummyFramework.attach( | ||
training_job_name=utils.name_from_base("neo"), sagemaker_session=sagemaker_session | ||
|
Uh oh!
There was an error while loading. Please reload this page.