|
14 | 14 |
|
15 | 15 | import os
|
16 | 16 | import re
|
| 17 | +import time |
17 | 18 | import uuid
|
18 | 19 |
|
19 | 20 | import pytest
|
|
34 | 35 | from tests.integ.timeout import timeout
|
35 | 36 |
|
36 | 37 |
|
| 38 | +TRAINING_STATUS = "Training" |
| 39 | +ALGO_PULL_FINISHED_MESSAGE = "Training image download completed. Training in progress." |
| 40 | + |
| 41 | + |
| 42 | +def _wait_until_training_can_be_updated(sagemaker_client, job_name, poll=5): |
| 43 | + ready_for_updating = _check_secondary_status(sagemaker_client, job_name) |
| 44 | + while not ready_for_updating: |
| 45 | + time.sleep(poll) |
| 46 | + ready_for_updating = _check_secondary_status(sagemaker_client, job_name) |
| 47 | + |
| 48 | + |
| 49 | +def _check_secondary_status(sagemaker_client, job_name): |
| 50 | + desc = sagemaker_client.describe_training_job(TrainingJobName=job_name) |
| 51 | + secondary_status_transitions = desc.get("SecondaryStatusTransitions") |
| 52 | + if not secondary_status_transitions: |
| 53 | + return False |
| 54 | + |
| 55 | + latest_secondary_status_transition = secondary_status_transitions[-1] |
| 56 | + secondary_status = latest_secondary_status_transition.get("Status") |
| 57 | + status_message = latest_secondary_status_transition.get("StatusMessage") |
| 58 | + return TRAINING_STATUS == secondary_status and ALGO_PULL_FINISHED_MESSAGE == status_message |
| 59 | + |
| 60 | + |
37 | 61 | def test_mxnet_with_default_profiler_config_and_profiler_rule(
|
38 | 62 | sagemaker_session,
|
39 | 63 | mxnet_training_latest_version,
|
@@ -139,6 +163,8 @@ def test_mxnet_with_custom_profiler_config_then_update_rule_and_config(
|
139 | 163 | )
|
140 | 164 | assert profiler_rule_configuration["RuleParameters"] == {"rule_to_invoke": "ProfilerReport"}
|
141 | 165 |
|
| 166 | + _wait_until_training_can_be_updated(sagemaker_session.sagemaker_client, training_job_name) |
| 167 | + |
142 | 168 | mx.update_profiler(
|
143 | 169 | rules=[ProfilerRule.sagemaker(rule_configs.CPUBottleneck())],
|
144 | 170 | system_monitor_interval_millis=500,
|
@@ -287,6 +313,8 @@ def test_mxnet_with_profiler_and_debugger_then_disable_framework_metrics(
|
287 | 313 | == rule.image_uri
|
288 | 314 | )
|
289 | 315 |
|
| 316 | + _wait_until_training_can_be_updated(sagemaker_session.sagemaker_client, training_job_name) |
| 317 | + |
290 | 318 | mx.update_profiler(disable_framework_metrics=True)
|
291 | 319 | job_description = mx.latest_training_job.describe()
|
292 | 320 | assert job_description["ProfilerConfig"]["ProfilingParameters"] == {}
|
@@ -338,6 +366,8 @@ def test_mxnet_with_enable_framework_metrics_then_update_framework_metrics(
|
338 | 366 | )
|
339 | 367 | assert job_description.get("ProfilingStatus") == "Enabled"
|
340 | 368 |
|
| 369 | + _wait_until_training_can_be_updated(sagemaker_session.sagemaker_client, training_job_name) |
| 370 | + |
341 | 371 | updated_framework_profile = FrameworkProfile(
|
342 | 372 | detailed_profiling_config=DetailedProfilingConfig(profile_default_steps=True)
|
343 | 373 | )
|
@@ -397,6 +427,8 @@ def test_mxnet_with_disable_profiler_then_enable_default_profiling(
|
397 | 427 | assert job_description.get("ProfilerRuleConfigurations") is None
|
398 | 428 | assert job_description.get("ProfilingStatus") == "Disabled"
|
399 | 429 |
|
| 430 | + _wait_until_training_can_be_updated(sagemaker_session.sagemaker_client, training_job_name) |
| 431 | + |
400 | 432 | mx.enable_default_profiling()
|
401 | 433 |
|
402 | 434 | job_description = mx.latest_training_job.describe()
|
|
0 commit comments