|
122 | 122 | parser.add_argument("--check", action="store_true", help="check if all required Python libraries are installed")
|
123 | 123 | parser.add_argument("-s", "--show", help=help_s)
|
124 | 124 | parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
| 125 | +parser.add_argument("--plot", help="generate a performance comparison plot and save to specified file (e.g., plot.png)") |
| 126 | +parser.add_argument("--plot_x", help="parameter to use as x-axis for plotting (default: n_depth)", default="n_depth") |
125 | 127 |
|
126 | 128 | known_args, unknown_args = parser.parse_known_args()
|
127 | 129 |
|
128 | 130 | logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO)
|
129 | 131 |
|
| 132 | +# Check for matplotlib if plotting is requested |
| 133 | +if known_args.plot: |
| 134 | + try: |
| 135 | + import matplotlib.pyplot as plt |
| 136 | + import matplotlib |
| 137 | + matplotlib.use('Agg') |
| 138 | + except ImportError as e: |
| 139 | + print("matplotlib is required for --plot.") |
| 140 | + raise e |
| 141 | + |
130 | 142 | if known_args.check:
|
131 | 143 | # Check if all required Python libraries are installed. Would have failed earlier if not.
|
132 | 144 | sys.exit(0)
|
@@ -600,6 +612,157 @@ def valid_format(data_files: list[str]) -> bool:
|
600 | 612 | headers = [PRETTY_NAMES[p] for p in show]
|
601 | 613 | headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"]
|
602 | 614 |
|
| 615 | +if known_args.plot: |
| 616 | + def create_performance_plot(table_data, headers, baseline_name, compare_name, output_file, plot_x_param): |
| 617 | + |
| 618 | + data_headers = headers[:-4] #Exclude the last 4 columns (Test, baseline t/s, compare t/s, Speedup) |
| 619 | + plot_x_index = None |
| 620 | + plot_x_label = plot_x_param |
| 621 | + |
| 622 | + if plot_x_param not in ["n_prompt", "n_gen", "n_depth"]: |
| 623 | + pretty_name = PRETTY_NAMES.get(plot_x_param, plot_x_param) |
| 624 | + if pretty_name in data_headers: |
| 625 | + plot_x_index = data_headers.index(pretty_name) |
| 626 | + plot_x_label = pretty_name |
| 627 | + elif plot_x_param in data_headers: |
| 628 | + plot_x_index = data_headers.index(plot_x_param) |
| 629 | + plot_x_label = plot_x_param |
| 630 | + else: |
| 631 | + logger.error(f"Parameter '{plot_x_param}' not found in current table columns. Available columns: {', '.join(data_headers)}") |
| 632 | + logger.error(f"To plot by '{plot_x_param}', include it in --show parameter or ensure it varies in your data.") |
| 633 | + return |
| 634 | + |
| 635 | + grouped_data = {} |
| 636 | + |
| 637 | + for i, row in enumerate(table_data): |
| 638 | + group_key_parts = [] |
| 639 | + test_name = row[-4] |
| 640 | + |
| 641 | + if plot_x_param in ["n_prompt", "n_gen", "n_depth"]: |
| 642 | + for j, val in enumerate(row[:-4]): |
| 643 | + header_name = data_headers[j] |
| 644 | + if val is not None and str(val).strip(): |
| 645 | + group_key_parts.append(f"{header_name}={val}") |
| 646 | + |
| 647 | + if plot_x_param == "n_prompt": |
| 648 | + assert "pp" in test_name, f"n_prompt test name {test_name} does not contain 'pp'" |
| 649 | + base_test = test_name.split("@")[0] |
| 650 | + x_value = base_test |
| 651 | + elif plot_x_param == "n_gen" and "tg" in test_name: |
| 652 | + assert "tg" in test_name, f"n_gen test name {test_name} does not contain 'tg'" |
| 653 | + x_value = test_name.split("@")[0] |
| 654 | + elif plot_x_param == "n_depth" and "@d" in test_name: |
| 655 | + assert "@d" in test_name, f"n_depth test name {test_name} does not contain '@d'" |
| 656 | + base_test = test_name.split("@d")[0] |
| 657 | + x_value = int(test_name.split("@d")[1]) |
| 658 | + else: |
| 659 | + base_test = test_name |
| 660 | + |
| 661 | + if base_test.strip(): |
| 662 | + group_key_parts.append(f"Test={base_test}") |
| 663 | + else: |
| 664 | + for j, val in enumerate(row[:-4]): |
| 665 | + if j != plot_x_index: |
| 666 | + header_name = data_headers[j] |
| 667 | + if val is not None and str(val).strip(): |
| 668 | + group_key_parts.append(f"{header_name}={val}") |
| 669 | + else: |
| 670 | + x_value = val |
| 671 | + |
| 672 | + group_key_parts.append(f"Test={test_name}") |
| 673 | + |
| 674 | + group_key = tuple(sorted(group_key_parts)) |
| 675 | + |
| 676 | + if group_key not in grouped_data: |
| 677 | + grouped_data[group_key] = [] |
| 678 | + |
| 679 | + grouped_data[group_key].append({ |
| 680 | + 'x_value': x_value, |
| 681 | + 'baseline': float(row[-3]), |
| 682 | + 'compare': float(row[-2]), |
| 683 | + 'speedup': float(row[-1]) |
| 684 | + }) |
| 685 | + |
| 686 | + if not grouped_data: |
| 687 | + logger.error("No data available for plotting") |
| 688 | + return |
| 689 | + |
| 690 | + |
| 691 | + def make_axes(num_groups, max_cols=2, base_size=(8, 4)): |
| 692 | + from math import ceil |
| 693 | + cols = 1 if num_groups == 1 else min(max_cols, num_groups) |
| 694 | + rows = ceil(num_groups / cols) |
| 695 | + |
| 696 | + # scale figure size by grid dimensions |
| 697 | + w, h = base_size |
| 698 | + fig, ax_arr = plt.subplots(rows, cols, |
| 699 | + figsize=(w * cols, h * rows), |
| 700 | + squeeze=False) |
| 701 | + |
| 702 | + axes = ax_arr.flatten()[:num_groups] |
| 703 | + return fig, axes |
| 704 | + |
| 705 | + num_groups = len(grouped_data) |
| 706 | + fig, axes = make_axes(num_groups) |
| 707 | + |
| 708 | + plot_idx = 0 |
| 709 | + |
| 710 | + for group_key, points in grouped_data.items(): |
| 711 | + if plot_idx >= len(axes): |
| 712 | + break |
| 713 | + ax = axes[plot_idx] |
| 714 | + |
| 715 | + try: |
| 716 | + points_sorted = sorted(points, key=lambda p: float(p['x_value']) if p['x_value'] is not None else 0) |
| 717 | + x_values = [float(p['x_value']) if p['x_value'] is not None else 0 for p in points_sorted] |
| 718 | + except ValueError: |
| 719 | + points_sorted = sorted(points, key=lambda p: group_key) |
| 720 | + x_values = [p['x_value'] for p in points_sorted] |
| 721 | + |
| 722 | + baseline_vals = [p['baseline'] for p in points_sorted] |
| 723 | + compare_vals = [p['compare'] for p in points_sorted] |
| 724 | + |
| 725 | + ax.plot(x_values, baseline_vals, 'o-', color='skyblue', |
| 726 | + label=f'{baseline_name}', linewidth=2, markersize=6) |
| 727 | + ax.plot(x_values, compare_vals, 's--', color='lightcoral', alpha=0.8, |
| 728 | + label=f'{compare_name}', linewidth=2, markersize=6) |
| 729 | + |
| 730 | + if plot_x_param == "n_depth" and max(x_values) > 0 and max(x_values) > min(x_values) * 4: |
| 731 | + ax.set_xscale('log', base=2) |
| 732 | + unique_x = sorted(set(x_values)) |
| 733 | + ax.set_xticks(unique_x) |
| 734 | + ax.set_xticklabels([str(int(x)) for x in unique_x]) |
| 735 | + |
| 736 | + title_parts = [] |
| 737 | + for part in group_key: |
| 738 | + if '=' in part: |
| 739 | + key, value = part.split('=', 1) |
| 740 | + title_parts.append(f"{key}: {value}") |
| 741 | + |
| 742 | + title = ', '.join(title_parts) if title_parts else "Performance Comparison" |
| 743 | + |
| 744 | + ax.set_xlabel(plot_x_label, fontsize=12, fontweight='bold') |
| 745 | + ax.set_ylabel('Tokens per Second (t/s)', fontsize=12, fontweight='bold') |
| 746 | + ax.set_title(title, fontsize=12, fontweight='bold') |
| 747 | + ax.legend(loc='best', fontsize=10) |
| 748 | + ax.grid(True, alpha=0.3) |
| 749 | + |
| 750 | + plot_idx += 1 |
| 751 | + |
| 752 | + for i in range(plot_idx, len(axes)): |
| 753 | + axes[i].set_visible(False) |
| 754 | + |
| 755 | + fig.suptitle(f'Performance Comparison: {compare_name} vs {baseline_name}', |
| 756 | + fontsize=14, fontweight='bold') |
| 757 | + fig.subplots_adjust(top=1) |
| 758 | + |
| 759 | + |
| 760 | + plt.tight_layout() |
| 761 | + plt.savefig(output_file, dpi=300, bbox_inches='tight') |
| 762 | + plt.close() |
| 763 | + |
| 764 | + create_performance_plot(table, headers, name_baseline, name_compare, known_args.plot, known_args.plot_x) |
| 765 | + |
603 | 766 | print(tabulate( # noqa: NP100
|
604 | 767 | table,
|
605 | 768 | headers=headers,
|
|
0 commit comments