|
| 1 | +import asyncio |
| 2 | +import json |
| 3 | +import traceback |
| 4 | +from pathlib import Path |
| 5 | +import modal |
| 6 | +import click |
| 7 | +from datetime import datetime |
| 8 | +from codegen.extensions.swebench.utils import SWEBenchDataset, get_swe_bench_example, get_swe_bench_examples |
| 9 | +from codegen.extensions.swebench.report import generate_report |
| 10 | + |
| 11 | +PREDS_DNAME = Path(__file__).parent / "predictions" |
| 12 | +LOG_DIR = Path(__file__).parent / "logs" |
| 13 | + |
| 14 | +run_agent_modal = modal.Function.lookup("swebench-agent-run", "run_agent_modal") |
| 15 | + |
| 16 | + |
| 17 | +async def process_batch(examples, batch_size=10): |
| 18 | + """Process a batch of examples concurrently. |
| 19 | +
|
| 20 | + Args: |
| 21 | + examples: List of SweBenchExample objects to process |
| 22 | + batch_size: Number of examples to process concurrently. |
| 23 | + Default is 50 which provides good parallelization |
| 24 | + while staying well within Modal's limits. |
| 25 | + """ |
| 26 | + results = [] |
| 27 | + |
| 28 | + # Process examples in batches |
| 29 | + for i in range(0, len(examples), batch_size): |
| 30 | + batch = examples[i : i + batch_size] |
| 31 | + |
| 32 | + # Create tasks for this batch |
| 33 | + batch_tasks = [run_agent_modal.remote.aio(example) for example in batch] |
| 34 | + |
| 35 | + # Wait for all tasks in this batch to complete |
| 36 | + print(f"Processing batch {i // batch_size + 1}/{len(examples) // batch_size + 1} (examples {i + 1}-{min(i + batch_size, len(examples))})") |
| 37 | + |
| 38 | + try: |
| 39 | + batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True) |
| 40 | + |
| 41 | + # Store results |
| 42 | + for example, result in zip(batch, batch_results): |
| 43 | + error_info = None |
| 44 | + |
| 45 | + if isinstance(result, Exception): |
| 46 | + error_type = type(result).__name__ |
| 47 | + error_info = { |
| 48 | + "error_type": error_type, |
| 49 | + "error_message": str(result), |
| 50 | + "traceback": traceback.format_exception(type(result), result, result.__traceback__), |
| 51 | + } |
| 52 | + |
| 53 | + if isinstance(result, modal.exception.Error): |
| 54 | + error_info["modal_error_code"] = getattr(result, "code", None) |
| 55 | + error_info["modal_error_details"] = getattr(result, "details", None) |
| 56 | + |
| 57 | + print(f"Error processing {example.instance_id}:") |
| 58 | + print(f"Type: {error_type}") |
| 59 | + print(f"Message: {str(result)}") |
| 60 | + print("Traceback:") |
| 61 | + print("".join(error_info["traceback"])) |
| 62 | + |
| 63 | + results.append({"instance_id": example.instance_id, "status": "error", "error_info": error_info}) |
| 64 | + else: |
| 65 | + if result is None: |
| 66 | + print(f"Warning: Null result for {example.instance_id}") |
| 67 | + results.append({"instance_id": example.instance_id, "status": "error", "error_info": {"error_type": "NullResult", "error_message": "Process returned None"}}) |
| 68 | + else: |
| 69 | + results.append(result) |
| 70 | + |
| 71 | + except Exception as e: |
| 72 | + print("Batch processing error:") |
| 73 | + print(f"Type: {type(e).__name__}") |
| 74 | + print(f"Message: {str(e)}") |
| 75 | + traceback.print_exc() |
| 76 | + |
| 77 | + # Mark all examples in the batch as failed |
| 78 | + for example in batch: |
| 79 | + results.append( |
| 80 | + { |
| 81 | + "instance_id": example.instance_id, |
| 82 | + "status": "error", |
| 83 | + "error_info": {"error_type": type(e).__name__, "error_message": str(e), "traceback": traceback.format_exc(), "batch_failure": True}, |
| 84 | + } |
| 85 | + ) |
| 86 | + |
| 87 | + return results |
| 88 | + |
| 89 | + |
| 90 | +async def run_eval(use_existing_preds, dataset, length, instance_id=None): |
| 91 | + dataset = SWEBenchDataset(dataset) |
| 92 | + if instance_id: |
| 93 | + examples = [get_swe_bench_example(instance_id, dataset=dataset)] |
| 94 | + else: |
| 95 | + examples = get_swe_bench_examples(dataset=dataset, length=length) |
| 96 | + |
| 97 | + try: |
| 98 | + if not use_existing_preds: |
| 99 | + print(f"Processing {len(examples)} examples...") |
| 100 | + |
| 101 | + # Create output directory if it doesn't exist |
| 102 | + PREDS_DNAME.mkdir(exist_ok=True) |
| 103 | + results_dir = PREDS_DNAME / "results" |
| 104 | + results_dir.mkdir(exist_ok=True) |
| 105 | + |
| 106 | + # Create a timestamp for this run |
| 107 | + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| 108 | + |
| 109 | + # Process all examples in parallel batches |
| 110 | + results = await process_batch(examples) |
| 111 | + |
| 112 | + # Save individual results |
| 113 | + for result in results: |
| 114 | + if result and "instance_id" in result: |
| 115 | + instance_id = result["instance_id"] |
| 116 | + output_file = results_dir / f"{instance_id}.json" |
| 117 | + with open(output_file, "w") as f: |
| 118 | + json.dump(result, f, indent=4) |
| 119 | + |
| 120 | + # Save summary file |
| 121 | + summary_file = results_dir / f"summary_{timestamp}.json" |
| 122 | + summary = { |
| 123 | + "timestamp": timestamp, |
| 124 | + "total_examples": len(examples), |
| 125 | + "successful": len([r for r in results if r and "status" not in r]), |
| 126 | + "failed": len([r for r in results if r and "status" in r and r["status"] == "error"]), |
| 127 | + "error_types": {}, |
| 128 | + "results": results, |
| 129 | + } |
| 130 | + |
| 131 | + # Collect error statistics |
| 132 | + for result in results: |
| 133 | + if result and "status" in result and result["status"] == "error": |
| 134 | + error_type = result.get("error_info", {}).get("error_type", "Unknown") |
| 135 | + summary["error_types"][error_type] = summary["error_types"].get(error_type, 0) + 1 |
| 136 | + |
| 137 | + with open(summary_file, "w") as f: |
| 138 | + json.dump(summary, f, indent=4) |
| 139 | + |
| 140 | + print("\nProcessing complete!") |
| 141 | + print(f"Results saved to: {results_dir}") |
| 142 | + print(f"Summary saved to: {summary_file}") |
| 143 | + print(f"Successful: {summary['successful']}/{summary['total_examples']}") |
| 144 | + print(f"Failed: {summary['failed']}/{summary['total_examples']}") |
| 145 | + if summary["error_types"]: |
| 146 | + print("\nError type distribution:") |
| 147 | + for error_type, count in summary["error_types"].items(): |
| 148 | + print(f" {error_type}: {count}") |
| 149 | + |
| 150 | + # Generate Report on Modal |
| 151 | + generate_report(PREDS_DNAME, LOG_DIR, dataset) |
| 152 | + except Exception: |
| 153 | + print("Fatal error in run_eval:") |
| 154 | + traceback.print_exc() |
| 155 | + raise |
| 156 | + |
| 157 | + |
| 158 | +@click.command() |
| 159 | +@click.option("--use-existing-preds", is_flag=True, help="Use existing predictions instead of generating new ones.") |
| 160 | +@click.option("--dataset", help="The dataset to use.", type=click.Choice([dataset.value for dataset in SWEBenchDataset]), default=SWEBenchDataset.LITE.value) |
| 161 | +@click.option("--length", help="The number of examples to process.", type=int, default=10) |
| 162 | +@click.option("--instance-id", help="The instance ID of the example to process.") |
| 163 | +def run_eval_command(use_existing_preds, dataset, length, instance_id): |
| 164 | + asyncio.run(run_eval(use_existing_preds, dataset, length, instance_id)) |
| 165 | + |
| 166 | + |
| 167 | +if __name__ == "__main__": |
| 168 | + run_eval_command() |
0 commit comments