@@ -511,7 +511,6 @@ def valid_format(data_files: list[str]) -> bool:
511
511
512
512
name_compare = bench_data .get_commit_name (hexsha8_compare )
513
513
514
-
515
514
# If the user provided columns to group the results by, use them:
516
515
if known_args .show is not None :
517
516
show = known_args .show .split ("," )
@@ -556,6 +555,14 @@ def valid_format(data_files: list[str]) -> bool:
556
555
show .remove (prop )
557
556
except ValueError :
558
557
pass
558
+
559
+ if known_args .plot :
560
+ for k , v in PRETTY_NAMES .items ():
561
+ if v == known_args .plot_x and k not in show :
562
+ logger .info (f"Adding { k } to --show" )
563
+ show .append (k )
564
+ break
565
+
559
566
rows_show = bench_data .get_rows (show , hexsha8_baseline , hexsha8_compare )
560
567
561
568
if not rows_show :
@@ -629,7 +636,6 @@ def create_performance_plot(table_data: list[list[str]], headers: list[str], bas
629
636
plot_x_label = plot_x_param
630
637
else :
631
638
logger .error (f"Parameter '{ plot_x_param } ' not found in current table columns. Available columns: { ', ' .join (data_headers )} " )
632
- logger .error (f"To plot by '{ plot_x_param } ', include it in --show parameter or ensure it varies in your data." )
633
639
return
634
640
635
641
grouped_data = {}
@@ -671,7 +677,7 @@ def create_performance_plot(table_data: list[list[str]], headers: list[str], bas
671
677
672
678
group_key_parts .append (f"Test={ test_name } " )
673
679
674
- group_key = tuple (sorted ( group_key_parts ) )
680
+ group_key = tuple (group_key_parts )
675
681
676
682
if group_key not in grouped_data :
677
683
grouped_data [group_key ] = []
@@ -692,7 +698,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
692
698
cols = 1 if num_groups == 1 else min (max_cols , num_groups )
693
699
rows = ceil (num_groups / cols )
694
700
695
- # scale figure size by grid dimensions
701
+ # Scale figure size by grid dimensions
696
702
w , h = base_size
697
703
fig , ax_arr = plt .subplots (rows , cols ,
698
704
figsize = (w * cols , h * rows ),
@@ -726,7 +732,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
726
732
ax .plot (x_values , compare_vals , 's--' , color = 'lightcoral' , alpha = 0.8 ,
727
733
label = f'{ compare_name } ' , linewidth = 2 , markersize = 6 )
728
734
729
- if plot_x_param == "n_depth" and max (x_values ) > 0 and max (x_values ) > min (x_values ) * 4 :
735
+ if plot_x_param == "n_depth" and min (x_values ) > 0 and max (x_values ) > min (x_values ) * 4 :
730
736
ax .set_xscale ('log' , base = 2 )
731
737
unique_x = sorted (set (x_values ))
732
738
ax .set_xticks (unique_x )
@@ -741,7 +747,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
741
747
title = ', ' .join (title_parts ) if title_parts else "Performance comparison"
742
748
743
749
ax .set_xlabel (plot_x_label , fontsize = 12 , fontweight = 'bold' )
744
- ax .set_ylabel ('Tokens per Second (t/s)' , fontsize = 12 , fontweight = 'bold' )
750
+ ax .set_ylabel ('Tokens per second (t/s)' , fontsize = 12 , fontweight = 'bold' )
745
751
ax .set_title (title , fontsize = 12 , fontweight = 'bold' )
746
752
ax .legend (loc = 'best' , fontsize = 10 )
747
753
ax .grid (True , alpha = 0.3 )
@@ -751,7 +757,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
751
757
for i in range (plot_idx , len (axes )):
752
758
axes [i ].set_visible (False )
753
759
754
- fig .suptitle (f'Performance comparison: { compare_name } vs { baseline_name } ' ,
760
+ fig .suptitle (f'Performance comparison: { compare_name } vs. { baseline_name } ' ,
755
761
fontsize = 14 , fontweight = 'bold' )
756
762
fig .subplots_adjust (top = 1 )
757
763
0 commit comments