Skip to content

Commit 9aae3f2

Browse files
committed
impl: support session in golang sql interface
1 parent f46bcc5 commit 9aae3f2

File tree

3 files changed

+193
-42
lines changed

3 files changed

+193
-42
lines changed

chdb/driver/driver.go

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package chdb
1+
package chdbdriver
22

33
import (
44
"bytes"
@@ -7,53 +7,113 @@ import (
77
"database/sql/driver"
88
"fmt"
99
"reflect"
10+
"strings"
1011
"time"
1112

1213
"github.com/apache/arrow/go/v14/arrow"
1314
"github.com/apache/arrow/go/v14/arrow/array"
1415
"github.com/apache/arrow/go/v14/arrow/decimal128"
1516
"github.com/apache/arrow/go/v14/arrow/decimal256"
16-
wrapper "github.com/chdb-io/chdb-go/chdb"
17+
"github.com/chdb-io/chdb-go/chdb"
1718
"github.com/chdb-io/chdb-go/chdbstable"
1819

1920
"github.com/apache/arrow/go/v14/arrow/ipc"
2021
)
2122

23+
const sessionOptionKey = "session"
24+
const udfPathOptionKey = "udfPath"
25+
2226
func init() {
2327
sql.Register("chdb", Driver{})
2428
}
2529

30+
type queryHandle func(string, ...string) *chdbstable.LocalResult
31+
2632
type connector struct {
33+
udfPath string
34+
session *chdb.Session
2735
}
2836

2937
// Connect returns a connection to a database.
3038
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
31-
return &conn{}, nil
39+
cc := &conn{udfPath: c.udfPath, session: c.session}
40+
cc.SetupQueryFun()
41+
return cc, nil
3242
}
3343

3444
// Driver returns the underying Driver of the connector,
3545
// compatibility with the Driver method on sql.DB
3646
func (c *connector) Driver() driver.Driver { return Driver{} }
3747

48+
func parseConnectStr(str string) (ret map[string]string, err error) {
49+
ret = make(map[string]string)
50+
if len(str) == 0 {
51+
return
52+
}
53+
for _, kv := range strings.Split(str, ";") {
54+
parsed := strings.SplitN(kv, "=", 2)
55+
if len(parsed) != 2 {
56+
return nil, fmt.Errorf("invalid format for connection string, str: %s", kv)
57+
}
58+
59+
ret[strings.TrimSpace(parsed[0])] = strings.TrimSpace(parsed[1])
60+
}
61+
62+
return
63+
}
64+
func NewConnect(opts map[string]string) (ret *connector, err error) {
65+
ret = &connector{}
66+
sessionPath, ok := opts[sessionOptionKey]
67+
if ok {
68+
ret.session, err = chdb.NewSession(sessionPath)
69+
if err != nil {
70+
return nil, err
71+
}
72+
}
73+
udfPath, ok := opts[udfPathOptionKey]
74+
if ok {
75+
ret.udfPath = udfPath
76+
}
77+
return
78+
}
79+
3880
type Driver struct{}
3981

4082
// Open returns a new connection to the database.
4183
func (d Driver) Open(name string) (driver.Conn, error) {
42-
return &conn{}, nil
84+
cc, err := d.OpenConnector(name)
85+
if err != nil {
86+
return nil, err
87+
}
88+
return cc.Connect(context.Background())
4389
}
4490

