Skip to content

Commit ed69081

Browse files
authored
Merge pull request #479 from kenshaw/move-registeraggregator
Move RegisterAggregator implementation
2 parents 5349436 + 7174000 commit ed69081

File tree

4 files changed

+187
-201
lines changed

4 files changed

+187
-201
lines changed

sqlite3.go

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ int _sqlite3_create_function(
100100
}
101101
102102
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
103+
void stepTrampoline(sqlite3_context*, int, sqlite3_value**);
104+
void doneTrampoline(sqlite3_context*);
103105
104106
int compareTrampoline(void*, int, char*, int, char*);
105107
int commitHookTrampoline(void*);
@@ -503,6 +505,131 @@ func sqlite3CreateFunction(db *C.sqlite3, zFunctionName *C.char, nArg C.int, eTe
503505
return C._sqlite3_create_function(db, zFunctionName, nArg, eTextRep, C.uintptr_t(pApp), (*[0]byte)(unsafe.Pointer(xFunc)), (*[0]byte)(unsafe.Pointer(xStep)), (*[0]byte)(unsafe.Pointer(xFinal)))
504506
}
505507

508+
// RegisterAggregator makes a Go type available as a SQLite aggregation function.
509+
//
510+
// Because aggregation is incremental, it's implemented in Go with a
511+
// type that has 2 methods: func Step(values) accumulates one row of
512+
// data into the accumulator, and func Done() ret finalizes and
513+
// returns the aggregate value. "values" and "ret" may be any type
514+
// supported by RegisterFunc.
515+
//
516+
// RegisterAggregator takes as implementation a constructor function
517+
// that constructs an instance of the aggregator type each time an
518+
// aggregation begins. The constructor must return a pointer to a
519+
// type, or an interface that implements Step() and Done().
520+
//
521+
// The constructor function and the Step/Done methods may optionally
522+
// return an error in addition to their other return values.
523+
//
524+
// See _example/go_custom_funcs for a detailed example.
525+
func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool) error {
526+
var ai aggInfo
527+
ai.constructor = reflect.ValueOf(impl)
528+
t := ai.constructor.Type()
529+
if t.Kind() != reflect.Func {
530+
return errors.New("non-function passed to RegisterAggregator")
531+
}
532+
if t.NumOut() != 1 && t.NumOut() != 2 {
533+
return errors.New("SQLite aggregator constructors must return 1 or 2 values")
534+
}
535+
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
536+
return errors.New("Second return value of SQLite function must be error")
537+
}
538+
if t.NumIn() != 0 {
539+
return errors.New("SQLite aggregator constructors must not have arguments")
540+
}
541+
542+
agg := t.Out(0)
543+
switch agg.Kind() {
544+
case reflect.Ptr, reflect.Interface:
545+
default:
546+
return errors.New("SQlite aggregator constructor must return a pointer object")
547+
}
548+
stepFn, found := agg.MethodByName("Step")
549+
if !found {
550+
return errors.New("SQlite aggregator doesn't have a Step() function")
551+
}
552+
step := stepFn.Type
553+
if step.NumOut() != 0 && step.NumOut() != 1 {
554+
return errors.New("SQlite aggregator Step() function must return 0 or 1 values")
555+
}
556+
if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
557+
return errors.New("type of SQlite aggregator Step() return value must be error")
558+
}
559+
560+
stepNArgs := step.NumIn()
561+
start := 0
562+
if agg.Kind() == reflect.Ptr {
563+
// Skip over the method receiver
564+
stepNArgs--
565+
start++
566+
}
567+
if step.IsVariadic() {
568+
stepNArgs--
569+
}
570+
for i := start; i < start+stepNArgs; i++ {
571+
conv, err := callbackArg(step.In(i))
572+
if err != nil {
573+
return err
574+
}
575+
ai.stepArgConverters = append(ai.stepArgConverters, conv)
576+
}
577+
if step.IsVariadic() {
578+
conv, err := callbackArg(t.In(start + stepNArgs).Elem())
579+
if err != nil {
580+
return err
581+
}
582+
ai.stepVariadicConverter = conv
583+
// Pass -1 to sqlite so that it allows any number of
584+
// arguments. The call helper verifies that the minimum number
585+
// of arguments is present for variadic functions.
586+
stepNArgs = -1
587+
}
588+
589+
doneFn, found := agg.MethodByName("Done")
590+
if !found {
591+
return errors.New("SQlite aggregator doesn't have a Done() function")
592+
}
593+
done := doneFn.Type
594+
doneNArgs := done.NumIn()
595+
if agg.Kind() == reflect.Ptr {
596+
// Skip over the method receiver
597+
doneNArgs--
598+
}
599+
if doneNArgs != 0 {
600+
return errors.New("SQlite aggregator Done() function must have no arguments")
601+
}
602+
if done.NumOut() != 1 && done.NumOut() != 2 {
603+
return errors.New("SQLite aggregator Done() function must return 1 or 2 values")
604+
}
605+
if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
606+
return errors.New("second return value of SQLite aggregator Done() function must be error")
607+
}
608+
609+
conv, err := callbackRet(done.Out(0))
610+
if err != nil {
611+
return err
612+
}
613+
ai.doneRetConverter = conv
614+
ai.active = make(map[int64]reflect.Value)
615+
ai.next = 1
616+
617+
// ai must outlast the database connection, or we'll have dangling pointers.
618+
c.aggregators = append(c.aggregators, &ai)
619+
620+
cname := C.CString(name)
621+
defer C.free(unsafe.Pointer(cname))
622+
opts := C.SQLITE_UTF8
623+
if pure {
624+
opts |= C.SQLITE_DETERMINISTIC
625+
}
626+
rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline)
627+
if rv != C.SQLITE_OK {
628+
return c.lastError()
629+
}
630+
return nil
631+
}
632+
506633
// AutoCommit return which currently auto commit or not.
507634
func (c *SQLiteConn) AutoCommit() bool {
508635
return int(C.sqlite3_get_autocommit(c.db)) != 0

sqlite3_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,6 +1232,66 @@ func TestFunctionRegistration(t *testing.T) {
12321232
}
12331233
}
12341234

