Skip to content

Commit 9249d47

Browse files
committed
own function to get Diverging of PRs
1 parent 328ecc3 commit 9249d47

File tree

2 files changed

+82
-1
lines changed

2 files changed

+82
-1
lines changed

routers/repo/pull.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ func PrepareViewPullInfo(ctx *context.Context, issue *models.Issue) *git.Compare
342342

343343
setMergeTarget(ctx, pull)
344344

345-
divergence, divergenceError := repofiles.CountDivergingCommits(repo, pull.HeadBranch)
345+
divergence, divergenceError := pull_service.GetDiverging(pull)
346346
if divergenceError != nil {
347347
ctx.ServerError("CountDivergingCommits", divergenceError)
348348
return nil

services/pull/update.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@ package pull
66

77
import (
88
"fmt"
9+
"strconv"
10+
"strings"
911

1012
"code.gitea.io/gitea/models"
13+
"code.gitea.io/gitea/modules/git"
1114
"code.gitea.io/gitea/modules/log"
1215
)
1316

@@ -29,6 +32,17 @@ func Update(pull *models.PullRequest, doer *models.User, message string) (err er
2932
return fmt.Errorf("LoadBaseRepo: %v", err)
3033
}
3134

35+
diffCount, err := GetDiverging(pull)
36+
if err != nil {
37+
return err
38+
} else if diffCount.Behind == 0 {
39+
return fmt.Errorf("HeadBranch of PR %d is up to date", pull.Index)
40+
}
41+
42+
defer func() {
43+
go AddTestPullRequestTask(doer, pr.HeadRepo.ID, pr.HeadBranch, false, "", "")
44+
}()
45+
3246
if err := rawMerge(pr, doer, models.MergeStyleMerge, message); err != nil {
3347
return err
3448
}
@@ -49,3 +63,70 @@ func IsUserAllowedToUpdate(pull *models.PullRequest, p models.Permission, user *
4963
}
5064
return IsUserAllowedToMerge(pr, p, user)
5165
}
66+
67+
// GetDiverging determines how many commits a PR is ahead or behind the PR base branch
68+
func GetDiverging(pr *models.PullRequest) (*git.DivergeObject, error) {
69+
log.Trace("PushToBaseRepo[%d]: pushing commits to base repo '%s'", pr.BaseRepoID, pr.GetGitRefName())
70+
71+
if pr.BaseRepo == nil {
72+
if err := pr.LoadBaseRepo(); err != nil {
73+
return nil, err
74+
}
75+
}
76+
if pr.HeadRepo == nil {
77+
if err := pr.LoadHeadRepo(); err != nil {
78+
return nil, err
79+
}
80+
}
81+
82+
headRepoPath := pr.HeadRepo.RepoPath()
83+
headGitRepo, err := git.OpenRepository(headRepoPath)
84+
if err != nil {
85+
return nil, fmt.Errorf("OpenRepository: %v", err)
86+
}
87+
defer headGitRepo.Close()
88+
89+
if pr.BaseRepoID == pr.HeadRepoID {
90+
diff, err := git.GetDivergingCommits(pr.HeadRepo.RepoPath(), pr.BaseBranch, pr.HeadBranch)
91+
return &diff, err
92+
}
93+
94+
tmpRemoteName := fmt.Sprintf("tmp-pull-%d-base", pr.ID)
95+
if err = headGitRepo.AddRemote(tmpRemoteName, pr.BaseRepo.RepoPath(), true); err != nil {
96+
return nil, fmt.Errorf("headGitRepo.AddRemote: %v", err)
97+
}
98+
// Make sure to remove the remote even if the push fails
99+
defer func() {
100+
if err := headGitRepo.RemoveRemote(tmpRemoteName); err != nil {
101+
log.Error("CountDiverging: RemoveRemote: %s", err)
102+
}
103+
}()
104+
105+
// $(git rev-list --count tmp-pull-1-base/master..feature) commits ahead of master
106+
ahead, errorAhead := checkDivergence(headRepoPath, fmt.Sprintf("%s/%s", tmpRemoteName, pr.BaseBranch), pr.HeadBranch)
107+
if errorAhead != nil {
108+
return &git.DivergeObject{}, errorAhead
109+
}
110+
111+
// $(git rev-list --count feature..tmp-pull-1-base/master) commits behind master
112+
behind, errorBehind := checkDivergence(headRepoPath, pr.HeadBranch, fmt.Sprintf("%s/%s", tmpRemoteName, pr.BaseBranch))
113+
if errorBehind != nil {
114+
return &git.DivergeObject{}, errorBehind
115+
}
116+
117+
return &git.DivergeObject{ahead, behind}, nil
118+
}
119+
120+
func checkDivergence(repoPath string, baseBranch string, targetBranch string) (int, error) {
121+
branches := fmt.Sprintf("%s..%s", baseBranch, targetBranch)
122+
cmd := git.NewCommand("rev-list", "--count", branches)
123+
stdout, err := cmd.RunInDir(repoPath)
124+
if err != nil {
125+
return -1, err
126+
}
127+
outInteger, errInteger := strconv.Atoi(strings.Trim(stdout, "\n"))
128+
if errInteger != nil {
129+
return -1, errInteger
130+
}
131+
return outInteger, nil
132+
}

0 commit comments

Comments
 (0)