4591
// OpenConnector expects the same format as driver.Open
46-
func (d Driver) OpenConnector(dataSourceName string) (driver.Connector, error) {
47-
return &connector{}, nil
92+
func (d Driver) OpenConnector(name string) (driver.Connector, error) {
93+
opts, err := parseConnectStr(name)
94+
if err != nil {
95+
return nil, err
96+
}
97+
return NewConnect(opts)
4898
}
4999

50100
type conn struct {
101+
udfPath string
102+
session *chdb.Session
103+
QueryFun queryHandle
51104
}
52105

53106
func (c *conn) Close() error {
54107
return nil
55108
}
56109

110+
func (c *conn) SetupQueryFun() {
111+
c.QueryFun = chdb.Query
112+
if c.session != nil {
113+
c.QueryFun = c.session.Query
114+
}
115+
}
116+
57117
func (c *conn) Query(query string, values []driver.Value) (driver.Rows, error) {
58118
namedValues := make([]driver.NamedValue, len(values))
59119
for i, value := range values {
@@ -67,7 +127,7 @@ func (c *conn) Query(query string, values []driver.Value) (driver.Rows, error) {
67127
}
68128

69129
func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
70-
result := wrapper.Query(query, "Arrow")
130+
result := c.QueryFun(query, "Arrow", c.udfPath)
71131
buf := result.Buf()
72132
if buf == nil {
73133
return nil, fmt.Errorf("result is nil")

chdb/driver/driver_test.go

Lines changed: 109 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,128 @@
1-
package chdb
1+
package chdbdriver
22

33
import (
44
"database/sql"
5+
"fmt"
6+
"os"
57
"testing"
8+
9+
"github.com/chdb-io/chdb-go/chdb"
610
)
711

812
func TestDb(t *testing.T) {
913
db, err := sql.Open("chdb", "")
1014
if err != nil {
11-
t.Errorf("open db fail")
15+
t.Errorf("open db fail, err:%s", err)
1216
}
1317
if db.Ping() != nil {
1418
t.Errorf("ping db fail")
1519
}
16-
{
17-
rows, err := db.Query(`SELECT 1,'abc'`)
20+
rows, err := db.Query(`SELECT 1,'abc'`)
21+
if err != nil {
22+
t.Errorf("run Query fail, err:%s", err)
23+
}
24+
cols, err := rows.Columns()
25+
if err != nil {
26+
t.Errorf("get result columns fail, err: %s", err)
27+
}
28+
if len(cols) != 2 {
29+
t.Errorf("select result columns length should be 2")
30+
}
31+
var (
32+
bar int
33+
foo string
34+
)
35+
defer rows.Close()
36+
for rows.Next() {
37+
err := rows.Scan(&bar, &foo)
1838
if err != nil {
19-
t.Errorf("run Query fail, err:%s", err)
39+
t.Errorf("scan fail, err: %s", err)
40+
}
41+
if bar != 1 {
42+
t.Errorf("expected error")
43+
}
44+
if foo != "abc" {
45+
t.Errorf("expected error")
46+
}
47+
}
48+
}
49+
50+
func TestDbWithOpt(t *testing.T) {
51+
for _, kv := range []struct {
52+
opt string
53+
condition bool
54+
}{
55+
{"", false},
56+
{"udfPath=qq", false},
57+
{"udfPath=qq;session=ss", false},
58+
{"session=sssss", false},
59+
{"session=s2;udfPath=u1", false},
60+
{"session=s3;udfPath=u2;fooobar=ssss", false},
61+
{"foo;bar", true},
62+
} {
63+
db, err := sql.Open("chdb", kv.opt)
64+
if (err != nil) != kv.condition {
65+
t.Errorf("open db fail, err: %s", err)
66+
}
67+
if db == nil {
68+
continue
69+
}
70+
if (db.Ping() != nil) != kv.condition {
71+
t.Errorf("ping db fail")
2072
}
21-
cols, err := rows.Columns()
73+
}
74+
}
75+
76+
func TestDbWithSession(t *testing.T) {
77+
sessionDir, err := os.MkdirTemp("", "unittest-sessiondata")
78+
if err != nil {
79+
t.Fatalf("create temp directory fail, err: %s", err)
80+
}
81+
defer os.RemoveAll(sessionDir)
82+
session, err := chdb.NewSession(sessionDir)
83+
if err != nil {
84+
t.Fatalf("new session fail, err: %s", err)
85+
}
86+
defer session.Cleanup()
87+
88+
session.Query("CREATE DATABASE IF NOT EXISTS testdb; " +
89+
"CREATE TABLE IF NOT EXISTS testdb.testtable (id UInt32) ENGINE = MergeTree() ORDER BY id;")
90+
91+
session.Query("USE testdb; INSERT INTO testtable VALUES (1), (2), (3);")
92+
93+
ret := session.Query("SELECT * FROM testtable;")
94+
if string(ret.Buf()) != "1\n2\n3\n" {
95+
t.Errorf("Query result should be 1\n2\n3\n, got %s", string(ret.Buf()))
96+
}
97+
db, err := sql.Open("chdb", fmt.Sprintf("session=%s", sessionDir))
98+
if err != nil {
99+
t.Fatalf("open db fail, err: %s", err)
100+
}
101+
if db.Ping() != nil {
102+
t.Fatalf("ping db fail, err: %s", err)
103+
}
104+
rows, err := db.Query("select * from testtable;")
105+
if err != nil {
106+
t.Fatalf("exec create function fail, err: %s", err)
107+
}
108+
defer rows.Close()
109+
cols, err := rows.Columns()
110+
if err != nil {
111+
t.Fatalf("get result columns fail, err: %s", err)
112+
}
113+
if len(cols) != 1 {
114+
t.Fatalf("result columns length shoule be 3, actual: %d", len(cols))
115+
}
116+
var bar = 0
117+
var count = 1
118+
for rows.Next() {
119+
err = rows.Scan(&bar)
22120
if err != nil {
23-
t.Errorf("get result columns fail, err: %s", err)
24-
}
25-
if len(cols) != 2 {
26-
t.Errorf("select result columns length should be 2")
27-
}
28-
var (
29-
bar int
30-
foo string
31-
)
32-
defer rows.Close()
33-
for rows.Next() {
34-
err := rows.Scan(&bar, &foo)
35-
if err != nil {
36-
t.Errorf("scan fail, err: %s", err)
37-
}
38-
if bar != 1 {
39-
t.Errorf("expected error")
40-
}
41-
if foo != "abc" {
42-
t.Errorf("expected error")
43-
}
121+
t.Fatalf("scan fail, err: %s", err)
122+
}
123+
if bar != count {
124+
t.Fatalf("result is not match, want: %d actual: %d", count, bar)
44125
}
126+
count++
45127
}
46128
}

chdb/session.go

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
)
1010

1111
type Session struct {
12-
path string
12+
path string
1313
isTemp bool
1414
}
1515

@@ -37,15 +37,24 @@ func NewSession(paths ...string) (*Session, error) {
3737

3838
// Query calls queryToBuffer with a default output format of "CSV" if not provided.
3939
func (s *Session) Query(queryStr string, outputFormats ...string) *chdbstable.LocalResult {
40-
outputFormat := "CSV" // Default value
41-
if len(outputFormats) > 0 {
42-
outputFormat = outputFormats[0]
43-
}
44-
return queryToBuffer(queryStr, outputFormat, s.path, "")
40+
outputFormat := "CSV" // Default value
41+
udfPath := ""
42+
switch len(outputFormats) {
43+
case 0:
44+
case 1:
45+
outputFormat = outputFormats[0]
46+
case 2:
47+
fallthrough
48+
default:
49+
outputFormat = outputFormats[0]
50+
udfPath = outputFormats[1]
51+
}
52+
return queryToBuffer(queryStr, outputFormat, s.path, udfPath)
4553
}
4654

47-
// Close closes the session and removes the temporary directory
48-
// temporary directory is created when NewSession was called with an empty path.
55+
// Close closes the session and removes the temporary directory
56+
//
57+
// temporary directory is created when NewSession was called with an empty path.
4958
func (s *Session) Close() {
5059
// Remove the temporary directory if it starts with "chdb_"
5160
if s.isTemp && filepath.Base(s.path)[:5] == "chdb_" {

0 commit comments

Comments
 (0)