|
19 | 19 | print("the following Python libraries are required: GitPython, tabulate.") # noqa: NP100
|
20 | 20 | raise e
|
21 | 21 |
|
| 22 | + |
22 | 23 | logger = logging.getLogger("compare-llama-bench")
|
23 | 24 |
|
24 | 25 | # All llama-bench SQL fields
|
|
122 | 123 | parser.add_argument("--check", action="store_true", help="check if all required Python libraries are installed")
|
123 | 124 | parser.add_argument("-s", "--show", help=help_s)
|
124 | 125 | parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
| 126 | +parser.add_argument("--plot", help="generate a performance comparison plot and save to specified file (e.g., plot.png)") |
| 127 | +parser.add_argument("--plot_x", help="parameter to use as x axis for plotting (default: n_depth)", default="n_depth") |
| 128 | +parser.add_argument("--plot_log_scale", action="store_true", help="use log scale for x axis in plots (off by default)") |
125 | 129 |
|
126 | 130 | known_args, unknown_args = parser.parse_known_args()
|
127 | 131 |
|
128 | 132 | logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO)
|
129 | 133 |
|
| 134 | + |
130 | 135 | if known_args.check:
|
131 | 136 | # Check if all required Python libraries are installed. Would have failed earlier if not.
|
132 | 137 | sys.exit(0)
|
@@ -499,7 +504,6 @@ def valid_format(data_files: list[str]) -> bool:
|
499 | 504 |
|
500 | 505 | name_compare = bench_data.get_commit_name(hexsha8_compare)
|
501 | 506 |
|
502 |
| - |
503 | 507 | # If the user provided columns to group the results by, use them:
|
504 | 508 | if known_args.show is not None:
|
505 | 509 | show = known_args.show.split(",")
|
@@ -544,6 +548,14 @@ def valid_format(data_files: list[str]) -> bool:
|
544 | 548 | show.remove(prop)
|
545 | 549 | except ValueError:
|
546 | 550 | pass
|
| 551 | + |
| 552 | + # Add plot_x parameter to parameters to show if it's not already present: |
| 553 | + if known_args.plot: |
| 554 | + for k, v in PRETTY_NAMES.items(): |
| 555 | + if v == known_args.plot_x and k not in show: |
| 556 | + show.append(k) |
| 557 | + break |
| 558 | + |
547 | 559 | rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare)
|
548 | 560 |
|
549 | 561 | if not rows_show:
|
@@ -600,6 +612,161 @@ def valid_format(data_files: list[str]) -> bool:
|
600 | 612 | headers = [PRETTY_NAMES[p] for p in show]
|
601 | 613 | headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"]
|
602 | 614 |
|
| 615 | +if known_args.plot: |
| 616 | + def create_performance_plot(table_data: list[list[str]], headers: list[str], baseline_name: str, compare_name: str, output_file: str, plot_x_param: str, log_scale: bool = False): |
| 617 | + try: |
| 618 | + import matplotlib.pyplot as plt |
| 619 | + import matplotlib |
| 620 | + matplotlib.use('Agg') |
| 621 | + except ImportError as e: |
| 622 | + logger.error("matplotlib is required for --plot.") |
| 623 | + raise e |
| 624 | + |
| 625 | + data_headers = headers[:-4] # Exclude the last 4 columns (Test, baseline t/s, compare t/s, Speedup) |
| 626 | + plot_x_index = None |
| 627 | + plot_x_label = plot_x_param |
| 628 | + |
| 629 | + if plot_x_param not in ["n_prompt", "n_gen", "n_depth"]: |
| 630 | + pretty_name = PRETTY_NAMES.get(plot_x_param, plot_x_param) |
| 631 | + if pretty_name in data_headers: |
| 632 | + plot_x_index = data_headers.index(pretty_name) |
| 633 | + plot_x_label = pretty_name |
| 634 | + elif plot_x_param in data_headers: |
| 635 | + plot_x_index = data_headers.index(plot_x_param) |
| 636 | + plot_x_label = plot_x_param |
| 637 | + else: |
| 638 | + logger.error(f"Parameter '{plot_x_param}' not found in current table columns. Available columns: {', '.join(data_headers)}") |
| 639 | + return |
| 640 | + |
| 641 | + grouped_data = {} |
| 642 | + |
| 643 | + for i, row in enumerate(table_data): |
| 644 | + group_key_parts = [] |
| 645 | + test_name = row[-4] |
| 646 | + |
| 647 | + base_test = "" |
| 648 | + x_value = None |
| 649 | + |
| 650 | + if plot_x_param in ["n_prompt", "n_gen", "n_depth"]: |
| 651 | + for j, val in enumerate(row[:-4]): |
| 652 | + header_name = data_headers[j] |
| 653 | + if val is not None and str(val).strip(): |
| 654 | + group_key_parts.append(f"{header_name}={val}") |
| 655 | + |
| 656 | + if plot_x_param == "n_prompt" and "pp" in test_name: |
| 657 | + base_test = test_name.split("@")[0] |
| 658 | + x_value = base_test |
| 659 | + elif plot_x_param == "n_gen" and "tg" in test_name: |
| 660 | + x_value = test_name.split("@")[0] |
| 661 | + elif plot_x_param == "n_depth" and "@d" in test_name: |
| 662 | + base_test = test_name.split("@d")[0] |
| 663 | + x_value = int(test_name.split("@d")[1]) |
| 664 | + else: |
| 665 | + base_test = test_name |
| 666 | + |
| 667 | + if base_test.strip(): |
| 668 | + group_key_parts.append(f"Test={base_test}") |
| 669 | + else: |
| 670 | + for j, val in enumerate(row[:-4]): |
| 671 | + if j != plot_x_index: |
| 672 | + header_name = data_headers[j] |
| 673 | + if val is not None and str(val).strip(): |
| 674 | + group_key_parts.append(f"{header_name}={val}") |
| 675 | + else: |
| 676 | + x_value = val |
| 677 | + |
| 678 | + group_key_parts.append(f"Test={test_name}") |
| 679 | + |
| 680 | + group_key = tuple(group_key_parts) |
| 681 | + |
| 682 | + if group_key not in grouped_data: |
| 683 | + grouped_data[group_key] = [] |
| 684 | + |
| 685 | + grouped_data[group_key].append({ |
| 686 | + 'x_value': x_value, |
| 687 | + 'baseline': float(row[-3]), |
| 688 | + 'compare': float(row[-2]), |
| 689 | + 'speedup': float(row[-1]) |
| 690 | + }) |
| 691 | + |
| 692 | + if not grouped_data: |
| 693 | + logger.error("No data available for plotting") |
| 694 | + return |
| 695 | + |
| 696 | + def make_axes(num_groups, max_cols=2, base_size=(8, 4)): |
| 697 | + from math import ceil |
| 698 | + cols = 1 if num_groups == 1 else min(max_cols, num_groups) |
| 699 | + rows = ceil(num_groups / cols) |
| 700 | + |
| 701 | + # Scale figure size by grid dimensions |
| 702 | + w, h = base_size |
| 703 | + fig, ax_arr = plt.subplots(rows, cols, |
| 704 | + figsize=(w * cols, h * rows), |
| 705 | + squeeze=False) |
| 706 | + |
| 707 | + axes = ax_arr.flatten()[:num_groups] |
| 708 | + return fig, axes |
| 709 | + |
| 710 | + num_groups = len(grouped_data) |
| 711 | + fig, axes = make_axes(num_groups) |
| 712 | + |
| 713 | + plot_idx = 0 |
| 714 | + |
| 715 | + for group_key, points in grouped_data.items(): |
| 716 | + if plot_idx >= len(axes): |
| 717 | + break |
| 718 | + ax = axes[plot_idx] |
| 719 | + |
| 720 | + try: |
| 721 | + points_sorted = sorted(points, key=lambda p: float(p['x_value']) if p['x_value'] is not None else 0) |
| 722 | + x_values = [float(p['x_value']) if p['x_value'] is not None else 0 for p in points_sorted] |
| 723 | + except ValueError: |
| 724 | + points_sorted = sorted(points, key=lambda p: group_key) |
| 725 | + x_values = [p['x_value'] for p in points_sorted] |
| 726 | + |
| 727 | + baseline_vals = [p['baseline'] for p in points_sorted] |
| 728 | + compare_vals = [p['compare'] for p in points_sorted] |
| 729 | + |
| 730 | + ax.plot(x_values, baseline_vals, 'o-', color='skyblue', |
| 731 | + label=f'{baseline_name}', linewidth=2, markersize=6) |
| 732 | + ax.plot(x_values, compare_vals, 's--', color='lightcoral', alpha=0.8, |
| 733 | + label=f'{compare_name}', linewidth=2, markersize=6) |
| 734 | + |
| 735 | + if log_scale: |
| 736 | + ax.set_xscale('log', base=2) |
| 737 | + unique_x = sorted(set(x_values)) |
| 738 | + ax.set_xticks(unique_x) |
| 739 | + ax.set_xticklabels([str(int(x)) for x in unique_x]) |
| 740 | + |
| 741 | + title_parts = [] |
| 742 | + for part in group_key: |
| 743 | + if '=' in part: |
| 744 | + key, value = part.split('=', 1) |
| 745 | + title_parts.append(f"{key}: {value}") |
| 746 | + |
| 747 | + title = ', '.join(title_parts) if title_parts else "Performance comparison" |
| 748 | + |
| 749 | + ax.set_xlabel(plot_x_label, fontsize=12, fontweight='bold') |
| 750 | + ax.set_ylabel('Tokens per second (t/s)', fontsize=12, fontweight='bold') |
| 751 | + ax.set_title(title, fontsize=12, fontweight='bold') |
| 752 | + ax.legend(loc='best', fontsize=10) |
| 753 | + ax.grid(True, alpha=0.3) |
| 754 | + |
| 755 | + plot_idx += 1 |
| 756 | + |
| 757 | + for i in range(plot_idx, len(axes)): |
| 758 | + axes[i].set_visible(False) |
| 759 | + |
| 760 | + fig.suptitle(f'Performance comparison: {compare_name} vs. {baseline_name}', |
| 761 | + fontsize=14, fontweight='bold') |
| 762 | + fig.subplots_adjust(top=1) |
| 763 | + |
| 764 | + plt.tight_layout() |
| 765 | + plt.savefig(output_file, dpi=300, bbox_inches='tight') |
| 766 | + plt.close() |
| 767 | + |
| 768 | + create_performance_plot(table, headers, name_baseline, name_compare, known_args.plot, known_args.plot_x, known_args.plot_log_scale) |
| 769 | + |
603 | 770 | print(tabulate( # noqa: NP100
|
604 | 771 | table,
|
605 | 772 | headers=headers,
|
|
0 commit comments