Skip to content

Commit 75631ef

Browse files
author
Uemit Yoldas
committed
fix: fix docstyle and flake8 errors
1 parent 103d5b6 commit 75631ef

File tree

3 files changed

+30
-25
lines changed

3 files changed

+30
-25
lines changed

src/sagemaker/amtviz/job_metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def _get_metric_data(
104104
start_time: datetime,
105105
end_time: datetime
106106
) -> pd.DataFrame:
107-
"""Fetches CloudWatch metrics between timestamps and returns a DataFrame with selected columns."""
107+
"""Fetches CloudWatch metrics between timestamps, returns a DataFrame with selected columns."""
108108
start_time = start_time - timedelta(hours=1)
109109
end_time = end_time + timedelta(hours=1)
110110
response = cw.get_metric_data(MetricDataQueries=queries, StartTime=start_time, EndTime=end_time)
@@ -135,7 +135,7 @@ def _collect_metrics(
135135
start_time: datetime,
136136
end_time: Optional[datetime]
137137
) -> pd.DataFrame:
138-
"""Collects SageMaker training job metrics from CloudWatch based on given dimensions and time range."""
138+
"""Collects SageMaker training job metrics from CloudWatch for dimensions and time range."""
139139
df = pd.DataFrame()
140140
for dim_name, dim_value in dimensions:
141141
response = cw.list_metrics(

src/sagemaker/amtviz/visualization.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
99
# ANY KIND, either express or implied. See the License for the specific
1010
# language governing permissions and limitations under the License.
11-
"""
12-
This module provides visualization capabilities for SageMaker hyperparameter tuning jobs.
11+
"""This module provides visualization capabilities for SageMaker hyperparameter tuning jobs.
1312
1413
It contains utilities to create interactive visualizations of hyperparameter tuning results
1514
using Altair charts. The module enables users to analyze and understand the performance
@@ -83,8 +82,7 @@ def visualize_tuning_job(
8382
trials_only: bool = False,
8483
advanced: bool = False,
8584
) -> Union[alt.Chart, Tuple[alt.Chart, pd.DataFrame, pd.DataFrame]]:
86-
"""
87-
Visualize SageMaker hyperparameter tuning jobs.
85+
"""Visualize SageMaker hyperparameter tuning jobs.
8886
8987
Args:
9088
tuning_jobs: Single tuning job or list of tuning jobs (name or HyperparameterTuner object)
@@ -147,8 +145,7 @@ def create_charts(
147145
color_trials: bool = False,
148146
advanced: bool = False,
149147
) -> alt.Chart:
150-
"""
151-
Create visualization charts for hyperparameter tuning results.
148+
"""Create visualization charts for hyperparameter tuning results.
152149
153150
Args:
154151
trials_df: DataFrame containing trials data
@@ -240,7 +237,8 @@ def create_charts(
240237
# If we have multiple tuning jobs, we also want to be able
241238
# to discriminate based on the individual tuning job, so
242239
# we just treat them as an additional tuning parameter
243-
tuning_parameters = tuning_parameters.copy() + (["TuningJobName"] if multiple_tuning_jobs else [])
240+
tuning_job_param = ["TuningJobName"] if multiple_tuning_jobs else []
241+
tuning_parameters = tuning_parameters.copy() + tuning_job_param
244242

245243
# If we use early stopping and at least some jobs were
246244
# stopped early, we want to be able to discriminate
@@ -331,7 +329,7 @@ def render_detail_charts():
331329
bandwidth=0.01,
332330
groupby=[tuning_parameter],
333331
# https://github.com/vega/altair/issues/3203#issuecomment-2141558911
334-
# Specifying extent no longer necessary (>5.1.2). Leaving the work around in it for now.
332+
# Specifying extent no longer necessary (>5.1.2).
335333
extent=[
336334
trials_df[objective_name].min(),
337335
trials_df[objective_name].max(),
@@ -612,7 +610,7 @@ def render_progress_chart():
612610

613611

614612
def _clean_parameter_name(s):
615-
""" Helper method to ensure proper parameter name characters for altair 5+ """
613+
"""Helper method to ensure proper parameter name characters for altair 5+"""
616614
return s.replace(":", "_").replace(".", "_")
617615

618616

@@ -664,8 +662,10 @@ def _prepare_consolidated_df(trials_df):
664662

665663

666664
def _get_df(tuning_job_name, filter_out_stopped=False):
667-
"""Retrieves hyperparameter tuning job results and returns preprocessed DataFrame with
668-
tuning metrics and parameters."""
665+
"""Retrieves hyperparameter tuning job results and returns preprocessed DataFrame.
666+
667+
Returns a DataFrame containing tuning metrics and parameters for the specified job.
668+
"""
669669

670670
tuner = sagemaker.HyperparameterTuningJobAnalytics(tuning_job_name)
671671

@@ -707,10 +707,12 @@ def _get_df(tuning_job_name, filter_out_stopped=False):
707707
# A float then?
708708
df[parameter_name] = df[parameter_name].astype(float)
709709

710-
except Exception:
711-
# Trouble, as this was not a number just pretending to be a string, but an actual string with
712-
# characters. Leaving the value untouched
713-
# Ex: Caught exception could not convert string to float: 'sqrt' <class 'ValueError'>
710+
except (ValueError, TypeError, AttributeError):
711+
# Catch exceptions that might occur during string manipulation or type conversion
712+
# - ValueError: Could not convert string to float/int
713+
# - TypeError: Object doesn't support the operation
714+
# - AttributeError: Object doesn't have replace method
715+
# Leaving the value untouched
714716
pass
715717

716718
return df
@@ -747,7 +749,7 @@ def get_job_analytics_data(tuning_job_names):
747749
tuning_job_names (str or list): Single tuning job name or list of names/tuner objects.
748750
749751
Returns:
750-
tuple: (DataFrame with training results, tuned parameters list, objective name, is_minimize flag).
752+
tuple: (DataFrame with training results, tuned params list, objective name, is_minimize).
751753
752754
Raises:
753755
ValueError: If tuning jobs have different objectives or optimization directions.
@@ -828,16 +830,18 @@ def get_job_analytics_data(tuning_job_names):
828830
if isinstance(val, str) and val.startswith('"'):
829831
try:
830832
df[column_name] = df[column_name].apply(lambda x: int(x.replace('"', "")))
831-
except: # noqa: E722 nosec b110 if we fail, we just continue with what we had
833+
except (ValueError, TypeError, AttributeError):
834+
# noqa: E722 nosec b110 if we fail, we just continue with what we had
832835
pass # Value is not an int, but a string
833836

834837
df = df.sort_values("FinalObjectiveValue", ascending=is_minimize)
835838
df[objective_name] = df.pop("FinalObjectiveValue")
836839

837840
# Fix potential issue with dates represented as objects, instead of a timestamp
838841
# This can in other cases lead to:
839-
# https://www.markhneedham.com/blog/2020/01/10/altair-typeerror-object-type-date-not-json-serializable/
840-
# Have only observed this for TrainingEndTime, but will be on the lookout dfor TrainingStartTime as well now
842+
# https://www.markhneedham.com/blog/2020/01/10/altair-typeerror-object-type-
843+
# date-not-json-serializable/
844+
# Seen this for TrainingEndTime, but will watch TrainingStartTime as well now.
841845
df["TrainingEndTime"] = pd.to_datetime(df["TrainingEndTime"])
842846
df["TrainingStartTime"] = pd.to_datetime(df["TrainingStartTime"])
843847

src/sagemaker/tuner.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2128,8 +2128,8 @@ def visualize_jobs(
21282128
trials_only: bool = False,
21292129
advanced: bool = False
21302130
):
2131-
"""Create an interactive visualization based on altair charts using the sagemaker.amtviz
2132-
package.
2131+
"""Create interactive visualization via altair charts using the sagemaker.amtviz package.
2132+
21332133
Args:
21342134
tuning_jobs (str or sagemaker.tuner.HyperparameterTuner or list[str, sagemaker.tuner.HyperparameterTuner]):
21352135
One or more tuning jobs to create
@@ -2147,9 +2147,9 @@ def visualize_jobs(
21472147
importlib.import_module('altair')
21482148

21492149
except ImportError:
2150-
print("Altair is not installed. To use the visualization feature, please install Altair:")
2150+
print("Altair is not installed. Install Altair to use the visualization feature:")
21512151
print(" pip install altair")
2152-
print("After installing Altair, you can use the methods visualize_jobs or visualize_job.")
2152+
print("After installing Altair, use the methods visualize_jobs or visualize_job.")
21532153
return None
21542154

21552155
# If altair is installed, proceed with visualization
@@ -2170,6 +2170,7 @@ def visualize_job(
21702170
advanced: bool = False
21712171
):
21722172
"""Convenience method on instance level for visualize_jobs().
2173+
21732174
See static method visualize_jobs().
21742175
"""
21752176
return HyperparameterTuner.visualize_jobs(

0 commit comments

Comments
 (0)