|
1 | 1 | from typing import TYPE_CHECKING
|
2 | 2 |
|
3 |
| -import requests |
4 | 3 | from github import Repository
|
5 | 4 | from github.PullRequest import PullRequest
|
6 | 5 | from unidiff import PatchSet
|
@@ -39,28 +38,6 @@ def get_file_to_changed_ranges(pull_patch_set: PatchSet) -> dict[str, list]:
|
39 | 38 | return file_to_changed_ranges
|
40 | 39 |
|
41 | 40 |
|
42 |
| -def get_pull_patch_set(op: RepoOperator, pull: PullRequestContext) -> PatchSet: |
43 |
| - # Get the diff directly from GitHub's API |
44 |
| - if not op.remote_git_repo: |
45 |
| - msg = "GitHub API client is required to get PR diffs" |
46 |
| - raise ValueError(msg) |
47 |
| - |
48 |
| - # Get the diff directly from the PR |
49 |
| - diff_url = pull.raw_data.get("diff_url") |
50 |
| - if diff_url: |
51 |
| - # Fetch the diff content from the URL |
52 |
| - response = requests.get(diff_url) |
53 |
| - response.raise_for_status() |
54 |
| - diff = response.text |
55 |
| - else: |
56 |
| - # If diff_url not available, get the patch directly |
57 |
| - diff = pull.get_patch() |
58 |
| - |
59 |
| - # Parse the diff into a PatchSet |
60 |
| - pull_patch_set = PatchSet(diff) |
61 |
| - return pull_patch_set |
62 |
| - |
63 |
| - |
64 | 41 | def to_1_indexed(zero_indexed_range: range) -> range:
|
65 | 42 | """Converts a n-indexed range to n+1-indexed.
|
66 | 43 | Primarily to convert 0-indexed ranges to 1 indexed
|
@@ -131,7 +108,7 @@ def __init__(self, op: RepoOperator, codebase: "Codebase", pr: PullRequest):
|
131 | 108 | def modified_file_ranges(self) -> dict[str, list[tuple[int, int]]]:
|
132 | 109 | """Files and the ranges within that are modified"""
|
133 | 110 | if not self._modified_file_ranges:
|
134 |
| - pull_patch_set = get_pull_patch_set(op=self._op, pull=self._gh_pr) |
| 111 | + pull_patch_set = self.get_pull_patch_set() |
135 | 112 | self._modified_file_ranges = get_file_to_changed_ranges(pull_patch_set)
|
136 | 113 | return self._modified_file_ranges
|
137 | 114 |
|
@@ -174,15 +151,16 @@ def get_pr_diff(self) -> str:
|
174 | 151 | raise ValueError(msg)
|
175 | 152 |
|
176 | 153 | # Get the diff directly from the PR
|
177 |
| - diff_url = self._gh_pr.raw_data.get("diff_url") |
178 |
| - if diff_url: |
179 |
| - # Fetch the diff content from the URL |
180 |
| - response = requests.get(diff_url) |
181 |
| - response.raise_for_status() |
182 |
| - return response.text |
183 |
| - else: |
184 |
| - # If diff_url not available, get the patch directly |
185 |
| - return self._gh_pr.get_patch() |
| 154 | + status, _, res = self._op.remote_git_repo.repo._requester.requestJson("GET", self._gh_pr.url, headers={"Accept": "application/vnd.github.v3.diff"}) |
| 155 | + if status != 200: |
| 156 | + msg = f"Failed to get PR diff: {res}" |
| 157 | + raise Exception(msg) |
| 158 | + return res |
| 159 | + |
| 160 | + def get_pull_patch_set(self) -> PatchSet: |
| 161 | + diff = self.get_pr_diff() |
| 162 | + pull_patch_set = PatchSet(diff) |
| 163 | + return pull_patch_set |
186 | 164 |
|
187 | 165 | def get_commit_sha(self) -> str:
|
188 | 166 | """Get the commit SHA of the PR"""
|
|
0 commit comments