Skip to content

Fix hyperparameter strategy docs #5045

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 18, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 20 additions & 26 deletions src/sagemaker/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,20 @@
import inspect
import json
import logging

from enum import Enum
from typing import Union, Dict, Optional, List, Set
from typing import Dict, List, Optional, Set, Union

import sagemaker
from sagemaker.amazon.amazon_estimator import (
RecordSet,
AmazonAlgorithmEstimatorBase,
FileSystemRecordSet,
RecordSet,
)
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
from sagemaker.analytics import HyperparameterTuningJobAnalytics
from sagemaker.deprecations import removed_function
from sagemaker.estimator import Framework, EstimatorBase
from sagemaker.inputs import TrainingInput, FileSystemInput
from sagemaker.estimator import EstimatorBase, Framework
from sagemaker.inputs import FileSystemInput, TrainingInput
from sagemaker.job import _Job
from sagemaker.jumpstart.utils import (
add_jumpstart_uri_tags,
Expand All @@ -44,18 +43,17 @@
IntegerParameter,
ParameterRange,
)
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.pipeline_context import runnable_by_pipeline

from sagemaker.session import Session
from sagemaker.utils import (
Tags,
base_from_name,
base_name_from_image,
format_tags,
name_from_base,
to_string,
format_tags,
Tags,
)
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.pipeline_context import runnable_by_pipeline

AMAZON_ESTIMATOR_MODULE = "sagemaker"
AMAZON_ESTIMATOR_CLS_NAMES = {
Expand Down Expand Up @@ -133,15 +131,12 @@ def __init__(

if warm_start_type not in list(WarmStartTypes):
raise ValueError(
"Invalid type: {}, valid warm start types are: {}".format(
warm_start_type, list(WarmStartTypes)
)
f"Invalid type: {warm_start_type}, "
f"valid warm start types are: {list(WarmStartTypes)}"
)

if not parents:
raise ValueError(
"Invalid parents: {}, parents should not be None/empty".format(parents)
)
raise ValueError(f"Invalid parents: {parents}, parents should not be None/empty")

self.type = warm_start_type
self.parents = set(parents)
Expand Down Expand Up @@ -1455,9 +1450,7 @@ def _get_best_training_job(self):
return tuning_job_describe_result["BestTrainingJob"]
except KeyError:
raise Exception(
"Best training job not available for tuning job: {}".format(
self.latest_tuning_job.name
)
f"Best training job not available for tuning job: {self.latest_tuning_job.name}"
)

def _ensure_last_tuning_job(self):
Expand Down Expand Up @@ -1920,8 +1913,11 @@ def create(
:meth:`~sagemaker.tuner.HyperparameterTuner.fit` method launches.
If not specified, a default job name is generated,
based on the training image name and current timestamp.
strategy (str): Strategy to be used for hyperparameter estimations
(default: 'Bayesian').
strategy (str or PipelineVariable): Strategy to be used for hyperparameter estimations.
More information about different strategies:
https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-how-it-works.html.
Available options are: 'Bayesian', 'Random', 'Hyperband',
'Grid' (default: 'Bayesian')
strategy_config (dict): The configuration for a training job launched by a
hyperparameter tuning job.
completion_criteria_config (dict): The configuration for tuning job completion criteria.
Expand Down Expand Up @@ -2080,21 +2076,19 @@ def _validate_dict_argument(cls, name, value, allowed_keys, require_same_keys=Fa
return

if not isinstance(value, dict):
raise ValueError(
"Argument '{}' must be a dictionary using {} as keys".format(name, allowed_keys)
)
raise ValueError(f"Argument '{name}' must be a dictionary using {allowed_keys} as keys")

value_keys = sorted(value.keys())

if require_same_keys:
if value_keys != allowed_keys:
raise ValueError(
"The keys of argument '{}' must be the same as {}".format(name, allowed_keys)
f"The keys of argument '{name}' must be the same as {allowed_keys}"
)
else:
if not set(value_keys).issubset(set(allowed_keys)):
raise ValueError(
"The keys of argument '{}' must be a subset of {}".format(name, allowed_keys)
f"The keys of argument '{name}' must be a subset of {allowed_keys}"
)

def _add_estimator(
Expand Down
Loading