Skip to content

Commit a47a4fa

Browse files
committed
doc: update docstrings to comply with PEP257
1 parent 64371d3 commit a47a4fa

File tree

1 file changed

+11
-189
lines changed

1 file changed

+11
-189
lines changed

src/sagemaker/session.py

Lines changed: 11 additions & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -462,8 +462,6 @@ def train( # noqa: C901
462462
debugger_hook_config=None,
463463
tensorboard_output_config=None,
464464
enable_sagemaker_metrics=None,
465-
profiler_rule_configs=None,
466-
profiler_config=None,
467465
):
468466
"""Create an Amazon SageMaker training job.
469467
@@ -533,9 +531,6 @@ def train( # noqa: C901
533531
Series. For more information see:
534532
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
535533
(default: ``None``).
536-
profiler_rule_configs (list[dict]): A list of profiler rule configurations.
537-
profiler_config (dict): Configuration for how profiling information is emitted
538-
with SageMaker Profiler. (default: ``None``).
539534
540535
Returns:
541536
str: ARN of the training job, if it is created.
@@ -565,8 +560,6 @@ def train( # noqa: C901
565560
debugger_hook_config=debugger_hook_config,
566561
tensorboard_output_config=tensorboard_output_config,
567562
enable_sagemaker_metrics=enable_sagemaker_metrics,
568-
profiler_rule_configs=profiler_rule_configs,
569-
profiler_config=profiler_config,
570563
)
571564
LOGGER.info("Creating training-job with name: %s", job_name)
572565
LOGGER.debug("train request: %s", json.dumps(train_request, indent=4))
@@ -597,8 +590,6 @@ def _get_train_request( # noqa: C901
597590
debugger_hook_config=None,
598591
tensorboard_output_config=None,
599592
enable_sagemaker_metrics=None,
600-
profiler_rule_configs=None,
601-
profiler_config=None,
602593
):
603594
"""Constructs a request compatible for creating an Amazon SageMaker training job.
604595
@@ -668,9 +659,6 @@ def _get_train_request( # noqa: C901
668659
Series. For more information see:
669660
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
670661
(default: ``None``).
671-
profiler_rule_configs (list[dict]): A list of profiler rule configurations.
672-
profiler_config(dict): Configuration for how profiling information is emitted with
673-
SageMaker Profiler. (default: ``None``).
674662
675663
Returns:
676664
Dict: a training request dict
@@ -746,66 +734,8 @@ def _get_train_request( # noqa: C901
746734
if tensorboard_output_config is not None:
747735
train_request["TensorBoardOutputConfig"] = tensorboard_output_config
748736

749-
if profiler_rule_configs is not None:
750-
train_request["ProfilerRuleConfigurations"] = profiler_rule_configs
751-
752-
if profiler_config is not None:
753-
train_request["ProfilerConfig"] = profiler_config
754-
755737
return train_request
756738

757-
def update_training_job(
758-
self,
759-
job_name,
760-
profiler_rule_configs=None,
761-
profiler_config=None,
762-
):
763-
"""Calls the UpdateTrainingJob API for the given job name and returns the response.
764-
765-
Args:
766-
job_name (str): Name of the training job being updated.
767-
profiler_rule_configs (list): List of profiler rule configurations. (default: ``None``).
768-
profiler_config(dict): Configuration for how profiling information is emitted with
769-
SageMaker Profiler. (default: ``None``).
770-
"""
771-
update_training_job_request = self._get_update_training_job_request(
772-
job_name=job_name,
773-
profiler_rule_configs=profiler_rule_configs,
774-
profiler_config=profiler_config,
775-
)
776-
LOGGER.info("Updating training job with name %s", job_name)
777-
LOGGER.debug("Update request: %s", json.dumps(update_training_job_request, indent=4))
778-
self.sagemaker_client.update_training_job(**update_training_job_request)
779-
780-
def _get_update_training_job_request(
781-
self,
782-
job_name,
783-
profiler_rule_configs=None,
784-
profiler_config=None,
785-
):
786-
"""Constructs a request compatible for updateing an Amazon SageMaker training job.
787-
788-
Args:
789-
job_name (str): Name of the training job being updated.
790-
profiler_rule_configs (list): List of profiler rule configurations. (default: ``None``).
791-
profiler_config(dict): Configuration for how profiling information is emitted with
792-
SageMaker Profiler. (default: ``None``).
793-
794-
Returns:
795-
Dict: an update training request dict
796-
"""
797-
update_training_job_request = {
798-
"TrainingJobName": job_name,
799-
}
800-
801-
if profiler_rule_configs is not None:
802-
update_training_job_request["ProfilerRuleConfigurations"] = profiler_rule_configs
803-
804-
if profiler_config is not None:
805-
update_training_job_request["ProfilerConfig"] = profiler_config
806-
807-
return update_training_job_request
808-
809739
def process(
810740
self,
811741
inputs,
@@ -1802,48 +1732,6 @@ def compile_model(
18021732
LOGGER.info("Creating compilation-job with name: %s", job_name)
18031733
self.sagemaker_client.create_compilation_job(**compilation_job_request)
18041734

1805-
def package_model_for_edge(
1806-
self,
1807-
output_model_config,
1808-
role,
1809-
job_name,
1810-
compilation_job_name,
1811-
model_name,
1812-
model_version,
1813-
resource_key,
1814-
tags,
1815-
):
1816-
"""Create an Amazon SageMaker Edge packaging job.
1817-
1818-
Args:
1819-
output_model_config (dict): Identifies the Amazon S3 location where you want Amazon
1820-
SageMaker Edge to save the results of edge packaging job
1821-
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker Edge
1822-
edge packaging jobs use this role to access model artifacts. You must grant
1823-
sufficient permissions to this role.
1824-
job_name (str): Name of the edge packaging job being created.
1825-
compilation_job_name (str): Name of the compilation job being created.
1826-
resource_key (str): KMS key to encrypt the disk used to package the job
1827-
tags (list[dict]): List of tags for labeling a compile model job. For more, see
1828-
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
1829-
"""
1830-
edge_packaging_job_request = {
1831-
"OutputConfig": output_model_config,
1832-
"RoleArn": role,
1833-
"ModelName": model_name,
1834-
"ModelVersion": model_version,
1835-
"EdgePackagingJobName": job_name,
1836-
"CompilationJobName": compilation_job_name,
1837-
}
1838-
1839-
if tags is not None:
1840-
edge_packaging_job_request["Tags"] = tags
1841-
if resource_key is not None:
1842-
edge_packaging_job_request["ResourceKey"] = (resource_key,)
1843-
1844-
LOGGER.info("Creating edge-packaging-job with name: %s", job_name)
1845-
self.sagemaker_client.create_edge_packaging_job(**edge_packaging_job_request)
1846-
18471735
def tune( # noqa: C901
18481736
self,
18491737
job_name,
@@ -3150,23 +3038,6 @@ def wait_for_compilation_job(self, job, poll=5):
31503038
self._check_job_status(job, desc, "CompilationJobStatus")
31513039
return desc
31523040

3153-
def wait_for_edge_packaging_job(self, job, poll=5):
3154-
"""Wait for an Amazon SageMaker Edge packaging job to complete.
3155-
3156-
Args:
3157-
job (str): Name of the edge packaging job to wait for.
3158-
poll (int): Polling interval in seconds (default: 5).
3159-
3160-
Returns:
3161-
(dict): Return value from the ``DescribeEdgePackagingJob`` API.
3162-
3163-
Raises:
3164-
exceptions.UnexpectedStatusException: If the compilation job fails.
3165-
"""
3166-
desc = _wait_until(lambda: _edge_packaging_job_status(self.sagemaker_client, job), poll)
3167-
self._check_job_status(job, desc, "EdgePackagingJobStatus")
3168-
return desc
3169-
31703041
def wait_for_tuning_job(self, job, poll=5):
31713042
"""Wait for an Amazon SageMaker hyperparameter tuning job to complete.
31723043
@@ -3250,13 +3121,7 @@ def _check_job_status(self, job, desc, status_key_name):
32503121
# If the status is capital case, then convert it to Camel case
32513122
status = _STATUS_CODE_TABLE.get(status, status)
32523123

3253-
if status == "Stopped":
3254-
LOGGER.warning(
3255-
"Job ended with status 'Stopped' rather than 'Completed'. "
3256-
"This could mean the job timed out or stopped early for some other reason: "
3257-
"Consider checking whether it completed as you expect."
3258-
)
3259-
elif status != "Completed":
3124+
if status not in ("Completed", "Stopped"):
32603125
reason = desc.get("FailureReason", "(No reason provided)")
32613126
job_type = status_key_name.replace("JobStatus", " job")
32623127
raise exceptions.UnexpectedStatusException(
@@ -3618,7 +3483,6 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
36183483
last_describe_job_call = time.time()
36193484
last_description = description
36203485
last_debug_rule_statuses = None
3621-
last_profiler_rule_statuses = None
36223486

36233487
while True:
36243488
_flush_log_streams(
@@ -3657,32 +3521,22 @@ def logs_for_job( # noqa: C901 - suppress complexity warning for this method
36573521
debug_rule_statuses = description.get("DebugRuleEvaluationStatuses", {})
36583522
if (
36593523
debug_rule_statuses
3660-
and _rule_statuses_changed(debug_rule_statuses, last_debug_rule_statuses)
3524+
and _debug_rule_statuses_changed(debug_rule_statuses, last_debug_rule_statuses)
36613525
and (log_type in {"All", "Rules"})
36623526
):
3527+
print()
3528+
print("********* Debugger Rule Status *********")
3529+
print("*")
36633530
for status in debug_rule_statuses:
3664-
rule_log = (
3665-
f"{status['RuleConfigurationName']}: {status['RuleEvaluationStatus']}"
3531+
rule_log = "* {:>18}: {:<18}".format(
3532+
status["RuleConfigurationName"], status["RuleEvaluationStatus"]
36663533
)
36673534
print(rule_log)
3535+
print("*")
3536+
print("*" * 40)
36683537

36693538
last_debug_rule_statuses = debug_rule_statuses
36703539

3671-
# Print prettified logs related to the status of SageMaker Profiler rules.
3672-
profiler_rule_statuses = description.get("ProfilerRuleEvaluationStatuses", {})
3673-
if (
3674-
profiler_rule_statuses
3675-
and _rule_statuses_changed(profiler_rule_statuses, last_profiler_rule_statuses)
3676-
and (log_type in {"All", "Rules"})
3677-
):
3678-
for status in profiler_rule_statuses:
3679-
rule_log = (
3680-
f"{status['RuleConfigurationName']}: {status['RuleEvaluationStatus']}"
3681-
)
3682-
print(rule_log)
3683-
3684-
last_profiler_rule_statuses = profiler_rule_statuses
3685-
36863540
if wait:
36873541
self._check_job_status(job_name, description, "TrainingJobStatus")
36883542
if dot:
@@ -4245,38 +4099,6 @@ def _processing_job_status(sagemaker_client, job_name):
42454099
return desc
42464100

42474101

4248-
def _edge_packaging_job_status(sagemaker_client, job_name):
4249-
"""Process the current status of a packaging job
4250-
4251-
Args:
4252-
sagemaker_client (boto3.client.sagemaker): a sagemaker client
4253-
job_name (str): the name of the job to inspec
4254-
4255-
Returns:
4256-
Dict: the status of the edge packaging job
4257-
"""
4258-
package_status_codes = {
4259-
"Completed": "!",
4260-
"InProgress": ".",
4261-
"Failed": "*",
4262-
"Stopped": "s",
4263-
"Stopping": "_",
4264-
}
4265-
in_progress_statuses = ["InProgress", "Stopping", "Starting"]
4266-
4267-
desc = sagemaker_client.describe_edge_packaging_job(EdgePackagingJobName=job_name)
4268-
status = desc["EdgePackagingJobStatus"]
4269-
4270-
status = _STATUS_CODE_TABLE.get(status, status)
4271-
print(package_status_codes.get(status, "?"), end="")
4272-
sys.stdout.flush()
4273-
4274-
if status in in_progress_statuses:
4275-
return None
4276-
4277-
return desc
4278-
4279-
42804102
def _compilation_job_status(sagemaker_client, job_name):
42814103
"""Placeholder docstring"""
42824104
compile_status_codes = {
@@ -4454,8 +4276,8 @@ def _get_initial_job_state(description, status_key, wait):
44544276
return LogState.TAILING if wait and not job_already_completed else LogState.COMPLETE
44554277

44564278

4457-
def _rule_statuses_changed(current_statuses, last_statuses):
4458-
"""Checks the rule evaluation statuses for SageMaker Debugger and Profiler rules."""
4279+
def _debug_rule_statuses_changed(current_statuses, last_statuses):
4280+
"""Checks the rule evaluation statuses for SageMaker Debugger rules."""
44594281
if not last_statuses:
44604282
return True
44614283

0 commit comments

Comments
 (0)