@@ -26,9 +26,9 @@ def parse_args():
26
26
required = True ,
27
27
)
28
28
parser .add_argument (
29
- "--pr " ,
30
- type = int ,
31
- help = "Number of the PR in the stack to check and create corresponding PR" ,
29
+ "--ref " ,
30
+ type = str ,
31
+ help = "Ref fo PR in the stack to check and create corresponding PR" ,
32
32
required = True ,
33
33
)
34
34
return parser .parse_args ()
@@ -68,12 +68,18 @@ def extract_stack_from_body(pr_body: str) -> List[int]:
68
68
return list (reversed (prs ))
69
69
70
70
71
- def get_pr_stack_from_number (pr_number : int , repo : Repository ) -> List [int ]:
71
+ def get_pr_stack_from_number (ref : str , repo : Repository ) -> List [int ]:
72
+ if ref .isnumeric ():
73
+ pr_number = int (ref )
74
+ else :
75
+ branch_name = ref .replace ("refs/heads/" , "" )
76
+ pr_number = repo .get_branch (branch_name ).commit .get_pulls ()[0 ].number
77
+
72
78
pr_stack = extract_stack_from_body (repo .get_pull (pr_number ).body )
73
79
74
80
if not pr_stack :
75
81
raise Exception (
76
- f"Could not find PR stack in body of # { pr_number } . "
82
+ f"Could not find PR stack in body of ref . "
77
83
+ "Please make sure that the PR was created with ghstack."
78
84
)
79
85
@@ -129,7 +135,7 @@ def main():
129
135
130
136
with Github (auth = Auth .Token (os .environ ["GITHUB_TOKEN" ])) as gh :
131
137
repo = gh .get_repo (args .repo )
132
- create_prs_for_orig_branch (get_pr_stack_from_number (args .pr , repo ), repo )
138
+ create_prs_for_orig_branch (get_pr_stack_from_number (args .ref , repo ), repo )
133
139
134
140
135
141
if __name__ == "__main__" :
0 commit comments