Skip to content

Commit 7174000

Browse files
committed
Move RegisterAggregator implementation
The SQLiteConn.RegisterAggregator implementation was defined in sqlite3_trace.go file, which is guarded with a build constraint. This change simply moves RegisterAggregator to the main sqlite3.go file, and moves accompanying unit tests. The rationale for this move is that it was not possible for downstream using packages to use RegisterAggregator without also specifying (and notifying the user) the 'trace' build tag.
1 parent 615c193 commit 7174000

File tree

4 files changed

+187
-201
lines changed

4 files changed

+187
-201
lines changed

sqlite3.go

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ int _sqlite3_create_function(
100100
}
101101
102102
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
103+
void stepTrampoline(sqlite3_context*, int, sqlite3_value**);
104+
void doneTrampoline(sqlite3_context*);
103105
104106
int compareTrampoline(void*, int, char*, int, char*);
105107
int commitHookTrampoline(void*);
@@ -477,6 +479,131 @@ func sqlite3CreateFunction(db *C.sqlite3, zFunctionName *C.char, nArg C.int, eTe
477479
return C._sqlite3_create_function(db, zFunctionName, nArg, eTextRep, C.uintptr_t(pApp), (*[0]byte)(unsafe.Pointer(xFunc)), (*[0]byte)(unsafe.Pointer(xStep)), (*[0]byte)(unsafe.Pointer(xFinal)))
478480
}
479481

482+
// RegisterAggregator makes a Go type available as a SQLite aggregation function.
483+
//
484+
// Because aggregation is incremental, it's implemented in Go with a
485+
// type that has 2 methods: func Step(values) accumulates one row of
486+
// data into the accumulator, and func Done() ret finalizes and
487+
// returns the aggregate value. "values" and "ret" may be any type
488+
// supported by RegisterFunc.
489+
//
490+
// RegisterAggregator takes as implementation a constructor function
491+
// that constructs an instance of the aggregator type each time an
492+
// aggregation begins. The constructor must return a pointer to a
493+
// type, or an interface that implements Step() and Done().
494+
//
495+
// The constructor function and the Step/Done methods may optionally
496+
// return an error in addition to their other return values.
497+
//
498+
// See _example/go_custom_funcs for a detailed example.
499+
func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool) error {
500+
var ai aggInfo
501+
ai.constructor = reflect.ValueOf(impl)
502+
t := ai.constructor.Type()
503+
if t.Kind() != reflect.Func {
504+
return errors.New("non-function passed to RegisterAggregator")
505+
}
506+
if t.NumOut() != 1 && t.NumOut() != 2 {
507+
return errors.New("SQLite aggregator constructors must return 1 or 2 values")
508+
}
509+
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
510+
return errors.New("Second return value of SQLite function must be error")
511+
}
512+
if t.NumIn() != 0 {
513+
return errors.New("SQLite aggregator constructors must not have arguments")
514+
}
515+
516+
agg := t.Out(0)
517+
switch agg.Kind() {
518+
case reflect.Ptr, reflect.Interface:
519+
default:
520+
return errors.New("SQlite aggregator constructor must return a pointer object")
521+
}
522+
stepFn, found := agg.MethodByName("Step")
523+
if !found {
524+
return errors.New("SQlite aggregator doesn't have a Step() function")
525+
}
526+
step := stepFn.Type
527+
if step.NumOut() != 0 && step.NumOut() != 1 {
528+
return errors.New("SQlite aggregator Step() function must return 0 or 1 values")
529+
}
530+
if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
531+
return errors.New("type of SQlite aggregator Step() return value must be error")
532+
}
533+
534+
stepNArgs := step.NumIn()
535+
start := 0
536+
if agg.Kind() == reflect.Ptr {
537+
// Skip over the method receiver
538+
stepNArgs--
539+
start++
540+
}
541+
if step.IsVariadic() {
542+
stepNArgs--
543+
}
544+
for i := start; i < start+stepNArgs; i++ {
545+
conv, err := callbackArg(step.In(i))
546+
if err != nil {
547+
return err
548+
}
549+
ai.stepArgConverters = append(ai.stepArgConverters, conv)
550+
}
551+
if step.IsVariadic() {
552+
conv, err := callbackArg(t.In(start + stepNArgs).Elem())
553+
if err != nil {
554+
return err
555+
}
556+
ai.stepVariadicConverter = conv
557+
// Pass -1 to sqlite so that it allows any number of
558+
// arguments. The call helper verifies that the minimum number
559+
// of arguments is present for variadic functions.
560+
stepNArgs = -1
561+
}
562+
563+
doneFn, found := agg.MethodByName("Done")
564+
if !found {
565+
return errors.New("SQlite aggregator doesn't have a Done() function")
566+
}
567+
done := doneFn.Type
568+
doneNArgs := done.NumIn()
569+
if agg.Kind() == reflect.Ptr {
570+
// Skip over the method receiver
571+
doneNArgs--
572+
}
573+
if doneNArgs != 0 {
574+
return errors.New("SQlite aggregator Done() function must have no arguments")
575+
}
576+
if done.NumOut() != 1 && done.NumOut() != 2 {
577+
return errors.New("SQLite aggregator Done() function must return 1 or 2 values")
578+
}
579+
if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
580+
return errors.New("second return value of SQLite aggregator Done() function must be error")
581+
}
582+
583+
conv, err := callbackRet(done.Out(0))
584+
if err != nil {
585+
return err
586+
}
587+
ai.doneRetConverter = conv
588+
ai.active = make(map[int64]reflect.Value)
589+
ai.next = 1
590+
591+
// ai must outlast the database connection, or we'll have dangling pointers.
592+
c.aggregators = append(c.aggregators, &ai)
593+
594+
cname := C.CString(name)
595+
defer C.free(unsafe.Pointer(cname))
596+
opts := C.SQLITE_UTF8
597+
if pure {
598+
opts |= C.SQLITE_DETERMINISTIC
599+
}
600+
rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline)
601+
if rv != C.SQLITE_OK {
602+
return c.lastError()
603+
}
604+
return nil
605+
}
606+
480607
// AutoCommit return which currently auto commit or not.
481608
func (c *SQLiteConn) AutoCommit() bool {
482609
return int(C.sqlite3_get_autocommit(c.db)) != 0

sqlite3_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,6 +1232,66 @@ func TestFunctionRegistration(t *testing.T) {
12321232
}
12331233
}
12341234

