Skip to content

Commit 41c52a3

Browse files
committed
Add support for pre-update hooks
1 parent 696e2e4 commit 41c52a3

File tree

4 files changed

+532
-2
lines changed

4 files changed

+532
-2
lines changed

callback.go

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,21 @@ func updateHookTrampoline(handle uintptr, op int, db *C.char, table *C.char, row
7777
callback(op, C.GoString(db), C.GoString(table), rowid)
7878
}
7979

80+
//export preUpdateHookTrampoline
81+
func preUpdateHookTrampoline(handle uintptr, dbHandle uintptr, op int, db *C.char, table *C.char, oldrowid int64, newrowid int64) {
82+
hval := lookupHandleVal(handle)
83+
data := SQLitePreUpdateData{
84+
Conn: hval.db,
85+
Op: op,
86+
DatabaseName: C.GoString(db),
87+
TableName: C.GoString(table),
88+
OldRowID: oldrowid,
89+
NewRowID: newrowid,
90+
}
91+
callback := hval.val.(func(SQLitePreUpdateData))
92+
callback(data)
93+
}
94+
8095
// Use handles to avoid passing Go pointers to C.
8196

8297
type handleVal struct {
@@ -97,7 +112,7 @@ func newHandle(db *SQLiteConn, v interface{}) uintptr {
97112
return i
98113
}
99114

100-
func lookupHandle(handle uintptr) interface{} {
115+
func lookupHandleVal(handle uintptr) handleVal {
101116
handleLock.Lock()
102117
defer handleLock.Unlock()
103118
r, ok := handleVals[handle]
@@ -108,7 +123,11 @@ func lookupHandle(handle uintptr) interface{} {
108123
panic("invalid handle")
109124
}
110125
}
111-
return r.val
126+
return r
127+
}
128+
129+
func lookupHandle(handle uintptr) interface{} {
130+
return lookupHandleVal(handle).val
112131
}
113132

114133
func deleteHandles(db *SQLiteConn) {

convert.go

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
// Extracted from Go database/sql source code
2+
3+
// Copyright 2011 The Go Authors. All rights reserved.
4+
// Use of this source code is governed by a BSD-style
5+
// license that can be found in the LICENSE file.
6+
7+
// Type conversions for Scan.
8+
9+
package sqlite3
10+
11+
import (
12+
"database/sql"
13+
"database/sql/driver"
14+
"errors"
15+
"fmt"
16+
"reflect"
17+
"strconv"
18+
"time"
19+
)
20+
21+
var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error
22+
23+
// convertAssign copies to dest the value in src, converting it if possible.
24+
// An error is returned if the copy would result in loss of information.
25+
// dest should be a pointer type.
26+
func convertAssign(dest, src interface{}) error {
27+
// Common cases, without reflect.
28+
switch s := src.(type) {
29+
case string:
30+
switch d := dest.(type) {
31+
case *string:
32+
if d == nil {
33+
return errNilPtr
34+
}
35+
*d = s
36+
return nil
37+
case *[]byte:
38+
if d == nil {
39+
return errNilPtr
40+
}
41+
*d = []byte(s)
42+
return nil
43+
case *sql.RawBytes:
44+
if d == nil {
45+
return errNilPtr
46+
}
47+
*d = append((*d)[:0], s...)
48+
return nil
49+
}
50+
case []byte:
51+
switch d := dest.(type) {
52+
case *string:
53+
if d == nil {
54+
return errNilPtr
55+
}
56+
*d = string(s)
57+
return nil
58+
case *interface{}:
59+
if d == nil {
60+
return errNilPtr
61+
}
62+
*d = cloneBytes(s)
63+
return nil
64+
case *[]byte:
65+
if d == nil {
66+
return errNilPtr
67+
}
68+
*d = cloneBytes(s)
69+
return nil
70+
case *sql.RawBytes:
71+
if d == nil {
72+
return errNilPtr
73+
}
74+
*d = s
75+
return nil
76+
}
77+
case time.Time:
78+
switch d := dest.(type) {
79+
case *time.Time:
80+
*d = s
81+
return nil
82+
case *string:
83+
*d = s.Format(time.RFC3339Nano)
84+
return nil
85+
case *[]byte:
86+
if d == nil {
87+
return errNilPtr
88+
}
89+
*d = []byte(s.Format(time.RFC3339Nano))
90+
return nil
91+
case *sql.RawBytes:
92+
if d == nil {
93+
return errNilPtr
94+
}
95+
*d = s.AppendFormat((*d)[:0], time.RFC3339Nano)
96+
return nil
97+
}
98+
case nil:
99+
switch d := dest.(type) {
100+
case *interface{}:
101+
if d == nil {
102+
return errNilPtr
103+
}
104+
*d = nil
105+
return nil
106+
case *[]byte:
107+
if d == nil {
108+
return errNilPtr
109+
}
110+
*d = nil
111+
return nil
112+
case *sql.RawBytes:
113+
if d == nil {
114+
return errNilPtr
115+
}
116+
*d = nil
117+
return nil
118+
}
119+
}
120+
121+
var sv reflect.Value
122+
123+
switch d := dest.(type) {
124+
case *string:
125+
sv = reflect.ValueOf(src)
126+
switch sv.Kind() {
127+
case reflect.Bool,
128+
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
129+
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
130+
reflect.Float32, reflect.Float64:
131+
*d = asString(src)
132+
return nil
133+
}
134+
case *[]byte:
135+
sv = reflect.ValueOf(src)
136+
if b, ok := asBytes(nil, sv); ok {
137+
*d = b
138+
return nil
139+
}
140+
case *sql.RawBytes:
141+
sv = reflect.ValueOf(src)
142+
if b, ok := asBytes([]byte(*d)[:0], sv); ok {
143+
*d = sql.RawBytes(b)
144+
return nil
145+
}
146+
case *bool:
147+
bv, err := driver.Bool.ConvertValue(src)
148+
if err == nil {
149+
*d = bv.(bool)
150+
}
151+
return err
152+
case *interface{}:
153+
*d = src
154+
return nil
155+
}
156+
157+
if scanner, ok := dest.(sql.Scanner); ok {
158+
return scanner.Scan(src)
159+
}
160+
161+
dpv := reflect.ValueOf(dest)
162+
if dpv.Kind() != reflect.Ptr {
163+
return errors.New("destination not a pointer")
164+
}
165+
if dpv.IsNil() {
166+
return errNilPtr
167+
}
168+
169+
if !sv.IsValid() {
170+
sv = reflect.ValueOf(src)
171+
}
172+
173+
dv := reflect.Indirect(dpv)
174+
if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
175+
switch b := src.(type) {
176+
case []byte:
177+
dv.Set(reflect.ValueOf(cloneBytes(b)))
178+
default:
179+
dv.Set(sv)
180+
}
181+
return nil
182+
}
183+
184+
if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) {
185+
dv.Set(sv.Convert(dv.Type()))
186+
return nil
187+
}
188+
189+
// The following conversions use a string value as an intermediate representation
190+
// to convert between various numeric types.
191+
//
192+
// This also allows scanning into user defined types such as "type Int int64".
193+
// For symmetry, also check for string destination types.
194+
switch dv.Kind() {
195+
case reflect.Ptr:
196+
if src == nil {
197+
dv.Set(reflect.Zero(dv.Type()))
198+
return nil
199+
} else {
200+
dv.Set(reflect.New(dv.Type().Elem()))
201+
return convertAssign(dv.Interface(), src)
202+
}
203+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
204+
s := asString(src)
205+
i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
206+
if err != nil {
207+
err = strconvErr(err)
208+
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
209+
}
210+
dv.SetInt(i64)
211+
return nil
212+
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
213+
s := asString(src)
214+
u64, err := strconv.ParseUint(s, 10, dv.Type().Bits())
215+
if err != nil {
216+
err = strconvErr(err)
217+
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
218+
}
219+
dv.SetUint(u64)
220+
return nil
221+
case reflect.Float32, reflect.Float64:
222+
s := asString(src)
223+
f64, err := strconv.ParseFloat(s, dv.Type().Bits())
224+
if err != nil {
225+
err = strconvErr(err)
226+
return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
227+
}
228+
dv.SetFloat(f64)
229+
return nil
230+
case reflect.String:
231+
switch v := src.(type) {
232+
case string:
233+
dv.SetString(v)
234+
return nil
235+
case []byte:
236+
dv.SetString(string(v))
237+
return nil
238+
}
239+
}
240+
241+
return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
242+
}
243+
244+
func strconvErr(err error) error {
245+
if ne, ok := err.(*strconv.NumError); ok {
246+
return ne.Err
247+
}
248+
return err
249+
}
250+
251+
func cloneBytes(b []byte) []byte {
252+
if b == nil {
253+
return nil
254+
} else {
255+
c := make([]byte, len(b))
256+
copy(c, b)
257+
return c
258+
}
259+
}
260+
261+
func asString(src interface{}) string {
262+
switch v := src.(type) {
263+
case string:
264+
return v
265+
case []byte:
266+
return string(v)
267+
}
268+
rv := reflect.ValueOf(src)
269+
switch rv.Kind() {
270+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
271+
return strconv.FormatInt(rv.Int(), 10)
272+
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
273+
return strconv.FormatUint(rv.Uint(), 10)
274+
case reflect.Float64:
275+
return strconv.FormatFloat(rv.Float(), 'g', -1, 64)
276+
case reflect.Float32:
277+
return strconv.FormatFloat(rv.Float(), 'g', -1, 32)
278+
case reflect.Bool:
279+
return strconv.FormatBool(rv.Bool())
280+
}
281+
return fmt.Sprintf("%v", src)
282+
}
283+
284+
func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
285+
switch rv.Kind() {
286+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
287+
return strconv.AppendInt(buf, rv.Int(), 10), true
288+
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
289+
return strconv.AppendUint(buf, rv.Uint(), 10), true
290+
case reflect.Float32:
291+
return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true
292+
case reflect.Float64:
293+
return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true
294+
case reflect.Bool:
295+
return strconv.AppendBool(buf, rv.Bool()), true
296+
case reflect.String:
297+
s := rv.String()
298+
return append(buf, s...), true
299+
}
300+
return
301+
}

0 commit comments

Comments
 (0)