@@ -7,7 +7,7 @@ package sqlite3
7
7
8
8
/*
9
9
#cgo CFLAGS: -std=gnu99
10
- #cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE
10
+ #cgo CFLAGS: -DSQLITE_ENABLE_RTREE -DSQLITE_THREADSAFE=1
11
11
#cgo CFLAGS: -DSQLITE_ENABLE_FTS3 -DSQLITE_ENABLE_FTS3_PARENTHESIS -DSQLITE_ENABLE_FTS4_UNICODE61
12
12
#cgo CFLAGS: -DSQLITE_TRACE_SIZE_LIMIT=15
13
13
#cgo CFLAGS: -DSQLITE_DISABLE_INTRINSIC
@@ -102,6 +102,9 @@ int _sqlite3_create_function(
102
102
void callbackTrampoline(sqlite3_context*, int, sqlite3_value**);
103
103
104
104
int compareTrampoline(void*, int, char*, int, char*);
105
+ int commitHookTrampoline(void*);
106
+ void rollbackHookTrampoline(void*);
107
+ void updateHookTrampoline(void*, int, char*, char*, sqlite3_int64);
105
108
*/
106
109
import "C"
107
110
import (
@@ -115,6 +118,7 @@ import (
115
118
"runtime"
116
119
"strconv"
117
120
"strings"
121
+ "sync"
118
122
"time"
119
123
"unsafe"
120
124
@@ -151,6 +155,12 @@ func Version() (libVersion string, libVersionNumber int, sourceID string) {
151
155
return libVersion , libVersionNumber , sourceID
152
156
}
153
157
158
+ const (
159
+ SQLITE_DELETE = C .SQLITE_DELETE
160
+ SQLITE_INSERT = C .SQLITE_INSERT
161
+ SQLITE_UPDATE = C .SQLITE_UPDATE
162
+ )
163
+
154
164
// SQLiteDriver implement sql.Driver.
155
165
type SQLiteDriver struct {
156
166
Extensions []string
@@ -159,6 +169,7 @@ type SQLiteDriver struct {
159
169
160
170
// SQLiteConn implement sql.Conn.
161
171
type SQLiteConn struct {
172
+ mu sync.Mutex
162
173
db * C.sqlite3
163
174
loc * time.Location
164
175
txlock string
@@ -173,6 +184,7 @@ type SQLiteTx struct {
173
184
174
185
// SQLiteStmt implement sql.Stmt.
175
186
type SQLiteStmt struct {
187
+ mu sync.Mutex
176
188
c * SQLiteConn
177
189
s * C.sqlite3_stmt
178
190
t string
@@ -193,6 +205,7 @@ type SQLiteRows struct {
193
205
cols []string
194
206
decltype []string
195
207
cls bool
208
+ closed bool
196
209
done chan struct {}
197
210
}
198
211
@@ -338,6 +351,51 @@ func (c *SQLiteConn) RegisterCollation(name string, cmp func(string, string) int
338
351
return nil
339
352
}
340
353
354
+ // RegisterCommitHook sets the commit hook for a connection.
355
+ //
356
+ // If the callback returns non-zero the transaction will become a rollback.
357
+ //
358
+ // If there is an existing commit hook for this connection, it will be
359
+ // removed. If callback is nil the existing hook (if any) will be removed
360
+ // without creating a new one.
361
+ func (c * SQLiteConn ) RegisterCommitHook (callback func () int ) {
362
+ if callback == nil {
363
+ C .sqlite3_commit_hook (c .db , nil , nil )
364
+ } else {
365
+ C .sqlite3_commit_hook (c .db , (* [0 ]byte )(unsafe .Pointer (C .commitHookTrampoline )), unsafe .Pointer (newHandle (c , callback )))
366
+ }
367
+ }
368
+
369
+ // RegisterRollbackHook sets the rollback hook for a connection.
370
+ //
371
+ // If there is an existing rollback hook for this connection, it will be
372
+ // removed. If callback is nil the existing hook (if any) will be removed
373
+ // without creating a new one.
374
+ func (c * SQLiteConn ) RegisterRollbackHook (callback func ()) {
375
+ if callback == nil {
376
+ C .sqlite3_rollback_hook (c .db , nil , nil )
377
+ } else {
378
+ C .sqlite3_rollback_hook (c .db , (* [0 ]byte )(unsafe .Pointer (C .rollbackHookTrampoline )), unsafe .Pointer (newHandle (c , callback )))
379
+ }
380
+ }
381
+
382
+ // RegisterUpdateHook sets the update hook for a connection.
383
+ //
384
+ // The parameters to the callback are the operation (one of the constants
385
+ // SQLITE_INSERT, SQLITE_DELETE, or SQLITE_UPDATE), the database name, the
386
+ // table name, and the rowid.
387
+ //
388
+ // If there is an existing update hook for this connection, it will be
389
+ // removed. If callback is nil the existing hook (if any) will be removed
390
+ // without creating a new one.
391
+ func (c * SQLiteConn ) RegisterUpdateHook (callback func (int , string , string , int64 )) {
392
+ if callback == nil {
393
+ C .sqlite3_update_hook (c .db , nil , nil )
394
+ } else {
395
+ C .sqlite3_update_hook (c .db , (* [0 ]byte )(unsafe .Pointer (C .updateHookTrampoline )), unsafe .Pointer (newHandle (c , callback )))
396
+ }
397
+ }
398
+
341
399
// RegisterFunc makes a Go function available as a SQLite function.
342
400
//
343
401
// The Go function can have arguments of the following types: any
@@ -568,6 +626,8 @@ func errorString(err Error) string {
568
626
// "deferred", "exclusive".
569
627
// _foreign_keys=X
570
628
// Enable or disable enforcement of foreign keys. X can be 1 or 0.
629
+ // _recursive_triggers=X
630
+ // Enable or disable recursive triggers. X can be 1 or 0.
571
631
func (d * SQLiteDriver ) Open (dsn string ) (driver.Conn , error ) {
572
632
if C .sqlite3_threadsafe () == 0 {
573
633
return nil , errors .New ("sqlite library was not compiled for thread-safe operation" )
@@ -577,6 +637,7 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
577
637
txlock := "BEGIN"
578
638
busyTimeout := 5000
579
639
foreignKeys := - 1
640
+ recursiveTriggers := - 1
580
641
pos := strings .IndexRune (dsn , '?' )
581
642
if pos >= 1 {
582
643
params , err := url .ParseQuery (dsn [pos + 1 :])
@@ -631,6 +692,18 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
631
692
}
632
693
}
633
694
695
+ // _recursive_triggers
696
+ if val := params .Get ("_recursive_triggers" ); val != "" {
697
+ switch val {
698
+ case "1" :
699
+ recursiveTriggers = 1
700
+ case "0" :
701
+ recursiveTriggers = 0
702
+ default :
703
+ return nil , fmt .Errorf ("Invalid _recursive_triggers: %v" , val )
704
+ }
705
+ }
706
+
634
707
if ! strings .HasPrefix (dsn , "file:" ) {
635
708
dsn = dsn [:pos ]
636
709
}
@@ -677,6 +750,17 @@ func (d *SQLiteDriver) Open(dsn string) (driver.Conn, error) {
677
750
return nil , err
678
751
}
679
752
}
753
+ if recursiveTriggers == 0 {
754
+ if err := exec ("PRAGMA recursive_triggers = OFF;" ); err != nil {
755
+ C .sqlite3_close_v2 (db )
756
+ return nil , err
757
+ }
758
+ } else if recursiveTriggers == 1 {
759
+ if err := exec ("PRAGMA recursive_triggers = ON;" ); err != nil {
760
+ C .sqlite3_close_v2 (db )
761
+ return nil , err
762
+ }
763
+ }
680
764
681
765
conn := & SQLiteConn {db : db , loc : loc , txlock : txlock }
682
766
@@ -704,11 +788,22 @@ func (c *SQLiteConn) Close() error {
704
788
return c .lastError ()
705
789
}
706
790
deleteHandles (c )
791
+ c .mu .Lock ()
707
792
c .db = nil
793
+ c .mu .Unlock ()
708
794
runtime .SetFinalizer (c , nil )
709
795
return nil
710
796
}
711
797
798
+ func (c * SQLiteConn ) dbConnOpen () bool {
799
+ if c == nil {
800
+ return false
801
+ }
802
+ c .mu .Lock ()
803
+ defer c .mu .Unlock ()
804
+ return c .db != nil
805
+ }
806
+
712
807
// Prepare the query string. Return a new statement.
713
808
func (c * SQLiteConn ) Prepare (query string ) (driver.Stmt , error ) {
714
809
return c .prepare (context .Background (), query )
@@ -734,14 +829,17 @@ func (c *SQLiteConn) prepare(ctx context.Context, query string) (driver.Stmt, er
734
829
735
830
// Close the statement.
736
831
func (s * SQLiteStmt ) Close () error {
832
+ s .mu .Lock ()
833
+ defer s .mu .Unlock ()
737
834
if s .closed {
738
835
return nil
739
836
}
740
837
s .closed = true
741
- if s .c == nil || s . c . db == nil {
838
+ if ! s .c . dbConnOpen () {
742
839
return errors .New ("sqlite statement with already closed database connection" )
743
840
}
744
841
rv := C .sqlite3_finalize (s .s )
842
+ s .s = nil
745
843
if rv != C .SQLITE_OK {
746
844
return s .c .lastError ()
747
845
}
@@ -759,6 +857,8 @@ type bindArg struct {
759
857
v driver.Value
760
858
}
761
859
860
+ var placeHolder = []byte {0 }
861
+
762
862
func (s * SQLiteStmt ) bind (args []namedValue ) error {
763
863
rv := C .sqlite3_reset (s .s )
764
864
if rv != C .SQLITE_ROW && rv != C .SQLITE_OK && rv != C .SQLITE_DONE {
@@ -780,8 +880,7 @@ func (s *SQLiteStmt) bind(args []namedValue) error {
780
880
rv = C .sqlite3_bind_null (s .s , n )
781
881
case string :
782
882
if len (v ) == 0 {
783
- b := []byte {0 }
784
- rv = C ._sqlite3_bind_text (s .s , n , (* C .char )(unsafe .Pointer (& b [0 ])), C .int (0 ))
883
+ rv = C ._sqlite3_bind_text (s .s , n , (* C .char )(unsafe .Pointer (& placeHolder [0 ])), C .int (0 ))
785
884
} else {
786
885
b := []byte (v )
787
886
rv = C ._sqlite3_bind_text (s .s , n , (* C .char )(unsafe .Pointer (& b [0 ])), C .int (len (b )))
@@ -797,11 +896,11 @@ func (s *SQLiteStmt) bind(args []namedValue) error {
797
896
case float64 :
798
897
rv = C .sqlite3_bind_double (s .s , n , C .double (v ))
799
898
case []byte :
800
- if len (v ) == 0 {
801
- rv = C ._sqlite3_bind_blob (s .s , n , nil , 0 )
802
- } else {
803
- rv = C ._sqlite3_bind_blob (s .s , n , unsafe .Pointer (& v [0 ]), C .int (len (v )))
899
+ ln := len (v )
900
+ if ln == 0 {
901
+ v = placeHolder
804
902
}
903
+ rv = C ._sqlite3_bind_blob (s .s , n , unsafe .Pointer (& v [0 ]), C .int (ln ))
805
904
case time.Time :
806
905
b := []byte (v .Format (SQLiteTimestampFormats [0 ]))
807
906
rv = C ._sqlite3_bind_text (s .s , n , (* C .char )(unsafe .Pointer (& b [0 ])), C .int (len (b )))
@@ -836,6 +935,7 @@ func (s *SQLiteStmt) query(ctx context.Context, args []namedValue) (driver.Rows,
836
935
cols : nil ,
837
936
decltype : nil ,
838
937
cls : s .cls ,
938
+ closed : false ,
839
939
done : make (chan struct {}),
840
940
}
841
941
@@ -908,25 +1008,33 @@ func (s *SQLiteStmt) exec(ctx context.Context, args []namedValue) (driver.Result
908
1008
909
1009
// Close the rows.
910
1010
func (rc * SQLiteRows ) Close () error {
911
- if rc .s .closed {
1011
+ rc .s .mu .Lock ()
1012
+ if rc .s .closed || rc .closed {
1013
+ rc .s .mu .Unlock ()
912
1014
return nil
913
1015
}
1016
+ rc .closed = true
914
1017
if rc .done != nil {
915
1018
close (rc .done )
916
1019
}
917
1020
if rc .cls {
1021
+ rc .s .mu .Unlock ()
918
1022
return rc .s .Close ()
919
1023
}
920
1024
rv := C .sqlite3_reset (rc .s .s )
921
1025
if rv != C .SQLITE_OK {
1026
+ rc .s .mu .Unlock ()
922
1027
return rc .s .c .lastError ()
923
1028
}
1029
+ rc .s .mu .Unlock ()
924
1030
return nil
925
1031
}
926
1032
927
1033
// Columns return column names.
928
1034
func (rc * SQLiteRows ) Columns () []string {
929
- if rc .nc != len (rc .cols ) {
1035
+ rc .s .mu .Lock ()
1036
+ defer rc .s .mu .Unlock ()
1037
+ if rc .s .s != nil && rc .nc != len (rc .cols ) {
930
1038
rc .cols = make ([]string , rc .nc )
931
1039
for i := 0 ; i < rc .nc ; i ++ {
932
1040
rc .cols [i ] = C .GoString (C .sqlite3_column_name (rc .s .s , C .int (i )))
@@ -935,9 +1043,8 @@ func (rc *SQLiteRows) Columns() []string {
935
1043
return rc .cols
936
1044
}
937
1045
938
- // DeclTypes return column types.
939
- func (rc * SQLiteRows ) DeclTypes () []string {
940
- if rc .decltype == nil {
1046
+ func (rc * SQLiteRows ) declTypes () []string {
1047
+ if rc .s .s != nil && rc .decltype == nil {
941
1048
rc .decltype = make ([]string , rc .nc )
942
1049
for i := 0 ; i < rc .nc ; i ++ {
943
1050
rc .decltype [i ] = strings .ToLower (C .GoString (C .sqlite3_column_decltype (rc .s .s , C .int (i ))))
@@ -946,8 +1053,20 @@ func (rc *SQLiteRows) DeclTypes() []string {
946
1053
return rc .decltype
947
1054
}
948
1055
1056
+ // DeclTypes return column types.
1057
+ func (rc * SQLiteRows ) DeclTypes () []string {
1058
+ rc .s .mu .Lock ()
1059
+ defer rc .s .mu .Unlock ()
1060
+ return rc .declTypes ()
1061
+ }
1062
+
949
1063
// Next move cursor to next.
950
1064
func (rc * SQLiteRows ) Next (dest []driver.Value ) error {
1065
+ if rc .s .closed {
1066
+ return io .EOF
1067
+ }
1068
+ rc .s .mu .Lock ()
1069
+ defer rc .s .mu .Unlock ()
951
1070
rv := C .sqlite3_step (rc .s .s )
952
1071
if rv == C .SQLITE_DONE {
953
1072
return io .EOF
@@ -960,7 +1079,7 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
960
1079
return nil
961
1080
}
962
1081
963
- rc .DeclTypes ()
1082
+ rc .declTypes ()
964
1083
965
1084
for i := range dest {
966
1085
switch C .sqlite3_column_type (rc .s .s , C .int (i )) {
@@ -973,10 +1092,11 @@ func (rc *SQLiteRows) Next(dest []driver.Value) error {
973
1092
// large to be a reasonable timestamp in seconds.
974
1093
if val > 1e12 || val < - 1e12 {
975
1094
val *= int64 (time .Millisecond ) // convert ms to nsec
1095
+ t = time .Unix (0 , val )
976
1096
} else {
977
- val *= int64 ( time .Second ) // convert sec to nsec
1097
+ t = time .Unix ( val , 0 )
978
1098
}
979
- t = time . Unix ( 0 , val ) .UTC ()
1099
+ t = t .UTC ()
980
1100
if rc .s .c .loc != nil {
981
1101
t = t .In (rc .s .c .loc )
982
1102
}
0 commit comments