Skip to content

Commit 132eeed

Browse files
authored
Merge branch 'master' into master
2 parents 0430b37 + b8d537f commit 132eeed

File tree

9 files changed

+823
-454
lines changed

9 files changed

+823
-454
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ FAQ
6565
6666
* Want to get time.Time with current locale
6767

68-
Use `loc=auto` in SQLite3 filename schema like `file:foo.db?loc=auto`.
68+
Use `_loc=auto` in SQLite3 filename schema like `file:foo.db?_loc=auto`.
6969

7070
* Can I use this in multiple routines concurrently?
7171

_example/hook/hook.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ func main() {
1414
&sqlite3.SQLiteDriver{
1515
ConnectHook: func(conn *sqlite3.SQLiteConn) error {
1616
sqlite3conn = append(sqlite3conn, conn)
17+
conn.RegisterUpdateHook(func(op int, db string, table string, rowid int64) {
18+
switch op {
19+
case sqlite3.SQLITE_INSERT:
20+
log.Println("Notified of insert on db", db, "table", table, "rowid", rowid)
21+
}
22+
})
1723
return nil
1824
},
1925
})

callback.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,24 @@ func compareTrampoline(handlePtr uintptr, la C.int, a *C.char, lb C.int, b *C.ch
5959
return C.int(cmp(C.GoStringN(a, la), C.GoStringN(b, lb)))
6060
}
6161

62+
//export commitHookTrampoline
63+
func commitHookTrampoline(handle uintptr) int {
64+
callback := lookupHandle(handle).(func() int)
65+
return callback()
66+
}
67+
68+
//export rollbackHookTrampoline
69+
func rollbackHookTrampoline(handle uintptr) {
70+
callback := lookupHandle(handle).(func())
71+
callback()
72+
}
73+
74+
//export updateHookTrampoline
75+
func updateHookTrampoline(handle uintptr, op int, db *C.char, table *C.char, rowid int64) {
76+
callback := lookupHandle(handle).(func(int, string, string, int64))
77+
callback(op, C.GoString(db), C.GoString(table), rowid)
78+
}
79+
6280
// Use handles to avoid passing Go pointers to C.
6381

