Skip to content

fix: Add retry in session.py to check if training is finished #3285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 26, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
secondary_training_status_changed,
secondary_training_status_message,
sts_regional_endpoint,
retries,
)
from sagemaker import exceptions
from sagemaker.session_settings import SessionSettings
Expand Down Expand Up @@ -4699,21 +4700,30 @@ def _train_done(sagemaker_client, job_name, last_desc):
"""Placeholder docstring"""
in_progress_statuses = ["InProgress", "Created"]

desc = sagemaker_client.describe_training_job(TrainingJobName=job_name)
status = desc["TrainingJobStatus"]
for _ in retries(
max_retry_count=10, # 10*30 = 5min
exception_message_prefix="Waiting for schedule to leave 'Pending' status",
seconds_to_sleep=30,
):
try:
desc = sagemaker_client.describe_training_job(TrainingJobName=job_name)
status = desc["TrainingJobStatus"]

if secondary_training_status_changed(desc, last_desc):
print()
print(secondary_training_status_message(desc, last_desc), end="")
else:
print(".", end="")
sys.stdout.flush()
if secondary_training_status_changed(desc, last_desc):
print()
print(secondary_training_status_message(desc, last_desc), end="")
else:
print(".", end="")
sys.stdout.flush()

if status in in_progress_statuses:
return desc, False
if status in in_progress_statuses:
return desc, False

print()
return desc, True
print()
return desc, True
except botocore.exceptions.ClientError as err:
if err.response["Error"]["Code"] == "AccessDeniedException":
pass


def _processing_job_status(sagemaker_client, job_name):
Expand Down