Skip to content

Commit 0430b37

Browse files
committed
Add support for collation sequences implemented in Go.
This allows Go programs to register custom comparison functions with sqlite, and ORDER BY that comparator.
1 parent 83772a7 commit 0430b37

File tree

3 files changed

+152
-0
lines changed

3 files changed

+152
-0
lines changed

callback.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ func doneTrampoline(ctx *C.sqlite3_context) {
5353
ai.Done(ctx)
5454
}
5555

56+
//export compareTrampoline
57+
func compareTrampoline(handlePtr uintptr, la C.int, a *C.char, lb C.int, b *C.char) C.int {
58+
cmp := lookupHandle(handlePtr).(func(string, string) int)
59+
return C.int(cmp(C.GoStringN(a, la), C.GoStringN(b, lb)))
60+
}
61+
5662
// Use handles to avoid passing Go pointers to C.
5763

5864
type handleVal struct {

sqlite3.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ int _sqlite3_create_function(
100100
}
101101
102102
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
103+
104+
int compareTrampoline(void*, int, char*, int, char*);
103105
*/
104106
import "C"
105107
import (
@@ -313,6 +315,29 @@ func (tx *SQLiteTx) Rollback() error {
313315
return err
314316
}
315317

318+
// RegisterCollation makes a Go function available as a collation.
319+
//
320+
// cmp receives two UTF-8 strings, a and b. The result should be 0 if
321+
// a==b, -1 if a < b, and +1 if a > b.
322+
//
323+
// cmp must always return the same result given the same
324+
// inputs. Additionally, it must have the following properties for all
325+
// strings A, B and C: if A==B then B==A; if A==B and B==C then A==C;
326+
// if A<B then B>A; if A<B and B<C then A<C.
327+
//
328+
// If cmp does not obey these constraints, sqlite3's behavior is
329+
// undefined when the collation is used.
330+
func (c *SQLiteConn) RegisterCollation(name string, cmp func(string, string) int) error {
331+
handle := newHandle(c, cmp)
332+
cname := C.CString(name)
333+
defer C.free(unsafe.Pointer(cname))
334+
rv := C.sqlite3_create_collation(c.db, cname, C.SQLITE_UTF8, unsafe.Pointer(handle), (*[0]byte)(unsafe.Pointer(C.compareTrampoline)))
335+
if rv != C.SQLITE_OK {
336+
return c.lastError()
337+
}
338+
return nil
339+
}
340+
316341
// RegisterFunc makes a Go function available as a SQLite function.
317342
//
318343
// The Go function can have arguments of the following types: any

sqlite3_test.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,127 @@ func TestFunctionRegistration(t *testing.T) {
12131213
}
12141214
}
12151215

1216+
func rot13(r rune) rune {
1217+
switch {
1218+
case r >= 'A' && r <= 'Z':
1219+
return 'A' + (r-'A'+13)%26
1220+
case r >= 'a' && r <= 'z':
1221+
return 'a' + (r-'a'+13)%26
1222+
}
1223+
return r
1224+
}
1225+
1226+
func TestCollationRegistration(t *testing.T) {
1227+
collateRot13 := func(a, b string) int {
1228+
ra, rb := strings.Map(rot13, a), strings.Map(rot13, b)
1229+
return strings.Compare(ra, rb)
1230+
}
1231+
collateRot13Reverse := func(a, b string) int {
1232+
return collateRot13(b, a)
1233+
}
1234+
1235+
sql.Register("sqlite3_CollationRegistration", &SQLiteDriver{
1236+
ConnectHook: func(conn *SQLiteConn) error {
1237+
if err := conn.RegisterCollation("rot13", collateRot13); err != nil {
1238+
return err
1239+
}
1240+
if err := conn.RegisterCollation("rot13reverse", collateRot13Reverse); err != nil {
1241+
return err
1242+
}
1243+
return nil
1244+
},
1245+
})
1246+
1247+
db, err := sql.Open("sqlite3_CollationRegistration", ":memory:")
1248+
if err != nil {
1249+
t.Fatal("Failed to open database:", err)
1250+
}
1251+
defer db.Close()
1252+
1253+
populate := []string{
1254+
`CREATE TABLE test (s TEXT)`,
1255+
`INSERT INTO test VALUES ("aaaa")`,
1256+
`INSERT INTO test VALUES ("ffff")`,
1257+
`INSERT INTO test VALUES ("qqqq")`,
1258+
`INSERT INTO test VALUES ("tttt")`,
1259+
`INSERT INTO test VALUES ("zzzz")`,
1260+
}
1261+
for _, stmt := range populate {
1262+
if _, err := db.Exec(stmt); err != nil {
1263+
t.Fatal("Failed to populate test DB:", err)
1264+
}
1265+
}
1266+
1267+
ops := []struct {
1268+
query string
1269+
want []string
1270+
}{
1271+
{
1272+
"SELECT * FROM test ORDER BY s COLLATE rot13 ASC",
1273+
[]string{
1274+
"qqqq",
1275+
"tttt",
1276+
"zzzz",
1277+
"aaaa",
1278+
"ffff",
1279+
},
1280+
},
1281+
{
1282+
"SELECT * FROM test ORDER BY s COLLATE rot13 DESC",
1283+
[]string{
1284+
"ffff",
1285+
"aaaa",
1286+
"zzzz",
1287+
"tttt",
1288+
"qqqq",
1289+
},
1290+
},
1291+
{
1292+
"SELECT * FROM test ORDER BY s COLLATE rot13reverse ASC",
1293+
[]string{
1294+
"ffff",
1295+
"aaaa",
1296+
"zzzz",
1297+
"tttt",
1298+
"qqqq",
1299+
},
1300+
},
1301+
{
1302+
"SELECT * FROM test ORDER BY s COLLATE rot13reverse DESC",
1303+
[]string{
1304+
"qqqq",
1305+
"tttt",
1306+
"zzzz",
1307+
"aaaa",
1308+
"ffff",
1309+
},
1310+
},
1311+
}
1312+
1313+
for _, op := range ops {
1314+
rows, err := db.Query(op.query)
1315+
if err != nil {
1316+
t.Fatalf("Query %q failed: %s", op.query, err)
1317+
}
1318+
got := []string{}
1319+
defer rows.Close()
1320+
for rows.Next() {
1321+
var s string
1322+
if err = rows.Scan(&s); err != nil {
1323+
t.Fatalf("Reading row for %q: %s", op.query, err)
1324+
}
1325+
got = append(got, s)
1326+
}
1327+
if err = rows.Err(); err != nil {
1328+
t.Fatalf("Reading rows for %q: %s", op.query, err)
1329+
}
1330+
1331+
if !reflect.DeepEqual(got, op.want) {
1332+
t.Fatalf("Unexpected output from %q\ngot:\n%s\n\nwant:\n%s", op.query, strings.Join(got, "\n"), strings.Join(op.want, "\n"))
1333+
}
1334+
}
1335+
}
1336+
12161337
func TestDeclTypes(t *testing.T) {
12171338

12181339
d := SQLiteDriver{}

0 commit comments

Comments
 (0)