6482
type handleVal struct {

sqlite3.go

Lines changed: 136 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ package sqlite3
77

88
/*
99
#cgo CFLAGS: -std=gnu99
10-
#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE
10+
#cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE=1
1111
#cgo CFLAGS: -DSQLITE_ENABLE_FTS3 -DSQLITE_ENABLE_FTS3_PARENTHESIS -DSQLITE_ENABLE_FTS4_UNICODE61
1212
#cgo CFLAGS: -DSQLITE_TRACE_SIZE_LIMIT=15
1313
#cgo CFLAGS: -DSQLITE_DISABLE_INTRINSIC
@@ -102,6 +102,9 @@ int _sqlite3_create_function(
102102
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
103103
104104
int compareTrampoline(void*, int, char*, int, char*);
105+
int commitHookTrampoline(void*);
106+
void rollbackHookTrampoline(void*);
107+
void updateHookTrampoline(void*, int, char*, char*, sqlite3_int64);
105108
*/
106109
import "C"
107110
import (
@@ -115,6 +118,7 @@ import (
115118
"runtime"
116119
"strconv"
117120
"strings"
121+
"sync"
118122
"time"
119123
"unsafe"
120124

@@ -151,6 +155,12 @@ func Version() (libVersion string, libVersionNumber int, sourceID string) {
151155
return libVersion, libVersionNumber, sourceID
152156
}
153157

158+
const (
159+
SQLITE_DELETE = C.SQLITE_DELETE
160+
SQLITE_INSERT = C.SQLITE_INSERT
161+
SQLITE_UPDATE = C.SQLITE_UPDATE
162+
)
163+
154164
// SQLiteDriver implement sql.Driver.
155165
type SQLiteDriver struct {
156166
Extensions []string
@@ -159,6 +169,7 @@ type SQLiteDriver struct {
159169

160170
// SQLiteConn implement sql.Conn.
161171
type SQLiteConn struct {
172+
mu sync.Mutex
162173
db *C.sqlite3
163174
loc *time.Location
164175
txlock string
@@ -173,6 +184,7 @@ type SQLiteTx struct {
173184

174185
// SQLiteStmt implement sql.Stmt.
175186
type SQLiteStmt struct {
187+
mu sync.Mutex
176188
c *SQLiteConn
177189
s *C.sqlite3_stmt
178190
t string
@@ -193,6 +205,7 @@ type SQLiteRows struct {
193205
cols []string
194206
decltype []string
195207
cls bool
208+
closed bool
196209
done chan struct{}
197210
}
198211

@@ -338,6 +351,51 @@ func (c *SQLiteConn) RegisterCollation(name string, cmp func(string, string) int
338351
return nil
339352
}
340353

354+
// RegisterCommitHook sets the commit hook for a connection.
355+
//
356+
// If the callback returns non-zero the transaction will become a rollback.
357+
//
358+
// If there is an existing commit hook for this connection, it will be
359+
// removed. If callback is nil the existing hook (if any) will be removed
360+
// without creating a new one.
361+
func (c *SQLiteConn) RegisterCommitHook(callback func() int) {
362+
if callback == nil {
363+
C.sqlite3_commit_hook(c.db, nil, nil)
364+
} else {
365+
C.sqlite3_commit_hook(c.db, (*[0]byte)(unsafe.Pointer(C.commitHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
366+
}
367+
}
368+
369+
// RegisterRollbackHook sets the rollback hook for a connection.
370+
//
371+
// If there is an existing rollback hook for this connection, it will be
372+
// removed. If callback is nil the existing hook (if any) will be removed
373+
// without creating a new one.
374+
func (c *SQLiteConn) RegisterRollbackHook(callback func()) {
375+
if callback == nil {
376+
C.sqlite3_rollback_hook(c.db, nil, nil)
377+
} else {
378+
C.sqlite3_rollback_hook(c.db, (*[0]byte)(unsafe.Pointer(C.rollbackHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
379+
}
380+
}
381+
382+
// RegisterUpdateHook sets the update hook for a connection.
383+
//
384+
// The parameters to the callback are the operation (one of the constants
385+
// SQLITE_INSERT, SQLITE_DELETE, or SQLITE_UPDATE), the database name, the
386+
// table name, and the rowid.
387+
//
388+
// If there is an existing update hook for this connection, it will be
389+
// removed. If callback is nil the existing hook (if any) will be removed
390+
// without creating a new one.
391+
func (c *SQLiteConn) RegisterUpdateHook(callback func(int, string, string, int64)) {
392+
if callback == nil {
393+
C.sqlite3_update_hook(c.db, nil, nil)
394+
} else {
395+
C.sqlite3_update_hook(c.db, (*[0]byte)(unsafe.Pointer(C.updateHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
396+
}
397+
}
398+
341399
// RegisterFunc makes a Go function available as a SQLite function.
342400
//
343401
// The Go function can have arguments of the following types: any
@@ -568,6 +626,8 @@ func errorString(err Error) string {
568626
// "deferred", "exclusive".
569627
// _foreign_keys=X
570628
// Enable or disable enforcement of foreign keys. X can be 1 or 0.
629+
// _recursive_triggers=X
630+
// Enable or disable recursive triggers. X can be 1 or 0.
571631
func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
572632
if C.sqlite3_threadsafe() == 0 {
573633
return nil, errors.New("sqlite library was not compiled for thread-safe operation")
@@ -577,6 +637,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
577637
txlock := "BEGIN"
578638
busyTimeout := 5000
579639
foreignKeys := -1
640+
recursiveTriggers := -1
580641
pos := strings.IndexRune(dsn, '?')
581642
if pos >= 1 {
582643
params, err := url.ParseQuery(dsn[pos+1:])
@@ -631,6 +692,18 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
631692
}
632693
}
633694

695+
// _recursive_triggers
696+
if val := params.Get("_recursive_triggers"); val != "" {
697+
switch val {
698+
case "1":
699+
recursiveTriggers = 1
700+
case "0":
701+
recursiveTriggers = 0
702+
default:
703+
return nil, fmt.Errorf("Invalid _recursive_triggers: %v", val)
704+
}
705+
}
706+
634707
if !strings.HasPrefix(dsn, "file:") {
635708
dsn = dsn[:pos]
636709
}
@@ -677,6 +750,17 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
677750
return nil, err
678751
}
679752
}
753+
if recursiveTriggers == 0 {
754+
if err := exec("PRAGMA recursive_triggers = OFF;"); err != nil {
755+
C.sqlite3_close_v2(db)
756+
return nil, err
757+
}
758+
} else if recursiveTriggers == 1 {
759+
if err := exec("PRAGMA recursive_triggers = ON;"); err != nil {
760+
C.sqlite3_close_v2(db)
761+
return nil, err
762+
}
763+
}
680764

681765
conn := &SQLiteConn{db: db, loc: loc, txlock: txlock}
682766

@@ -704,11 +788,22 @@ func (c *SQLiteConn) Close() error {
704788
return c.lastError()
705789
}
706790
deleteHandles(c)
791+
c.mu.Lock()
707792
c.db = nil
793+
c.mu.Unlock()
708794
runtime.SetFinalizer(c, nil)
709795
return nil
710796
}
711797

798+
func (c *SQLiteConn) dbConnOpen() bool {
799+
if c == nil {
800+
return false
801+
}
802+
c.mu.Lock()
803+
defer c.mu.Unlock()
804+
return c.db != nil
805+
}
806+
712807
// Prepare the query string. Return a new statement.
713808
func (c *SQLiteConn) Prepare(query string) (driver.Stmt, error) {
714809
return c.prepare(context.Background(), query)
@@ -734,14 +829,17 @@ func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, er
734829

735830
// Close the statement.
736831
func (s *SQLiteStmt) Close() error {
832+
s.mu.Lock()
833+
defer s.mu.Unlock()
737834
if s.closed {
738835
return nil
739836
}
740837
s.closed = true
741-
if s.c == nil || s.c.db == nil {
838+
if !s.c.dbConnOpen() {
742839
return errors.New("sqlite statement with already closed database connection")
743840
}
744841
rv := C.sqlite3_finalize(s.s)
842+
s.s = nil
745843
if rv != C.SQLITE_OK {
746844
return s.c.lastError()
747845
}
@@ -759,6 +857,8 @@ type bindArg struct {
759857
v driver.Value
760858
}
761859

860+
var placeHolder = []byte{0}
861+
762862
func (s *SQLiteStmt) bind(args []namedValue) error {
763863
rv := C.sqlite3_reset(s.s)
764864
if rv != C.SQLITE_ROW && rv != C.SQLITE_OK && rv != C.SQLITE_DONE {
@@ -780,8 +880,7 @@ func (s *SQLiteStmt) bind(args []namedValue) error {
780880
rv = C.sqlite3_bind_null(s.s, n)
781881
case string:
782882
if len(v) == 0 {
783-
b := []byte{0}
784-
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(0))
883+
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&placeHolder[0])), C.int(0))
785884
} else {
786885
b := []byte(v)
787886
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
@@ -797,11 +896,11 @@ func (s *SQLiteStmt) bind(args []namedValue) error {
797896
case float64:
798897
rv = C.sqlite3_bind_double(s.s, n, C.double(v))
799898
case []byte:
800-
if len(v) == 0 {
801-
rv = C._sqlite3_bind_blob(s.s, n, nil, 0)
802-
} else {
803-
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(len(v)))
899+
ln := len(v)
900+
if ln == 0 {
901+
v = placeHolder
804902
}
903+
rv = C._sqlite3_bind_blob(s.s, n, unsafe.Pointer(&v[0]), C.int(ln))
805904
case time.Time:
806905
b := []byte(v.Format(SQLiteTimestampFormats[0]))
807906
rv = C._sqlite3_bind_text(s.s, n, (*C.char)(unsafe.Pointer(&b[0])), C.int(len(b)))
@@ -836,6 +935,7 @@ func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows,
836935
cols: nil,
837936
decltype: nil,
838937
cls: s.cls,
938+
closed: false,
839939
done: make(chan struct{}),
840940
}
841941

@@ -908,25 +1008,33 @@ func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result
9081008

9091009
// Close the rows.
9101010
func (rc *SQLiteRows) Close() error {
911-
if rc.s.closed {
1011+
rc.s.mu.Lock()
1012+
if rc.s.closed || rc.closed {
1013+
rc.s.mu.Unlock()
9121014
return nil
9131015
}
1016+
rc.closed = true
9141017
if rc.done != nil {
9151018
close(rc.done)
9161019
}
9171020
if rc.cls {
1021+
rc.s.mu.Unlock()
9181022
return rc.s.Close()
9191023
}
9201024
rv := C.sqlite3_reset(rc.s.s)
9211025
if rv != C.SQLITE_OK {
1026+
rc.s.mu.Unlock()
9221027
return rc.s.c.lastError()
9231028
}
1029+
rc.s.mu.Unlock()
9241030
return nil
9251031
}
9261032

9271033
// Columns return column names.
9281034
func (rc *SQLiteRows) Columns() []string {
929-
if rc.nc != len(rc.cols) {
1035+
rc.s.mu.Lock()
1036+
defer rc.s.mu.Unlock()
1037+
if rc.s.s != nil && rc.nc != len(rc.cols) {
9301038
rc.cols = make([]string, rc.nc)
9311039
for i := 0; i < rc.nc; i++ {
9321040
rc.cols[i] = C.GoString(C.sqlite3_column_name(rc.s.s, C.int(i)))
@@ -935,9 +1043,8 @@ func (rc *SQLiteRows) Columns() []string {
9351043
return rc.cols
9361044
}
9371045

938-
// DeclTypes return column types.
939-
func (rc *SQLiteRows) DeclTypes() []string {
940-
if rc.decltype == nil {
1046+
func (rc *SQLiteRows) declTypes() []string {
1047+
if rc.s.s != nil && rc.decltype == nil {
9411048
rc.decltype = make([]string, rc.nc)
9421049
for i := 0; i < rc.nc; i++ {
9431050
rc.decltype[i] = strings.ToLower(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i))))
@@ -946,8 +1053,20 @@ func (rc *SQLiteRows) DeclTypes() []string {
9461053
return rc.decltype
9471054
}
9481055

1056+
// DeclTypes return column types.
1057+
func (rc *SQLiteRows) DeclTypes() []string {
1058+
rc.s.mu.Lock()
1059+
defer rc.s.mu.Unlock()
1060+
return rc.declTypes()
1061+
}
1062+
9491063
// Next move cursor to next.
9501064
func (rc *SQLiteRows) Next(dest []driver.Value) error {
1065+
if rc.s.closed {
1066+
return io.EOF
1067+
}
1068+
rc.s.mu.Lock()
1069+
defer rc.s.mu.Unlock()
9511070
rv := C.sqlite3_step(rc.s.s)
9521071
if rv == C.SQLITE_DONE {
9531072
return io.EOF
@@ -960,7 +1079,7 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
9601079
return nil
9611080
}
9621081

963-
rc.DeclTypes()
1082+
rc.declTypes()
9641083

9651084
for i := range dest {
9661085
switch C.sqlite3_column_type(rc.s.s, C.int(i)) {
@@ -973,10 +1092,11 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
9731092
// large to be a reasonable timestamp in seconds.
9741093
if val > 1e12 || val < -1e12 {
9751094
val *= int64(time.Millisecond) // convert ms to nsec
1095+
t = time.Unix(0, val)
9761096
} else {
977-
val *= int64(time.Second) // convert sec to nsec
1097+
t = time.Unix(val, 0)
9781098
}
979-
t = time.Unix(0, val).UTC()
1099+
t = t.UTC()
9801100
if rc.s.c.loc != nil {
9811101
t = t.In(rc.s.c.loc)
9821102
}

0 commit comments

Comments
 (0)