Skip to content

Commit 18135fa

Browse files
committed
Adding unit test for VTable Insert/Update/Delete
1 parent f68bb95 commit 18135fa

File tree

2 files changed

+338
-1
lines changed

2 files changed

+338
-1
lines changed

sqlite3_vtable.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,14 @@ func goVUpdate(pVTab unsafe.Pointer, argc C.int, argv **C.sqlite3_value, pRowid
537537
if err != nil {
538538
return mPrintf("%s", err.Error())
539539
}
540-
vals = append(vals, conv.Interface())
540+
541+
// work around for SQLITE_NULL
542+
x := conv.Interface()
543+
if z, ok := x.([]byte); ok && z == nil {
544+
x = nil
545+
}
546+
547+
vals = append(vals, x)
541548
}
542549

543550
switch {

sqlite3_vtable_test.go

Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@ package sqlite3
88

99
import (
1010
"database/sql"
11+
"errors"
1112
"fmt"
1213
"os"
14+
"reflect"
15+
"strings"
1316
"testing"
1417
)
1518

@@ -148,8 +151,335 @@ func TestCreateModule(t *testing.T) {
148151
t.Fatalf("want %v but %v", intarray[i], value)
149152
}
150153
}
154+
151155
_, err = db.Exec("DROP TABLE vtab")
152156
if err != nil {
153157
t.Fatalf("couldn't drop virtual table: %v", err)
154158
}
155159
}
160+
161+
func TestVUpdate(t *testing.T) {
162+
tempFilename := TempFilename(t)
163+
defer os.Remove(tempFilename)
164+
165+
// create module
166+
updateMod := &vtabUpdateModule{t, make(map[string]*vtabUpdateTable)}
167+
168+
// register module
169+
sql.Register("sqlite3_TestVUpdate", &SQLiteDriver{
170+
ConnectHook: func(conn *SQLiteConn) error {
171+
return conn.CreateModule("updatetest", updateMod)
172+
},
173+
})
174+
175+
// connect
176+
db, err := sql.Open("sqlite3_TestVUpdate", tempFilename)
177+
if err != nil {
178+
t.Fatalf("could not open db: %v", err)
179+
}
180+
181+
// create test table
182+
_, err = db.Exec(`CREATE VIRTUAL TABLE vt USING updatetest(f1 integer, f2 text, f3 text)`)
183+
if err != nil {
184+
t.Fatalf("could not create updatetest vtable vt, got: %v", err)
185+
}
186+
187+
// check that table is defined properly
188+
if len(updateMod.tables) != 1 {
189+
t.Fatalf("expected exactly 1 table to exist, got: %d", len(updateMod.tables))
190+
}
191+
if _, ok := updateMod.tables["vt"]; !ok {
192+
t.Fatalf("expected table `vt` to exist in tables")
193+
}
194+
195+
// check nothing in updatetest
196+
rows, err := db.Query(`select * from vt`)
197+
if err != nil {
198+
t.Fatalf("could not query vt, got: %v", err)
199+
}
200+
i, err := getRowCount(rows)
201+
if err != nil {
202+
t.Fatalf("expected no error, got: %v", err)
203+
}
204+
if i != 0 {
205+
t.Fatalf("expected no rows in vt, got: %d", i)
206+
}
207+
208+
_, err = db.Exec(`delete from vt where f1 = 'yes'`)
209+
if err != nil {
210+
t.Fatalf("expected error on delete, got nil")
211+
}
212+
213+
// test bad column name
214+
_, err = db.Exec(`insert into vt (f4) values('a')`)
215+
if err == nil {
216+
t.Fatalf("expected error on insert, got nil")
217+
}
218+
219+
// insert to vt
220+
res, err := db.Exec(`insert into vt (f1, f2, f3) values (115, 'b', 'c'), (116, 'd', 'e')`)
221+
if err != nil {
222+
t.Fatalf("expected no error on insert, got: %v", err)
223+
}
224+
n, err := res.RowsAffected()
225+
if err != nil {
226+
t.Fatalf("expected no error, got: %v", err)
227+
}
228+
if n != 2 {
229+
t.Fatalf("expected 1 row affected, got: %d", n)
230+
}
231+
232+
// check vt table
233+
vt := updateMod.tables["vt"]
234+
if len(vt.data) != 2 {
235+
t.Fatalf("expected table vt to have exactly 2 rows, got: %d", len(vt.data))
236+
}
237+
if !reflect.DeepEqual(vt.data[0], []interface{}{int64(115), "b", "c"}) {
238+
t.Fatalf("expected table vt entry 0 to be [115 b c], instead: %v", vt.data[0])
239+
}
240+
if !reflect.DeepEqual(vt.data[1], []interface{}{int64(116), "d", "e"}) {
241+
t.Fatalf("expected table vt entry 1 to be [116 d e], instead: %v", vt.data[1])
242+
}
243+
244+
// query vt
245+
var f1 int
246+
var f2, f3 string
247+
err = db.QueryRow(`select * from vt where f1 = 115`).Scan(&f1, &f2, &f3)
248+
if err != nil {
249+
t.Fatalf("expected no error on vt query, got: %v", err)
250+
}
251+
252+
// check column values
253+
if f1 != 115 || f2 != "b" || f3 != "c" {
254+
t.Errorf("expected f1==115, f2==b, f3==c, got: %d, %q, %q", f1, f2, f3)
255+
}
256+
257+
// update vt
258+
res, err = db.Exec(`update vt set f1=117, f2='f' where f3='e'`)
259+
if err != nil {
260+
t.Fatalf("expected no error, got: %v", err)
261+
}
262+
n, err = res.RowsAffected()
263+
if err != nil {
264+
t.Fatalf("expected no error, got: %v", err)
265+
}
266+
if n != 1 {
267+
t.Fatalf("expected exactly one row updated, got: %d", n)
268+
}
269+
270+
// check vt table
271+
if len(vt.data) != 2 {
272+
t.Fatalf("expected table vt to have exactly 2 rows, got: %d", len(vt.data))
273+
}
274+
if !reflect.DeepEqual(vt.data[0], []interface{}{int64(115), "b", "c"}) {
275+
t.Fatalf("expected table vt entry 0 to be [115 b c], instead: %v", vt.data[0])
276+
}
277+
if !reflect.DeepEqual(vt.data[1], []interface{}{int64(117), "f", "e"}) {
278+
t.Fatalf("expected table vt entry 1 to be [117 f e], instead: %v", vt.data[1])
279+
}
280+
281+
// delete from vt
282+
res, err = db.Exec(`delete from vt where f1 = 117`)
283+
if err != nil {
284+
t.Fatalf("expected no error, got: %v", err)
285+
}
286+
n, err = res.RowsAffected()
287+
if err != nil {
288+
t.Fatalf("expected no error, got: %v", err)
289+
}
290+
if n != 1 {
291+
t.Fatalf("expected exactly one row deleted, got: %d", n)
292+
}
293+
294+
// check vt table
295+
if len(vt.data) != 1 {
296+
t.Fatalf("expected table vt to have exactly 1 row, got: %d", len(vt.data))
297+
}
298+
if !reflect.DeepEqual(vt.data[0], []interface{}{int64(115), "b", "c"}) {
299+
t.Fatalf("expected table vt entry 0 to be [115 b c], instead: %v", vt.data[0])
300+
}
301+
302+
// check updatetest has 1 result
303+
rows, err = db.Query(`select * from vt`)
304+
if err != nil {
305+
t.Fatalf("could not query vt, got: %v", err)
306+
}
307+
i, err = getRowCount(rows)
308+
if err != nil {
309+
t.Fatalf("expected no error, got: %v", err)
310+
}
311+
if i != 1 {
312+
t.Fatalf("expected 1 row in vt, got: %d", i)
313+
}
314+
}
315+
316+
func getRowCount(rows *sql.Rows) (int, error) {
317+
var i int
318+
for rows.Next() {
319+
i++
320+
}
321+
return i, nil
322+
}
323+
324+
type vtabUpdateModule struct {
325+
t *testing.T
326+
tables map[string]*vtabUpdateTable
327+
}
328+
329+
func (m *vtabUpdateModule) Create(c *SQLiteConn, args []string) (VTab, error) {
330+
if len(args) < 2 {
331+
return nil, errors.New("must declare at least one column")
332+
}
333+
334+
// get database name, table name, and column declarations ...
335+
dbname, tname, decls := args[1], args[2], args[3:]
336+
337+
// extract column names + types from parameters declarations
338+
cols, typs := make([]string, len(decls)), make([]string, len(decls))
339+
for i := 0; i < len(decls); i++ {
340+
n, typ := decls[i], ""
341+
if j := strings.IndexAny(n, " \t\n"); j != -1 {
342+
typ, n = strings.TrimSpace(n[j+1:]), n[:j]
343+
}
344+
cols[i], typs[i] = n, typ
345+
}
346+
347+
// declare table
348+
err := c.DeclareVTab(fmt.Sprintf(`CREATE TABLE "%s"."%s" (%s)`, dbname, tname, strings.Join(decls, ",")))
349+
if err != nil {
350+
return nil, err
351+
}
352+
353+
// create table
354+
vtab := &vtabUpdateTable{m.t, dbname, tname, cols, typs, make([][]interface{}, 0)}
355+
m.tables[tname] = vtab
356+
return vtab, nil
357+
}
358+
359+
func (m *vtabUpdateModule) Connect(c *SQLiteConn, args []string) (VTab, error) {
360+
return m.Create(c, args)
361+
}
362+
363+
func (m *vtabUpdateModule) DestroyModule() {}
364+
365+
type vtabUpdateTable struct {
366+
t *testing.T
367+
db string
368+
name string
369+
cols []string
370+
typs []string
371+
data [][]interface{}
372+
}
373+
374+
func (t *vtabUpdateTable) Open() (VTabCursor, error) {
375+
return &vtabUpdateCursor{t, 0}, nil
376+
}
377+
378+
func (t *vtabUpdateTable) BestIndex(cst []InfoConstraint, ob []InfoOrderBy) (*IndexResult, error) {
379+
return &IndexResult{Used: make([]bool, len(cst))}, nil
380+
}
381+
382+
func (t *vtabUpdateTable) Disconnect() error {
383+
return nil
384+
}
385+
386+
func (t *vtabUpdateTable) Destroy() error {
387+
return nil
388+
}
389+
390+
func (t *vtabUpdateTable) Insert(id interface{}, vals []interface{}) (int64, error) {
391+
var i int64
392+
if id == nil {
393+
i, t.data = int64(len(t.data)), append(t.data, vals)
394+
return i, nil
395+
}
396+
397+
var ok bool
398+
i, ok = id.(int64)
399+
if !ok {
400+
return 0, fmt.Errorf("id is invalid type: %T", id)
401+
}
402+
403+
t.data[i] = vals
404+
405+
return i, nil
406+
}
407+
408+
func (t *vtabUpdateTable) Update(id interface{}, vals []interface{}) error {
409+
i, ok := id.(int64)
410+
if !ok {
411+
return fmt.Errorf("id is invalid type: %T", id)
412+
}
413+
414+
if int(i) >= len(t.data) || i < 0 {
415+
return fmt.Errorf("invalid row id %d", i)
416+
}
417+
418+
t.data[int(i)] = vals
419+
420+
return nil
421+
}
422+
423+
func (t *vtabUpdateTable) Delete(id interface{}) error {
424+
i, ok := id.(int64)
425+
if !ok {
426+
return fmt.Errorf("id is invalid type: %T", id)
427+
}
428+
429+
if int(i) >= len(t.data) || i < 0 {
430+
return fmt.Errorf("invalid row id %d", i)
431+
}
432+
433+
t.data = append(t.data[:i], t.data[i+1:]...)
434+
435+
return nil
436+
}
437+
438+
type vtabUpdateCursor struct {
439+
t *vtabUpdateTable
440+
i int
441+
}
442+
443+
func (c *vtabUpdateCursor) Column(ctxt *SQLiteContext, col int) error {
444+
switch x := c.t.data[c.i][col].(type) {
445+
case []byte:
446+
ctxt.ResultBlob(x)
447+
case bool:
448+
ctxt.ResultBool(x)
449+
case float64:
450+
ctxt.ResultDouble(x)
451+
case int:
452+
ctxt.ResultInt(x)
453+
case int64:
454+
ctxt.ResultInt64(x)
455+
case nil:
456+
ctxt.ResultNull()
457+
case string:
458+
ctxt.ResultText(x)
459+
default:
460+
ctxt.ResultText(fmt.Sprintf("%v", x))
461+
}
462+
463+
return nil
464+
}
465+
466+
func (c *vtabUpdateCursor) Filter(ixNum int, ixName string, vals []interface{}) error {
467+
return nil
468+
}
469+
470+
func (c *vtabUpdateCursor) Next() error {
471+
c.i++
472+
return nil
473+
}
474+
475+
func (c *vtabUpdateCursor) EOF() bool {
476+
return c.i >= len(c.t.data)
477+
}
478+
479+
func (c *vtabUpdateCursor) Rowid() (int64, error) {
480+
return int64(c.i), nil
481+
}
482+
483+
func (c *vtabUpdateCursor) Close() error {
484+
return nil
485+
}

0 commit comments

Comments
 (0)