129
129
130
130
logging .basicConfig (level = logging .DEBUG if known_args .verbose else logging .INFO )
131
131
132
- # Check for matplotlib if plotting is requested
133
132
if known_args .plot :
134
133
try :
135
134
import matplotlib .pyplot as plt
@@ -511,7 +510,6 @@ def valid_format(data_files: list[str]) -> bool:
511
510
512
511
name_compare = bench_data .get_commit_name (hexsha8_compare )
513
512
514
-
515
513
# If the user provided columns to group the results by, use them:
516
514
if known_args .show is not None :
517
515
show = known_args .show .split ("," )
@@ -556,6 +554,14 @@ def valid_format(data_files: list[str]) -> bool:
556
554
show .remove (prop )
557
555
except ValueError :
558
556
pass
557
+
558
+ # add plot_x parameter to if it's not already there
559
+ if known_args .plot :
560
+ for k , v in PRETTY_NAMES .items ():
561
+ if v == known_args .plot_x and k not in show :
562
+ show .append (k )
563
+ break
564
+
559
565
rows_show = bench_data .get_rows (show , hexsha8_baseline , hexsha8_compare )
560
566
561
567
if not rows_show :
@@ -629,7 +635,6 @@ def create_performance_plot(table_data: list[list[str]], headers: list[str], bas
629
635
plot_x_label = plot_x_param
630
636
else :
631
637
logger .error (f"Parameter '{ plot_x_param } ' not found in current table columns. Available columns: { ', ' .join (data_headers )} " )
632
- logger .error (f"To plot by '{ plot_x_param } ', include it in --show parameter or ensure it varies in your data." )
633
638
return
634
639
635
640
grouped_data = {}
@@ -671,7 +676,7 @@ def create_performance_plot(table_data: list[list[str]], headers: list[str], bas
671
676
672
677
group_key_parts .append (f"Test={ test_name } " )
673
678
674
- group_key = tuple (sorted ( group_key_parts ) )
679
+ group_key = tuple (group_key_parts )
675
680
676
681
if group_key not in grouped_data :
677
682
grouped_data [group_key ] = []
@@ -692,7 +697,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
692
697
cols = 1 if num_groups == 1 else min (max_cols , num_groups )
693
698
rows = ceil (num_groups / cols )
694
699
695
- # scale figure size by grid dimensions
700
+ # Scale figure size by grid dimensions
696
701
w , h = base_size
697
702
fig , ax_arr = plt .subplots (rows , cols ,
698
703
figsize = (w * cols , h * rows ),
@@ -726,7 +731,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
726
731
ax .plot (x_values , compare_vals , 's--' , color = 'lightcoral' , alpha = 0.8 ,
727
732
label = f'{ compare_name } ' , linewidth = 2 , markersize = 6 )
728
733
729
- if plot_x_param == "n_depth" and max (x_values ) > 0 and max (x_values ) > min (x_values ) * 4 :
734
+ if plot_x_param == "n_depth" and min (x_values ) > 0 and max (x_values ) > min (x_values ) * 4 :
730
735
ax .set_xscale ('log' , base = 2 )
731
736
unique_x = sorted (set (x_values ))
732
737
ax .set_xticks (unique_x )
@@ -741,7 +746,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
741
746
title = ', ' .join (title_parts ) if title_parts else "Performance comparison"
742
747
743
748
ax .set_xlabel (plot_x_label , fontsize = 12 , fontweight = 'bold' )
744
- ax .set_ylabel ('Tokens per Second (t/s)' , fontsize = 12 , fontweight = 'bold' )
749
+ ax .set_ylabel ('Tokens per second (t/s)' , fontsize = 12 , fontweight = 'bold' )
745
750
ax .set_title (title , fontsize = 12 , fontweight = 'bold' )
746
751
ax .legend (loc = 'best' , fontsize = 10 )
747
752
ax .grid (True , alpha = 0.3 )
@@ -751,7 +756,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
751
756
for i in range (plot_idx , len (axes )):
752
757
axes [i ].set_visible (False )
753
758
754
- fig .suptitle (f'Performance comparison: { compare_name } vs { baseline_name } ' ,
759
+ fig .suptitle (f'Performance comparison: { compare_name } vs. { baseline_name } ' ,
755
760
fontsize = 14 , fontweight = 'bold' )
756
761
fig .subplots_adjust (top = 1 )
757
762
0 commit comments