1235+
type sumAggregator int64
1236+
1237+
func (s *sumAggregator) Step(x int64) {
1238+
*s += sumAggregator(x)
1239+
}
1240+
1241+
func (s *sumAggregator) Done() int64 {
1242+
return int64(*s)
1243+
}
1244+
1245+
func TestAggregatorRegistration(t *testing.T) {
1246+
customSum := func() *sumAggregator {
1247+
var ret sumAggregator
1248+
return &ret
1249+
}
1250+
1251+
sql.Register("sqlite3_AggregatorRegistration", &SQLiteDriver{
1252+
ConnectHook: func(conn *SQLiteConn) error {
1253+
if err := conn.RegisterAggregator("customSum", customSum, true); err != nil {
1254+
return err
1255+
}
1256+
return nil
1257+
},
1258+
})
1259+
db, err := sql.Open("sqlite3_AggregatorRegistration", ":memory:")
1260+
if err != nil {
1261+
t.Fatal("Failed to open database:", err)
1262+
}
1263+
defer db.Close()
1264+
1265+
_, err = db.Exec("create table foo (department integer, profits integer)")
1266+
if err != nil {
1267+
// trace feature is not implemented
1268+
t.Skip("Failed to create table:", err)
1269+
}
1270+
1271+
_, err = db.Exec("insert into foo values (1, 10), (1, 20), (2, 42)")
1272+
if err != nil {
1273+
t.Fatal("Failed to insert records:", err)
1274+
}
1275+
1276+
tests := []struct {
1277+
dept, sum int64
1278+
}{
1279+
{1, 30},
1280+
{2, 42},
1281+
}
1282+
1283+
for _, test := range tests {
1284+
var ret int64
1285+
err = db.QueryRow("select customSum(profits) from foo where department = $1 group by department", test.dept).Scan(&ret)
1286+
if err != nil {
1287+
t.Fatal("Query failed:", err)
1288+
}
1289+
if ret != test.sum {
1290+
t.Fatalf("Custom sum returned wrong value, got %d, want %d", ret, test.sum)
1291+
}
1292+
}
1293+
}
1294+
12351295
func rot13(r rune) rune {
12361296
switch {
12371297
case r >= 'A' && r <= 'Z':

sqlite3_trace.go

Lines changed: 0 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,12 @@ package sqlite3
1414
#endif
1515
#include <stdlib.h>
1616
17-
void stepTrampoline(sqlite3_context*, int, sqlite3_value**);
18-
void doneTrampoline(sqlite3_context*);
1917
int traceCallbackTrampoline(unsigned int traceEventCode, void *ctx, void *p, void *x);
2018
*/
2119
import "C"
2220

2321
import (
24-
"errors"
2522
"fmt"
26-
"reflect"
2723
"strings"
2824
"sync"
2925
"unsafe"
@@ -239,131 +235,6 @@ func popTraceMapping(connHandle uintptr) (TraceConfig, bool) {
239235
return entryCopy.config, found
240236
}
241237

242-
// RegisterAggregator makes a Go type available as a SQLite aggregation function.
243-
//
244-
// Because aggregation is incremental, it's implemented in Go with a
245-
// type that has 2 methods: func Step(values) accumulates one row of
246-
// data into the accumulator, and func Done() ret finalizes and
247-
// returns the aggregate value. "values" and "ret" may be any type
248-
// supported by RegisterFunc.
249-
//
250-
// RegisterAggregator takes as implementation a constructor function
251-
// that constructs an instance of the aggregator type each time an
252-
// aggregation begins. The constructor must return a pointer to a
253-
// type, or an interface that implements Step() and Done().
254-
//
255-
// The constructor function and the Step/Done methods may optionally
256-
// return an error in addition to their other return values.
257-
//
258-
// See _example/go_custom_funcs for a detailed example.
259-
func (c *SQLiteConn) RegisterAggregator(name string, impl interface{}, pure bool) error {
260-
var ai aggInfo
261-
ai.constructor = reflect.ValueOf(impl)
262-
t := ai.constructor.Type()
263-
if t.Kind() != reflect.Func {
264-
return errors.New("non-function passed to RegisterAggregator")
265-
}
266-
if t.NumOut() != 1 && t.NumOut() != 2 {
267-
return errors.New("SQLite aggregator constructors must return 1 or 2 values")
268-
}
269-
if t.NumOut() == 2 && !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
270-
return errors.New("Second return value of SQLite function must be error")
271-
}
272-
if t.NumIn() != 0 {
273-
return errors.New("SQLite aggregator constructors must not have arguments")
274-
}
275-
276-
agg := t.Out(0)
277-
switch agg.Kind() {
278-
case reflect.Ptr, reflect.Interface:
279-
default:
280-
return errors.New("SQlite aggregator constructor must return a pointer object")
281-
}
282-
stepFn, found := agg.MethodByName("Step")
283-
if !found {
284-
return errors.New("SQlite aggregator doesn't have a Step() function")
285-
}
286-
step := stepFn.Type
287-
if step.NumOut() != 0 && step.NumOut() != 1 {
288-
return errors.New("SQlite aggregator Step() function must return 0 or 1 values")
289-
}
290-
if step.NumOut() == 1 && !step.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
291-
return errors.New("type of SQlite aggregator Step() return value must be error")
292-
}
293-
294-
stepNArgs := step.NumIn()
295-
start := 0
296-
if agg.Kind() == reflect.Ptr {
297-
// Skip over the method receiver
298-
stepNArgs--
299-
start++
300-
}
301-
if step.IsVariadic() {
302-
stepNArgs--
303-
}
304-
for i := start; i < start+stepNArgs; i++ {
305-
conv, err := callbackArg(step.In(i))
306-
if err != nil {
307-
return err
308-
}
309-
ai.stepArgConverters = append(ai.stepArgConverters, conv)
310-
}
311-
if step.IsVariadic() {
312-
conv, err := callbackArg(t.In(start + stepNArgs).Elem())
313-
if err != nil {
314-
return err
315-
}
316-
ai.stepVariadicConverter = conv
317-
// Pass -1 to sqlite so that it allows any number of
318-
// arguments. The call helper verifies that the minimum number
319-
// of arguments is present for variadic functions.
320-
stepNArgs = -1
321-
}
322-
323-
doneFn, found := agg.MethodByName("Done")
324-
if !found {
325-
return errors.New("SQlite aggregator doesn't have a Done() function")
326-
}
327-
done := doneFn.Type
328-
doneNArgs := done.NumIn()
329-
if agg.Kind() == reflect.Ptr {
330-
// Skip over the method receiver
331-
doneNArgs--
332-
}
333-
if doneNArgs != 0 {
334-
return errors.New("SQlite aggregator Done() function must have no arguments")
335-
}
336-
if done.NumOut() != 1 && done.NumOut() != 2 {
337-
return errors.New("SQLite aggregator Done() function must return 1 or 2 values")
338-
}
339-
if done.NumOut() == 2 && !done.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
340-
return errors.New("second return value of SQLite aggregator Done() function must be error")
341-
}
342-
343-
conv, err := callbackRet(done.Out(0))
344-
if err != nil {
345-
return err
346-
}
347-
ai.doneRetConverter = conv
348-
ai.active = make(map[int64]reflect.Value)
349-
ai.next = 1
350-
351-
// ai must outlast the database connection, or we'll have dangling pointers.
352-
c.aggregators = append(c.aggregators, &ai)
353-
354-
cname := C.CString(name)
355-
defer C.free(unsafe.Pointer(cname))
356-
opts := C.SQLITE_UTF8
357-
if pure {
358-
opts |= C.SQLITE_DETERMINISTIC
359-
}
360-
rv := sqlite3CreateFunction(c.db, cname, C.int(stepNArgs), C.int(opts), newHandle(c, &ai), nil, C.stepTrampoline, C.doneTrampoline)
361-
if rv != C.SQLITE_OK {
362-
return c.lastError()
363-
}
364-
return nil
365-
}
366-
367238
// SetTrace installs or removes the trace callback for the given database connection.
368239
// It's not named 'RegisterTrace' because only one callback can be kept and called.
369240
// Calling SetTrace a second time on same database connection

0 commit comments

Comments
 (0)