Skip to content

Commit 5426c87

Browse files
committed
Address review comments: convert case + add type hints
1 parent 1690177 commit 5426c87

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

scripts/compare-llama-bench.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@
123123
parser.add_argument("-s", "--show", help=help_s)
124124
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
125125
parser.add_argument("--plot", help="generate a performance comparison plot and save to specified file (e.g., plot.png)")
126-
parser.add_argument("--plot_x", help="parameter to use as x-axis for plotting (default: n_depth)", default="n_depth")
126+
parser.add_argument("--plot_x", help="parameter to use as x axis for plotting (default: n_depth)", default="n_depth")
127127

128128
known_args, unknown_args = parser.parse_known_args()
129129

@@ -136,7 +136,7 @@
136136
import matplotlib
137137
matplotlib.use('Agg')
138138
except ImportError as e:
139-
print("matplotlib is required for --plot.")
139+
logger.error("matplotlib is required for --plot.")
140140
raise e
141141

142142
if known_args.check:
@@ -613,9 +613,9 @@ def valid_format(data_files: list[str]) -> bool:
613613
headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"]
614614

615615
if known_args.plot:
616-
def create_performance_plot(table_data, headers, baseline_name, compare_name, output_file, plot_x_param):
616+
def create_performance_plot(table_data: list[list[str]], headers: list[str], baseline_name: str, compare_name: str, output_file: str, plot_x_param: str):
617617

618-
data_headers = headers[:-4] #Exclude the last 4 columns (Test, baseline t/s, compare t/s, Speedup)
618+
data_headers = headers[:-4] # Exclude the last 4 columns (Test, baseline t/s, compare t/s, Speedup)
619619
plot_x_index = None
620620
plot_x_label = plot_x_param
621621

@@ -687,7 +687,6 @@ def create_performance_plot(table_data, headers, baseline_name, compare_name, ou
687687
logger.error("No data available for plotting")
688688
return
689689

690-
691690
def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
692691
from math import ceil
693692
cols = 1 if num_groups == 1 else min(max_cols, num_groups)
@@ -696,8 +695,8 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
696695
# scale figure size by grid dimensions
697696
w, h = base_size
698697
fig, ax_arr = plt.subplots(rows, cols,
699-
figsize=(w * cols, h * rows),
700-
squeeze=False)
698+
figsize=(w * cols, h * rows),
699+
squeeze=False)
701700

702701
axes = ax_arr.flatten()[:num_groups]
703702
return fig, axes
@@ -739,7 +738,7 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
739738
key, value = part.split('=', 1)
740739
title_parts.append(f"{key}: {value}")
741740

742-
title = ', '.join(title_parts) if title_parts else "Performance Comparison"
741+
title = ', '.join(title_parts) if title_parts else "Performance comparison"
743742

744743
ax.set_xlabel(plot_x_label, fontsize=12, fontweight='bold')
745744
ax.set_ylabel('Tokens per Second (t/s)', fontsize=12, fontweight='bold')
@@ -752,11 +751,10 @@ def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
752751
for i in range(plot_idx, len(axes)):
753752
axes[i].set_visible(False)
754753

755-
fig.suptitle(f'Performance Comparison: {compare_name} vs {baseline_name}',
756-
fontsize=14, fontweight='bold')
754+
fig.suptitle(f'Performance comparison: {compare_name} vs {baseline_name}',
755+
fontsize=14, fontweight='bold')
757756
fig.subplots_adjust(top=1)
758757

759-
760758
plt.tight_layout()
761759
plt.savefig(output_file, dpi=300, bbox_inches='tight')
762760
plt.close()

0 commit comments

Comments
 (0)