|
8 | 8 | # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
|
9 | 9 | # ANY KIND, either express or implied. See the License for the specific
|
10 | 10 | # language governing permissions and limitations under the License.
|
11 |
| -""" |
12 |
| -This module provides visualization capabilities for SageMaker hyperparameter tuning jobs. |
| 11 | +"""This module provides visualization capabilities for SageMaker hyperparameter tuning jobs. |
13 | 12 |
|
14 | 13 | It contains utilities to create interactive visualizations of hyperparameter tuning results
|
15 | 14 | using Altair charts. The module enables users to analyze and understand the performance
|
@@ -83,8 +82,7 @@ def visualize_tuning_job(
|
83 | 82 | trials_only: bool = False,
|
84 | 83 | advanced: bool = False,
|
85 | 84 | ) -> Union[alt.Chart, Tuple[alt.Chart, pd.DataFrame, pd.DataFrame]]:
|
86 |
| - """ |
87 |
| - Visualize SageMaker hyperparameter tuning jobs. |
| 85 | + """Visualize SageMaker hyperparameter tuning jobs. |
88 | 86 |
|
89 | 87 | Args:
|
90 | 88 | tuning_jobs: Single tuning job or list of tuning jobs (name or HyperparameterTuner object)
|
@@ -147,8 +145,7 @@ def create_charts(
|
147 | 145 | color_trials: bool = False,
|
148 | 146 | advanced: bool = False,
|
149 | 147 | ) -> alt.Chart:
|
150 |
| - """ |
151 |
| - Create visualization charts for hyperparameter tuning results. |
| 148 | + """Create visualization charts for hyperparameter tuning results. |
152 | 149 |
|
153 | 150 | Args:
|
154 | 151 | trials_df: DataFrame containing trials data
|
@@ -240,7 +237,8 @@ def create_charts(
|
240 | 237 | # If we have multiple tuning jobs, we also want to be able
|
241 | 238 | # to discriminate based on the individual tuning job, so
|
242 | 239 | # we just treat them as an additional tuning parameter
|
243 |
| - tuning_parameters = tuning_parameters.copy() + (["TuningJobName"] if multiple_tuning_jobs else []) |
| 240 | + tuning_job_param = ["TuningJobName"] if multiple_tuning_jobs else [] |
| 241 | + tuning_parameters = tuning_parameters.copy() + tuning_job_param |
244 | 242 |
|
245 | 243 | # If we use early stopping and at least some jobs were
|
246 | 244 | # stopped early, we want to be able to discriminate
|
@@ -331,7 +329,7 @@ def render_detail_charts():
|
331 | 329 | bandwidth=0.01,
|
332 | 330 | groupby=[tuning_parameter],
|
333 | 331 | # https://github.com/vega/altair/issues/3203#issuecomment-2141558911
|
334 |
| - # Specifying extent no longer necessary (>5.1.2). Leaving the work around in it for now. |
| 332 | + # Specifying extent no longer necessary (>5.1.2). |
335 | 333 | extent=[
|
336 | 334 | trials_df[objective_name].min(),
|
337 | 335 | trials_df[objective_name].max(),
|
@@ -612,7 +610,7 @@ def render_progress_chart():
|
612 | 610 |
|
613 | 611 |
|
614 | 612 | def _clean_parameter_name(s):
|
615 |
| - """ Helper method to ensure proper parameter name characters for altair 5+ """ |
| 613 | + """Helper method to ensure proper parameter name characters for altair 5+""" |
616 | 614 | return s.replace(":", "_").replace(".", "_")
|
617 | 615 |
|
618 | 616 |
|
@@ -664,8 +662,10 @@ def _prepare_consolidated_df(trials_df):
|
664 | 662 |
|
665 | 663 |
|
666 | 664 | def _get_df(tuning_job_name, filter_out_stopped=False):
|
667 |
| - """Retrieves hyperparameter tuning job results and returns preprocessed DataFrame with |
668 |
| - tuning metrics and parameters.""" |
| 665 | + """Retrieves hyperparameter tuning job results and returns preprocessed DataFrame. |
| 666 | +
|
| 667 | + Returns a DataFrame containing tuning metrics and parameters for the specified job. |
| 668 | + """ |
669 | 669 |
|
670 | 670 | tuner = sagemaker.HyperparameterTuningJobAnalytics(tuning_job_name)
|
671 | 671 |
|
@@ -707,10 +707,12 @@ def _get_df(tuning_job_name, filter_out_stopped=False):
|
707 | 707 | # A float then?
|
708 | 708 | df[parameter_name] = df[parameter_name].astype(float)
|
709 | 709 |
|
710 |
| - except Exception: |
711 |
| - # Trouble, as this was not a number just pretending to be a string, but an actual string with |
712 |
| - # characters. Leaving the value untouched |
713 |
| - # Ex: Caught exception could not convert string to float: 'sqrt' <class 'ValueError'> |
| 710 | + except (ValueError, TypeError, AttributeError): |
| 711 | + # Catch exceptions that might occur during string manipulation or type conversion |
| 712 | + # - ValueError: Could not convert string to float/int |
| 713 | + # - TypeError: Object doesn't support the operation |
| 714 | + # - AttributeError: Object doesn't have replace method |
| 715 | + # Leaving the value untouched |
714 | 716 | pass
|
715 | 717 |
|
716 | 718 | return df
|
@@ -747,7 +749,7 @@ def get_job_analytics_data(tuning_job_names):
|
747 | 749 | tuning_job_names (str or list): Single tuning job name or list of names/tuner objects.
|
748 | 750 |
|
749 | 751 | Returns:
|
750 |
| - tuple: (DataFrame with training results, tuned parameters list, objective name, is_minimize flag). |
| 752 | + tuple: (DataFrame with training results, tuned params list, objective name, is_minimize). |
751 | 753 |
|
752 | 754 | Raises:
|
753 | 755 | ValueError: If tuning jobs have different objectives or optimization directions.
|
@@ -828,16 +830,18 @@ def get_job_analytics_data(tuning_job_names):
|
828 | 830 | if isinstance(val, str) and val.startswith('"'):
|
829 | 831 | try:
|
830 | 832 | df[column_name] = df[column_name].apply(lambda x: int(x.replace('"', "")))
|
831 |
| - except: # noqa: E722 nosec b110 if we fail, we just continue with what we had |
| 833 | + except (ValueError, TypeError, AttributeError): |
| 834 | + # noqa: E722 nosec b110 if we fail, we just continue with what we had |
832 | 835 | pass # Value is not an int, but a string
|
833 | 836 |
|
834 | 837 | df = df.sort_values("FinalObjectiveValue", ascending=is_minimize)
|
835 | 838 | df[objective_name] = df.pop("FinalObjectiveValue")
|
836 | 839 |
|
837 | 840 | # Fix potential issue with dates represented as objects, instead of a timestamp
|
838 | 841 | # This can in other cases lead to:
|
839 |
| - # https://www.markhneedham.com/blog/2020/01/10/altair-typeerror-object-type-date-not-json-serializable/ |
840 |
| - # Have only observed this for TrainingEndTime, but will be on the lookout dfor TrainingStartTime as well now |
| 842 | + # https://www.markhneedham.com/blog/2020/01/10/altair-typeerror-object-type- |
| 843 | + # date-not-json-serializable/ |
| 844 | + # Seen this for TrainingEndTime, but will watch TrainingStartTime as well now. |
841 | 845 | df["TrainingEndTime"] = pd.to_datetime(df["TrainingEndTime"])
|
842 | 846 | df["TrainingStartTime"] = pd.to_datetime(df["TrainingStartTime"])
|
843 | 847 |
|
|
0 commit comments