1235+
type sumAggregator int64
1236+
1237+
func (s *sumAggregator) Step(x int64) {
1238+
*s += sumAggregator(x)
1239+
}
1240+
1241+
func (s *sumAggregator) Done() int64 {
1242+
return int64(*s)
1243+
}
1244+
1245+
func TestAggregatorRegistration(t *testing.T) {
1246+
customSum := func() *sumAggregator {
1247+
var ret sumAggregator
1248+
return &ret
1249+
}
1250+
1251+
sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{
1252+
ConnectHook: func(conn *SQLiteConn) error {
1253+
if err := conn.RegisterAggregator("customSum", customSum, true); err != nil {
1254+
return err
1255+
}
1256+
return nil
1257+
},
1258+
})
1259+
db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:")
1260+
if err != nil {
1261+
t.Fatal("Failed to open database:", err)
1262+
}
1263+
defer db.Close()
1264+
1265+
_, err = db.Exec("create table foo (department integer, profits integer)")
1266+
if err != nil {
1267+
// trace feature is not implemented
1268+
t.Skip("Failed to create table:", err)
1269+
}
1270+
1271+
_, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)")
1272+
if err != nil {
1273+
t.Fatal("Failed to insert records:", err)
1274+
}
1275+
1276+
tests := []struct {
1277+
dept, sum int64
1278+
}{
1279+
{1, 30},
1280+
{2, 42},
1281+
}
1282+
1283+
for _, test := range tests {
1284+
var ret int64
1285+
err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret)
1286+
if err != nil {
1287+
t.Fatal("Query failed:", err)
1288+
}
1289+
if ret != test.sum {
1290+
t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum)
1291+
}
1292+
}
1293+
}
1294+
12351295
func rot13(r rune) rune {
12361296
switch {
12371297
case r >= 'A' && r <= 'Z':

sqlite3_trace.go

Lines changed: 0 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,12 @@ package sqlite3
1414
#endif
1515
#include <stdlib.h>
1616
17-
void stepTrampoline(sqlite3_context*, int, sqlite3_value**);
18-
void doneTrampoline(sqlite3_context*);
1917
int traceCallbackTrampoline(unsigned int traceEventCode, void *ctx, void *p, void *x);
2018
*/
2119
import "C"
2220

2321
import (
24-
"errors"
2522
"fmt"
26-
"reflect"
2723
"strings"
2824
"sync"
2925
"unsafe"
@@ -239,131 +235,6 @@ func popTraceMapping(connHandle uintptr) (TraceConfig, bool) {
239235
return entryCopy.config, found
240236
}
241237

242-
// RegisterAggregator makes a Go type available as a SQLite aggregation function.
243-
//
244-
// Because aggregation is incremental, it's implemented in Go with a
245-
// type that has 2 methods: func Step(values) accumulates one row of
246-
// data into the accumulator, and func Done() ret finalizes and
247-
// returns the aggregate value. "values" and "ret" may be any type
248-
// supported by RegisterFunc.
249-
//
250-
// RegisterAggregator takes as implementation a constructor function
251-
// that constructs an instance of the aggregator type each time an
252-
// aggregation begins. The constructor must return a pointer to a
253-
// type, or an interface that implements Step() and Done().
254-
//
255-
// The constructor function and the Step/Done methods may optionally
256-
// return an error in addition to their other return values.
257-
//
258-
// See _example/go_custom_funcs for a detailed example.
259-
func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool) error {
260-
var ai aggInfo
261-
ai.constructor = reflect.ValueOf(impl)
262-
t := ai.constructor.Type()
263-
if t.Kind() != reflect.Func {
264-
return errors.New("non-function passed to RegisterAggregator")
265-
}
266-
if t.NumOut() != 1 && t.NumOut() != 2 {
267-
return errors.New("SQLite aggregator constructors must return 1 or 2 values")
268-
}
269-
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
270-
return errors.New("Second return value of SQLite function must be error")
271-
}
272-
if t.NumIn() != 0 {
273-
return errors.New("SQLite aggregator constructors must not have arguments")
274-
}
275-
276-
agg := t.Out(0)
277-
switch agg.Kind() {
278-
case reflect.Ptr, reflect.Interface:
279-
default:
280-
return errors.New("SQlite aggregator constructor must return a pointer object")
281-
}
282-
stepFn, found := agg.MethodByName("Step")
283-
if !found {
284-
return errors.New("SQlite aggregator doesn't have a Step() function")
285-
}
286-
step := stepFn.Type
287-
if step.NumOut() != 0 && step.NumOut() != 1 {
288-
return errors.New("SQlite aggregator Step() function must return 0 or 1 values")
289-
}
290-
if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
291-
return errors.New("type of SQlite aggregator Step() return value must be error")
292-
}
293-
294-
stepNArgs := step.NumIn()
295-
start := 0
296-
if agg.Kind() == reflect.Ptr {
297-
// Skip over the method receiver
298-
stepNArgs--
299-
start++
300-
}
301-
if step.IsVariadic() {
302-
stepNArgs--
303-
}
304-
for i := start; i < start+stepNArgs; i++ {
305-
conv, err := callbackArg(step.In(i))
306-
if err != nil {
307-
return err
308-
}
309-
ai.stepArgConverters = append(ai.stepArgConverters, conv)
310-
}
311-
if step.IsVariadic() {
312-
conv, err := callbackArg(t.In(start + stepNArgs).Elem())
313-
if err != nil {
314-
return err
315-
}
316-
ai.stepVariadicConverter = conv
317-
// Pass -1 to sqlite so that it allows any number of
318-
// arguments. The call helper verifies that the minimum number
319-
// of arguments is present for variadic functions.
320-
stepNArgs = -1
321-
}
322-
323-
doneFn, found := agg.MethodByName("Done")
324-
if !found {
325-
return errors.New("SQlite aggregator doesn't have a Done() function")
326-
}
327-
done := doneFn.Type
328-
doneNArgs := done.NumIn()
329-
if agg.Kind() == reflect.Ptr {
330-
// Skip over the method receiver
331-
doneNArgs--
332-
}
333-
if doneNArgs != 0 {
334-
return errors.New("SQlite aggregator Done() function must have no arguments")
335-
}
336-
if done.NumOut() != 1 && done.NumOut() != 2 {
337-
return errors.New("SQLite aggregator Done() function must return 1 or 2 values")
338-
}
339-
if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
340-
return errors.New("second return value of SQLite aggregator Done() function must be error")
341-
}
342-
343-
conv, err := callbackRet(done.Out(0))
344-
if err != nil {
345-
return err
346-
}
347-
ai.doneRetConverter = conv
348-
ai.active = make(map[int64]reflect.Value)
349-
ai.next = 1
350-
351-
// ai must outlast the database connection, or we'll have dangling pointers.
352-
c.aggregators = append(c.aggregators, &ai)
353-
354-
cname := C.CString(name)
355-
defer C.free(unsafe.Pointer(cname))
356-
opts := C.SQLITE_UTF8
357-
if pure {
358-
opts |= C.SQLITE_DETERMINISTIC
359-
}
360-
rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline)
361-
if rv != C.SQLITE_OK {
362-
return c.lastError()
363-
}
364-
return nil
365-
}
366-
367238
// SetTrace installs or removes the trace callback for the given database connection.
368239
// It's not named 'RegisterTrace' because only one callback can be kept and called.
369240
// Calling SetTrace a second time on same database connection

0 commit comments

Comments
 (0)