@@ -8,8 +8,11 @@ package sqlite3
8
8
9
9
import (
10
10
"database/sql"
11
+ "errors"
11
12
"fmt"
12
13
"os"
14
+ "reflect"
15
+ "strings"
13
16
"testing"
14
17
)
15
18
@@ -148,8 +151,335 @@ func TestCreateModule(t *testing.T) {
148
151
t .Fatalf ("want %v but %v" , intarray [i ], value )
149
152
}
150
153
}
154
+
151
155
_ , err = db .Exec ("DROP TABLE vtab" )
152
156
if err != nil {
153
157
t .Fatalf ("couldn't drop virtual table: %v" , err )
154
158
}
155
159
}
160
+
161
+ func TestVUpdate (t * testing.T ) {
162
+ tempFilename := TempFilename (t )
163
+ defer os .Remove (tempFilename )
164
+
165
+ // create module
166
+ updateMod := & vtabUpdateModule {t , make (map [string ]* vtabUpdateTable )}
167
+
168
+ // register module
169
+ sql .Register ("sqlite3_TestVUpdate" , & SQLiteDriver {
170
+ ConnectHook : func (conn * SQLiteConn ) error {
171
+ return conn .CreateModule ("updatetest" , updateMod )
172
+ },
173
+ })
174
+
175
+ // connect
176
+ db , err := sql .Open ("sqlite3_TestVUpdate" , tempFilename )
177
+ if err != nil {
178
+ t .Fatalf ("could not open db: %v" , err )
179
+ }
180
+
181
+ // create test table
182
+ _ , err = db .Exec (`CREATE VIRTUAL TABLE vt USING updatetest(f1 integer, f2 text, f3 text)` )
183
+ if err != nil {
184
+ t .Fatalf ("could not create updatetest vtable vt, got: %v" , err )
185
+ }
186
+
187
+ // check that table is defined properly
188
+ if len (updateMod .tables ) != 1 {
189
+ t .Fatalf ("expected exactly 1 table to exist, got: %d" , len (updateMod .tables ))
190
+ }
191
+ if _ , ok := updateMod .tables ["vt" ]; ! ok {
192
+ t .Fatalf ("expected table `vt` to exist in tables" )
193
+ }
194
+
195
+ // check nothing in updatetest
196
+ rows , err := db .Query (`select * from vt` )
197
+ if err != nil {
198
+ t .Fatalf ("could not query vt, got: %v" , err )
199
+ }
200
+ i , err := getRowCount (rows )
201
+ if err != nil {
202
+ t .Fatalf ("expected no error, got: %v" , err )
203
+ }
204
+ if i != 0 {
205
+ t .Fatalf ("expected no rows in vt, got: %d" , i )
206
+ }
207
+
208
+ _ , err = db .Exec (`delete from vt where f1 = 'yes'` )
209
+ if err != nil {
210
+ t .Fatalf ("expected error on delete, got nil" )
211
+ }
212
+
213
+ // test bad column name
214
+ _ , err = db .Exec (`insert into vt (f4) values('a')` )
215
+ if err == nil {
216
+ t .Fatalf ("expected error on insert, got nil" )
217
+ }
218
+
219
+ // insert to vt
220
+ res , err := db .Exec (`insert into vt (f1, f2, f3) values (115, 'b', 'c'), (116, 'd', 'e')` )
221
+ if err != nil {
222
+ t .Fatalf ("expected no error on insert, got: %v" , err )
223
+ }
224
+ n , err := res .RowsAffected ()
225
+ if err != nil {
226
+ t .Fatalf ("expected no error, got: %v" , err )
227
+ }
228
+ if n != 2 {
229
+ t .Fatalf ("expected 1 row affected, got: %d" , n )
230
+ }
231
+
232
+ // check vt table
233
+ vt := updateMod .tables ["vt" ]
234
+ if len (vt .data ) != 2 {
235
+ t .Fatalf ("expected table vt to have exactly 2 rows, got: %d" , len (vt .data ))
236
+ }
237
+ if ! reflect .DeepEqual (vt .data [0 ], []interface {}{int64 (115 ), "b" , "c" }) {
238
+ t .Fatalf ("expected table vt entry 0 to be [115 b c], instead: %v" , vt .data [0 ])
239
+ }
240
+ if ! reflect .DeepEqual (vt .data [1 ], []interface {}{int64 (116 ), "d" , "e" }) {
241
+ t .Fatalf ("expected table vt entry 1 to be [116 d e], instead: %v" , vt .data [1 ])
242
+ }
243
+
244
+ // query vt
245
+ var f1 int
246
+ var f2 , f3 string
247
+ err = db .QueryRow (`select * from vt where f1 = 115` ).Scan (& f1 , & f2 , & f3 )
248
+ if err != nil {
249
+ t .Fatalf ("expected no error on vt query, got: %v" , err )
250
+ }
251
+
252
+ // check column values
253
+ if f1 != 115 || f2 != "b" || f3 != "c" {
254
+ t .Errorf ("expected f1==115, f2==b, f3==c, got: %d, %q, %q" , f1 , f2 , f3 )
255
+ }
256
+
257
+ // update vt
258
+ res , err = db .Exec (`update vt set f1=117, f2='f' where f3='e'` )
259
+ if err != nil {
260
+ t .Fatalf ("expected no error, got: %v" , err )
261
+ }
262
+ n , err = res .RowsAffected ()
263
+ if err != nil {
264
+ t .Fatalf ("expected no error, got: %v" , err )
265
+ }
266
+ if n != 1 {
267
+ t .Fatalf ("expected exactly one row updated, got: %d" , n )
268
+ }
269
+
270
+ // check vt table
271
+ if len (vt .data ) != 2 {
272
+ t .Fatalf ("expected table vt to have exactly 2 rows, got: %d" , len (vt .data ))
273
+ }
274
+ if ! reflect .DeepEqual (vt .data [0 ], []interface {}{int64 (115 ), "b" , "c" }) {
275
+ t .Fatalf ("expected table vt entry 0 to be [115 b c], instead: %v" , vt .data [0 ])
276
+ }
277
+ if ! reflect .DeepEqual (vt .data [1 ], []interface {}{int64 (117 ), "f" , "e" }) {
278
+ t .Fatalf ("expected table vt entry 1 to be [117 f e], instead: %v" , vt .data [1 ])
279
+ }
280
+
281
+ // delete from vt
282
+ res , err = db .Exec (`delete from vt where f1 = 117` )
283
+ if err != nil {
284
+ t .Fatalf ("expected no error, got: %v" , err )
285
+ }
286
+ n , err = res .RowsAffected ()
287
+ if err != nil {
288
+ t .Fatalf ("expected no error, got: %v" , err )
289
+ }
290
+ if n != 1 {
291
+ t .Fatalf ("expected exactly one row deleted, got: %d" , n )
292
+ }
293
+
294
+ // check vt table
295
+ if len (vt .data ) != 1 {
296
+ t .Fatalf ("expected table vt to have exactly 1 row, got: %d" , len (vt .data ))
297
+ }
298
+ if ! reflect .DeepEqual (vt .data [0 ], []interface {}{int64 (115 ), "b" , "c" }) {
299
+ t .Fatalf ("expected table vt entry 0 to be [115 b c], instead: %v" , vt .data [0 ])
300
+ }
301
+
302
+ // check updatetest has 1 result
303
+ rows , err = db .Query (`select * from vt` )
304
+ if err != nil {
305
+ t .Fatalf ("could not query vt, got: %v" , err )
306
+ }
307
+ i , err = getRowCount (rows )
308
+ if err != nil {
309
+ t .Fatalf ("expected no error, got: %v" , err )
310
+ }
311
+ if i != 1 {
312
+ t .Fatalf ("expected 1 row in vt, got: %d" , i )
313
+ }
314
+ }
315
+
316
+ func getRowCount (rows * sql.Rows ) (int , error ) {
317
+ var i int
318
+ for rows .Next () {
319
+ i ++
320
+ }
321
+ return i , nil
322
+ }
323
+
324
+ type vtabUpdateModule struct {
325
+ t * testing.T
326
+ tables map [string ]* vtabUpdateTable
327
+ }
328
+
329
+ func (m * vtabUpdateModule ) Create (c * SQLiteConn , args []string ) (VTab , error ) {
330
+ if len (args ) < 2 {
331
+ return nil , errors .New ("must declare at least one column" )
332
+ }
333
+
334
+ // get database name, table name, and column declarations ...
335
+ dbname , tname , decls := args [1 ], args [2 ], args [3 :]
336
+
337
+ // extract column names + types from parameters declarations
338
+ cols , typs := make ([]string , len (decls )), make ([]string , len (decls ))
339
+ for i := 0 ; i < len (decls ); i ++ {
340
+ n , typ := decls [i ], ""
341
+ if j := strings .IndexAny (n , " \t \n " ); j != - 1 {
342
+ typ , n = strings .TrimSpace (n [j + 1 :]), n [:j ]
343
+ }
344
+ cols [i ], typs [i ] = n , typ
345
+ }
346
+
347
+ // declare table
348
+ err := c .DeclareVTab (fmt .Sprintf (`CREATE TABLE "%s"."%s" (%s)` , dbname , tname , strings .Join (decls , "," )))
349
+ if err != nil {
350
+ return nil , err
351
+ }
352
+
353
+ // create table
354
+ vtab := & vtabUpdateTable {m .t , dbname , tname , cols , typs , make ([][]interface {}, 0 )}
355
+ m .tables [tname ] = vtab
356
+ return vtab , nil
357
+ }
358
+
359
+ func (m * vtabUpdateModule ) Connect (c * SQLiteConn , args []string ) (VTab , error ) {
360
+ return m .Create (c , args )
361
+ }
362
+
363
+ func (m * vtabUpdateModule ) DestroyModule () {}
364
+
365
+ type vtabUpdateTable struct {
366
+ t * testing.T
367
+ db string
368
+ name string
369
+ cols []string
370
+ typs []string
371
+ data [][]interface {}
372
+ }
373
+
374
+ func (t * vtabUpdateTable ) Open () (VTabCursor , error ) {
375
+ return & vtabUpdateCursor {t , 0 }, nil
376
+ }
377
+
378
+ func (t * vtabUpdateTable ) BestIndex (cst []InfoConstraint , ob []InfoOrderBy ) (* IndexResult , error ) {
379
+ return & IndexResult {Used : make ([]bool , len (cst ))}, nil
380
+ }
381
+
382
+ func (t * vtabUpdateTable ) Disconnect () error {
383
+ return nil
384
+ }
385
+
386
+ func (t * vtabUpdateTable ) Destroy () error {
387
+ return nil
388
+ }
389
+
390
+ func (t * vtabUpdateTable ) Insert (id interface {}, vals []interface {}) (int64 , error ) {
391
+ var i int64
392
+ if id == nil {
393
+ i , t .data = int64 (len (t .data )), append (t .data , vals )
394
+ return i , nil
395
+ }
396
+
397
+ var ok bool
398
+ i , ok = id .(int64 )
399
+ if ! ok {
400
+ return 0 , fmt .Errorf ("id is invalid type: %T" , id )
401
+ }
402
+
403
+ t .data [i ] = vals
404
+
405
+ return i , nil
406
+ }
407
+
408
+ func (t * vtabUpdateTable ) Update (id interface {}, vals []interface {}) error {
409
+ i , ok := id .(int64 )
410
+ if ! ok {
411
+ return fmt .Errorf ("id is invalid type: %T" , id )
412
+ }
413
+
414
+ if int (i ) >= len (t .data ) || i < 0 {
415
+ return fmt .Errorf ("invalid row id %d" , i )
416
+ }
417
+
418
+ t .data [int (i )] = vals
419
+
420
+ return nil
421
+ }
422
+
423
+ func (t * vtabUpdateTable ) Delete (id interface {}) error {
424
+ i , ok := id .(int64 )
425
+ if ! ok {
426
+ return fmt .Errorf ("id is invalid type: %T" , id )
427
+ }
428
+
429
+ if int (i ) >= len (t .data ) || i < 0 {
430
+ return fmt .Errorf ("invalid row id %d" , i )
431
+ }
432
+
433
+ t .data = append (t .data [:i ], t .data [i + 1 :]... )
434
+
435
+ return nil
436
+ }
437
+
438
+ type vtabUpdateCursor struct {
439
+ t * vtabUpdateTable
440
+ i int
441
+ }
442
+
443
+ func (c * vtabUpdateCursor ) Column (ctxt * SQLiteContext , col int ) error {
444
+ switch x := c .t .data [c .i ][col ].(type ) {
445
+ case []byte :
446
+ ctxt .ResultBlob (x )
447
+ case bool :
448
+ ctxt .ResultBool (x )
449
+ case float64 :
450
+ ctxt .ResultDouble (x )
451
+ case int :
452
+ ctxt .ResultInt (x )
453
+ case int64 :
454
+ ctxt .ResultInt64 (x )
455
+ case nil :
456
+ ctxt .ResultNull ()
457
+ case string :
458
+ ctxt .ResultText (x )
459
+ default :
460
+ ctxt .ResultText (fmt .Sprintf ("%v" , x ))
461
+ }
462
+
463
+ return nil
464
+ }
465
+
466
+ func (c * vtabUpdateCursor ) Filter (ixNum int , ixName string , vals []interface {}) error {
467
+ return nil
468
+ }
469
+
470
+ func (c * vtabUpdateCursor ) Next () error {
471
+ c .i ++
472
+ return nil
473
+ }
474
+
475
+ func (c * vtabUpdateCursor ) EOF () bool {
476
+ return c .i >= len (c .t .data )
477
+ }
478
+
479
+ func (c * vtabUpdateCursor ) Rowid () (int64 , error ) {
480
+ return int64 (c .i ), nil
481
+ }
482
+
483
+ func (c * vtabUpdateCursor ) Close () error {
484
+ return nil
485
+ }
0 commit comments