Skip to content
This repository was archived by the owner on Apr 12, 2019. It is now read-only.

Commit d7487da

Browse files
authored
Add method to allow custom sorting of tree entry list (#81)
* Add method to allow custom sorting of tree entry list * Add tests for git tree entries sorting
1 parent e8ae926 commit d7487da

File tree

2 files changed

+71
-12
lines changed

2 files changed

+71
-12
lines changed

tree_entry.go

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,35 +116,51 @@ func (te *TreeEntry) GetSubJumpablePathName() string {
116116
// Entries a list of entry
117117
type Entries []*TreeEntry
118118

119-
var sorter = []func(t1, t2 *TreeEntry) bool{
120-
func(t1, t2 *TreeEntry) bool {
119+
type customSortableEntries struct {
120+
Comparer func(s1, s2 string) bool
121+
Entries
122+
}
123+
124+
var sorter = []func(t1, t2 *TreeEntry, cmp func(s1, s2 string) bool) bool{
125+
func(t1, t2 *TreeEntry, cmp func(s1, s2 string) bool) bool {
121126
return (t1.IsDir() || t1.IsSubModule()) && !t2.IsDir() && !t2.IsSubModule()
122127
},
123-
func(t1, t2 *TreeEntry) bool {
124-
return t1.name < t2.name
128+
func(t1, t2 *TreeEntry, cmp func(s1, s2 string) bool) bool {
129+
return cmp(t1.name, t2.name)
125130
},
126131
}
127132

128-
func (tes Entries) Len() int { return len(tes) }
129-
func (tes Entries) Swap(i, j int) { tes[i], tes[j] = tes[j], tes[i] }
130-
func (tes Entries) Less(i, j int) bool {
131-
t1, t2 := tes[i], tes[j]
133+
func (ctes customSortableEntries) Len() int { return len(ctes.Entries) }
134+
135+
func (ctes customSortableEntries) Swap(i, j int) {
136+
ctes.Entries[i], ctes.Entries[j] = ctes.Entries[j], ctes.Entries[i]
137+
}
138+
139+
func (ctes customSortableEntries) Less(i, j int) bool {
140+
t1, t2 := ctes.Entries[i], ctes.Entries[j]
132141
var k int
133142
for k = 0; k < len(sorter)-1; k++ {
134143
s := sorter[k]
135144
switch {
136-
case s(t1, t2):
145+
case s(t1, t2, ctes.Comparer):
137146
return true
138-
case s(t2, t1):
147+
case s(t2, t1, ctes.Comparer):
139148
return false
140149
}
141150
}
142-
return sorter[k](t1, t2)
151+
return sorter[k](t1, t2, ctes.Comparer)
143152
}
144153

145154
// Sort sort the list of entry
146155
func (tes Entries) Sort() {
147-
sort.Sort(tes)
156+
sort.Sort(customSortableEntries{func(s1, s2 string) bool {
157+
return s1 < s2
158+
}, tes})
159+
}
160+
161+
// CustomSort customizable string comparing sort entry list
162+
func (tes Entries) CustomSort(cmp func(s1, s2 string) bool) {
163+
sort.Sort(customSortableEntries{cmp, tes})
148164
}
149165

150166
type commitInfo struct {

tree_entry_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
"path/filepath"
1010
"testing"
1111
"time"
12+
13+
"github.com/stretchr/testify/assert"
1214
)
1315

1416
const benchmarkReposDir = "benchmark/repos/"
@@ -61,3 +63,44 @@ func BenchmarkEntries_GetCommitsInfo(b *testing.B) {
6163
})
6264
}
6365
}
66+
67+
func getTestEntries() Entries {
68+
return Entries{
69+
&TreeEntry{name: "v1.0", mode: EntryModeTree},
70+
&TreeEntry{name: "v2.0", mode: EntryModeTree},
71+
&TreeEntry{name: "v2.1", mode: EntryModeTree},
72+
&TreeEntry{name: "v2.12", mode: EntryModeTree},
73+
&TreeEntry{name: "v2.2", mode: EntryModeTree},
74+
&TreeEntry{name: "v12.0", mode: EntryModeTree},
75+
&TreeEntry{name: "abc", mode: EntryModeBlob},
76+
&TreeEntry{name: "bcd", mode: EntryModeBlob},
77+
}
78+
}
79+
80+
func TestEntriesSort(t *testing.T) {
81+
entries := getTestEntries()
82+
entries.Sort()
83+
assert.Equal(t, "v1.0", entries[0].Name())
84+
assert.Equal(t, "v12.0", entries[1].Name())
85+
assert.Equal(t, "v2.0", entries[2].Name())
86+
assert.Equal(t, "v2.1", entries[3].Name())
87+
assert.Equal(t, "v2.12", entries[4].Name())
88+
assert.Equal(t, "v2.2", entries[5].Name())
89+
assert.Equal(t, "abc", entries[6].Name())
90+
assert.Equal(t, "bcd", entries[7].Name())
91+
}
92+
93+
func TestEntriesCustomSort(t *testing.T) {
94+
entries := getTestEntries()
95+
entries.CustomSort(func(s1, s2 string) bool {
96+
return s1 > s2
97+
})
98+
assert.Equal(t, "v2.2", entries[0].Name())
99+
assert.Equal(t, "v2.12", entries[1].Name())
100+
assert.Equal(t, "v2.1", entries[2].Name())
101+
assert.Equal(t, "v2.0", entries[3].Name())
102+
assert.Equal(t, "v12.0", entries[4].Name())
103+
assert.Equal(t, "v1.0", entries[5].Name())
104+
assert.Equal(t, "bcd", entries[6].Name())
105+
assert.Equal(t, "abc", entries[7].Name())
106+
}

0 commit comments

Comments
 (0)