@@ -106,27 +106,61 @@ def validate_commits(self, data):
106
106
return False , f"More than 1 commit! { len (data ['commits' ])} "
107
107
return True
108
108
109
- def validate_pr (self ):
109
+ def _normalize_pr (self , parg : str ):
110
+ if parg .isdigit ():
111
+ return parg
112
+ elif parg .startswith ("https://github.com/llvm/llvm-project/pull" ):
113
+ # try to parse the following url https://github.com/llvm/llvm-project/pull/114089
114
+ i = parg [parg .rfind ("/" ) + 1 :]
115
+ if not i .isdigit ():
116
+ raise RuntimeError (f"{ i } is not a number, malformatted input." )
117
+ return i
118
+ else :
119
+ raise RuntimeError (
120
+ f"PR argument must be PR ID or pull request URL - { parg } is wrong."
121
+ )
122
+
123
+ def load_pr_data (self ):
124
+ self .args .pr = self ._normalize_pr (self .args .pr )
110
125
fields_to_fetch = [
111
126
"baseRefName" ,
127
+ "commits" ,
128
+ "headRefName" ,
129
+ "headRepository" ,
130
+ "headRepositoryOwner" ,
112
131
"reviewDecision" ,
113
- "title " ,
132
+ "state " ,
114
133
"statusCheckRollup" ,
134
+ "title" ,
115
135
"url" ,
116
- "state" ,
117
- "commits" ,
118
136
]
137
+ print (f"> Loading PR { self .args .pr } ..." )
119
138
o = self .run_gh (
120
139
"pr" ,
121
140
["view" , self .args .pr , "--json" , "," .join (fields_to_fetch )],
122
141
)
123
- prdata = json .loads (o )
142
+ self . prdata = json .loads (o )
124
143
125
144
# save the baseRefName (target branch) so that we know where to push
126
- self .target_branch = prdata ["baseRefName" ]
145
+ self .target_branch = self .prdata ["baseRefName" ]
146
+ srepo = self .prdata ["headRepository" ]["name" ]
147
+ sowner = self .prdata ["headRepositoryOwner" ]["login" ]
148
+ self .source_url = f"https://github.com/{ sowner } /{ srepo } "
149
+ self .source_branch = self .prdata ["headRefName" ]
150
+
151
+ if srepo != "llvm-project" :
152
+ print ("The target repo is NOT llvm-project, check the PR!" )
153
+ sys .exit (1 )
154
+
155
+ if sowner == "llvm" :
156
+ print (
157
+ "The source owner should never be github.com/llvm, double check the PR!"
158
+ )
159
+ sys .exit (1 )
127
160
128
- print (f"> Handling PR { self .args .pr } - { prdata ['title' ]} " )
129
- print (f"> { prdata ['url' ]} " )
161
+ def validate_pr (self ):
162
+ print (f"> Handling PR { self .args .pr } - { self .prdata ['title' ]} " )
163
+ print (f"> { self .prdata ['url' ]} " )
130
164
131
165
VALIDATIONS = {
132
166
"state" : self .validate_state ,
@@ -141,7 +175,7 @@ def validate_pr(self):
141
175
total_ok = True
142
176
for val_name , val_func in VALIDATIONS .items ():
143
177
try :
144
- validation_data = val_func (prdata )
178
+ validation_data = val_func (self . prdata )
145
179
except :
146
180
validation_data = False
147
181
ok = None
@@ -166,24 +200,42 @@ def validate_pr(self):
166
200
return total_ok
167
201
168
202
def rebase_pr (self ):
169
- print ("> Rebasing" )
170
- self .run_gh ("pr" , ["update-branch" , "--rebase" , self .args .pr ])
171
- print ("> Waiting for GitHub to update PR" )
172
- time .sleep (4 )
203
+ print ("> Fetching upstream" )
204
+ subprocess .run (["git" , "fetch" , "--all" ], check = True )
205
+ print ("> Rebasing..." )
206
+ subprocess .run (
207
+ ["git" , "rebase" , self .args .upstream + "/" + self .target_branch ], check = True
208
+ )
209
+ print ("> Publish rebase..." )
210
+ subprocess .run (
211
+ ["git" , "push" , "--force" , self .source_url , f"HEAD:{ self .source_branch } " ]
212
+ )
173
213
174
214
def checkout_pr (self ):
175
215
print ("> Fetching PR changes..." )
216
+ self .merge_branch = "llvm_merger_" + self .args .pr
176
217
self .run_gh (
177
218
"pr" ,
178
219
[
179
220
"checkout" ,
180
221
self .args .pr ,
181
222
"--force" ,
182
223
"--branch" ,
183
- "llvm_merger_" + self .args . pr ,
224
+ self .merge_branch ,
184
225
],
185
226
)
186
227
228
+ # get the branch information so that we can use it for
229
+ # pushing later.
230
+ p = subprocess .run (
231
+ ["git" , "config" , f"branch.{ self .merge_branch } .merge" ],
232
+ check = True ,
233
+ capture_output = True ,
234
+ text = True ,
235
+ )
236
+ upstream_branch = p .stdout .strip ().replace ("refs/heads/" , "" )
237
+ print (upstream_branch )
238
+
187
239
def push_upstream (self ):
188
240
print ("> Pushing changes..." )
189
241
subprocess .run (
@@ -201,7 +253,7 @@ def delete_local_branch(self):
201
253
parser = argparse .ArgumentParser ()
202
254
parser .add_argument (
203
255
"pr" ,
204
- help = "The Pull Request ID that should be merged into a release." ,
256
+ help = "The Pull Request ID that should be merged into a release. Can be number or URL " ,
205
257
)
206
258
parser .add_argument (
207
259
"--skip-validation" ,
@@ -224,9 +276,20 @@ def delete_local_branch(self):
224
276
parser .add_argument (
225
277
"--validate-only" , action = "store_true" , help = "Only run the validations."
226
278
)
279
+ parser .add_argument (
280
+ "--rebase-only" , action = "store_true" , help = "Only rebase and exit"
281
+ )
227
282
args = parser .parse_args ()
228
283
229
284
merger = PRMerger (args )
285
+ merger .load_pr_data ()
286
+
287
+ if args .rebase_only :
288
+ merger .checkout_pr ()
289
+ merger .rebase_pr ()
290
+ merger .delete_local_branch ()
291
+ sys .exit (0 )
292
+
230
293
if not merger .validate_pr ():
231
294
print ()
232
295
print (
@@ -239,8 +302,8 @@ def delete_local_branch(self):
239
302
print ("! --validate-only passed, will exit here" )
240
303
sys .exit (0 )
241
304
242
- merger .rebase_pr ()
243
305
merger .checkout_pr ()
306
+ merger .rebase_pr ()
244
307
245
308
if args .no_push :
246
309
print ()
0 commit comments