123
123
parser .add_argument ("-s" , "--show" , help = help_s )
124
124
parser .add_argument ("--verbose" , action = "store_true" , help = "increase output verbosity" )
125
125
parser .add_argument ("--plot" , help = "generate a performance comparison plot and save to specified file (e.g., plot.png)" )
126
- parser .add_argument ("--plot_x" , help = "parameter to use as x- axis for plotting (default: n_depth)" , default = "n_depth" )
126
+ parser .add_argument ("--plot_x" , help = "parameter to use as x axis for plotting (default: n_depth)" , default = "n_depth" )
127
127
128
128
known_args , unknown_args = parser .parse_known_args ()
129
129
136
136
import matplotlib
137
137
matplotlib .use ('Agg' )
138
138
except ImportError as e :
139
- print ("matplotlib is required for --plot." )
139
+ logger . error ("matplotlib is required for --plot." )
140
140
raise e
141
141
142
142
if known_args .check :
@@ -613,9 +613,9 @@ def valid_format(data_files: list[str]) -> bool:
613
613
headers += ["Test" , f"t/s { name_baseline } " , f"t/s { name_compare } " , "Speedup" ]
614
614
615
615
if known_args .plot :
616
- def create_performance_plot (table_data , headers , baseline_name , compare_name , output_file , plot_x_param ):
616
+ def create_performance_plot (table_data : list [ list [ str ]] , headers : list [ str ] , baseline_name : str , compare_name : str , output_file : str , plot_x_param : str ):
617
617
618
- data_headers = headers [:- 4 ] #Exclude the last 4 columns (Test, baseline t/s, compare t/s, Speedup)
618
+ data_headers = headers [:- 4 ] # Exclude the last 4 columns (Test, baseline t/s, compare t/s, Speedup)
619
619
plot_x_index = None
620
620
plot_x_label = plot_x_param
621
621
@@ -687,7 +687,6 @@ def create_performance_plot(table_data, headers, baseline_name, compare_name, ou
687
687
logger .error ("No data available for plotting" )
688
688
return
689
689
690
-
691
690
def make_axes (num_groups , max_cols = 2 , base_size = (8 , 4 )):
692
691
from math import ceil
693
692
cols = 1 if num_groups == 1 else min (max_cols , num_groups )
@@ -696,8 +695,8 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
696
695
# scale figure size by grid dimensions
697
696
w , h = base_size
698
697
fig , ax_arr = plt .subplots (rows , cols ,
699
- figsize = (w * cols , h * rows ),
700
- squeeze = False )
698
+ figsize = (w * cols , h * rows ),
699
+ squeeze = False )
701
700
702
701
axes = ax_arr .flatten ()[:num_groups ]
703
702
return fig , axes
@@ -739,7 +738,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
739
738
key , value = part .split ('=' , 1 )
740
739
title_parts .append (f"{ key } : { value } " )
741
740
742
- title = ', ' .join (title_parts ) if title_parts else "Performance Comparison "
741
+ title = ', ' .join (title_parts ) if title_parts else "Performance comparison "
743
742
744
743
ax .set_xlabel (plot_x_label , fontsize = 12 , fontweight = 'bold' )
745
744
ax .set_ylabel ('Tokens per Second (t/s)' , fontsize = 12 , fontweight = 'bold' )
@@ -752,11 +751,10 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
752
751
for i in range (plot_idx , len (axes )):
753
752
axes [i ].set_visible (False )
754
753
755
- fig .suptitle (f'Performance Comparison : { compare_name } vs { baseline_name } ' ,
756
- fontsize = 14 , fontweight = 'bold' )
754
+ fig .suptitle (f'Performance comparison : { compare_name } vs { baseline_name } ' ,
755
+ fontsize = 14 , fontweight = 'bold' )
757
756
fig .subplots_adjust (top = 1 )
758
757
759
-
760
758
plt .tight_layout ()
761
759
plt .savefig (output_file , dpi = 300 , bbox_inches = 'tight' )
762
760
plt .close ()
0 commit comments