Skip to content

Commit 500bd36

Browse files
jswudiChoiByungWook
authored andcommitted
fix: run UpdateTrainingJob tests only during allowed secondary status (#552)
1 parent 8b9cca7 commit 500bd36

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

tests/integ/test_profiler.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import os
1616
import re
17+
import time
1718
import uuid
1819

1920
import pytest
@@ -34,6 +35,29 @@
3435
from tests.integ.timeout import timeout
3536

3637

38+
TRAINING_STATUS = "Training"
39+
ALGO_PULL_FINISHED_MESSAGE = "Training image download completed. Training in progress."
40+
41+
42+
def _wait_until_training_can_be_updated(sagemaker_client, job_name, poll=5):
43+
ready_for_updating = _check_secondary_status(sagemaker_client, job_name)
44+
while not ready_for_updating:
45+
time.sleep(poll)
46+
ready_for_updating = _check_secondary_status(sagemaker_client, job_name)
47+
48+
49+
def _check_secondary_status(sagemaker_client, job_name):
50+
desc = sagemaker_client.describe_training_job(TrainingJobName=job_name)
51+
secondary_status_transitions = desc.get("SecondaryStatusTransitions")
52+
if not secondary_status_transitions:
53+
return False
54+
55+
latest_secondary_status_transition = secondary_status_transitions[-1]
56+
secondary_status = latest_secondary_status_transition.get("Status")
57+
status_message = latest_secondary_status_transition.get("StatusMessage")
58+
return TRAINING_STATUS == secondary_status and ALGO_PULL_FINISHED_MESSAGE == status_message
59+
60+
3761
def test_mxnet_with_default_profiler_config_and_profiler_rule(
3862
sagemaker_session,
3963
mxnet_training_latest_version,
@@ -139,6 +163,8 @@ def test_mxnet_with_custom_profiler_config_then_update_rule_and_config(
139163
)
140164
assert profiler_rule_configuration["RuleParameters"] == {"rule_to_invoke": "ProfilerReport"}
141165

166+
_wait_until_training_can_be_updated(sagemaker_session.sagemaker_client, training_job_name)
167+
142168
mx.update_profiler(
143169
rules=[ProfilerRule.sagemaker(rule_configs.CPUBottleneck())],
144170
system_monitor_interval_millis=500,
@@ -287,6 +313,8 @@ def test_mxnet_with_profiler_and_debugger_then_disable_framework_metrics(
287313
== rule.image_uri
288314
)
289315

316+
_wait_until_training_can_be_updated(sagemaker_session.sagemaker_client, training_job_name)
317+
290318
mx.update_profiler(disable_framework_metrics=True)
291319
job_description = mx.latest_training_job.describe()
292320
assert job_description["ProfilerConfig"]["ProfilingParameters"] == {}
@@ -338,6 +366,8 @@ def test_mxnet_with_enable_framework_metrics_then_update_framework_metrics(
338366
)
339367
assert job_description.get("ProfilingStatus") == "Enabled"
340368

369+
_wait_until_training_can_be_updated(sagemaker_session.sagemaker_client, training_job_name)
370+
341371
updated_framework_profile = FrameworkProfile(
342372
detailed_profiling_config=DetailedProfilingConfig(profile_default_steps=True)
343373
)
@@ -397,6 +427,8 @@ def test_mxnet_with_disable_profiler_then_enable_default_profiling(
397427
assert job_description.get("ProfilerRuleConfigurations") is None
398428
assert job_description.get("ProfilingStatus") == "Disabled"
399429

430+
_wait_until_training_can_be_updated(sagemaker_session.sagemaker_client, training_job_name)
431+
400432
mx.enable_default_profiling()
401433

402434
job_description = mx.latest_training_job.describe()

0 commit comments

Comments
 (0)