|
4 | 4 | import evaluate
|
5 | 5 | import numpy as np
|
6 | 6 | import json
|
| 7 | +from multiprocessing import Pool, cpu_count |
7 | 8 |
|
8 | 9 |
|
9 | 10 | def get_args():
|
@@ -52,12 +53,21 @@ def postprocess_text(preds, targets):
|
52 | 53 | return preds, targets
|
53 | 54 |
|
54 | 55 |
|
| 56 | +def compute_rouge_chunk(chunk): |
| 57 | + """Compute ROUGE scores for a chunk of predictions and references.""" |
| 58 | + metric = evaluate.load("rouge") |
| 59 | + preds, targets = chunk |
| 60 | + result = metric.compute( |
| 61 | + predictions=preds, references=targets, use_stemmer=True, use_aggregator=False |
| 62 | + ) |
| 63 | + return result |
| 64 | + |
| 65 | + |
55 | 66 | def main():
|
56 | 67 |
|
57 | 68 | args = get_args()
|
58 | 69 | dataset_path = args.dataset_file
|
59 | 70 | checkpoint_path = args.checkpoint_path
|
60 |
| - metric = evaluate.load("rouge") |
61 | 71 | nltk.download("punkt")
|
62 | 72 | nltk.download("punkt_tab")
|
63 | 73 |
|
@@ -103,23 +113,43 @@ def main():
|
103 | 113 |
|
104 | 114 | preds, targets = postprocess_text(preds_decoded_text, target_required)
|
105 | 115 |
|
106 |
| - result = metric.compute( |
107 |
| - predictions=preds, references=targets, use_stemmer=True, use_aggregator=False |
108 |
| - ) |
109 |
| - result = {k: round(np.mean(v) * 100, 4) for k, v in result.items()} |
| 116 | + # Split data into chunks for parallel processing |
| 117 | + num_chunks = cpu_count() # Number of parallel processes |
| 118 | + chunk_size = len(preds) // num_chunks + (len(preds) % num_chunks > 0) |
| 119 | + |
| 120 | + chunks = [ |
| 121 | + (preds[i:i + chunk_size], targets[i:i + chunk_size]) |
| 122 | + for i in range(0, len(preds), chunk_size) |
| 123 | + ] |
| 124 | + |
| 125 | + # Use multiprocessing Pool to compute ROUGE scores in parallel |
| 126 | + with Pool(num_chunks) as pool: |
| 127 | + results_list = pool.map(compute_rouge_chunk, chunks) |
| 128 | + |
| 129 | + # Aggregate results from all chunks |
| 130 | + aggregated_results = {} |
| 131 | + |
| 132 | + for result in results_list: |
| 133 | + for k, v in result.items(): |
| 134 | + if k not in aggregated_results: |
| 135 | + aggregated_results[k] = [] |
| 136 | + aggregated_results[k].extend(v) |
| 137 | + |
| 138 | + final_result = {k: round(np.mean(v) * 100, 4) |
| 139 | + for k, v in aggregated_results.items()} |
| 140 | + |
110 | 141 | prediction_lens = [len(pred) for pred in preds]
|
111 | 142 | gen_num = len(preds)
|
112 | 143 |
|
113 |
| - result = { |
114 |
| - **result, |
| 144 | + final_result.update({ |
115 | 145 | "gen_len": np.sum(prediction_lens),
|
116 | 146 | "gen_num": gen_num,
|
117 | 147 | "gen_tok_len": gen_tok_len,
|
118 | 148 | "tokens_per_sample": round(gen_tok_len / gen_num, 1),
|
119 |
| - } |
| 149 | + }) |
120 | 150 |
|
121 | 151 | print("\nResults\n")
|
122 |
| - print(result) |
| 152 | + print(final_result) |
123 | 153 |
|
124 | 154 |
|
125 | 155 | if __name__ == "__main__":
|
|
0 commit comments