Skip to content

Commit 1c0bbbe

Browse files
committed
impl: rows logic
1 parent 89e79f6 commit 1c0bbbe

File tree

2 files changed

+149
-17
lines changed

2 files changed

+149
-17
lines changed

chdb/driver/driver.go

Lines changed: 132 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@ import (
77
"database/sql/driver"
88
"fmt"
99
"reflect"
10+
"time"
1011

12+
"github.com/apache/arrow/go/v14/arrow"
13+
"github.com/apache/arrow/go/v14/arrow/array"
14+
"github.com/apache/arrow/go/v14/arrow/decimal128"
15+
"github.com/apache/arrow/go/v14/arrow/decimal256"
1116
wrapper "github.com/chdb-io/chdb-go/chdb"
1217
"github.com/chdb-io/chdb-go/chdbstable"
1318

@@ -93,6 +98,8 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
9398
type rows struct {
9499
localResult *chdbstable.LocalResult
95100
reader *ipc.FileReader
101+
curRecord arrow.Record
102+
curRow int64
96103
}
97104

98105
func (r *rows) Columns() (out []string) {
@@ -104,25 +111,146 @@ func (r *rows) Columns() (out []string) {
104111
}
105112

106113
func (r *rows) Close() error {
114+
if r.curRecord != nil {
115+
r.curRecord = nil
116+
}
117+
// ignore reader close
118+
_ = r.reader.Close()
119+
r.reader = nil
120+
r.localResult = nil
107121
return nil
108122
}
109123

110124
func (r *rows) Next(dest []driver.Value) error {
125+
if r.curRecord != nil && r.curRow == r.curRecord.NumRows() {
126+
r.curRecord = nil
127+
}
128+
for r.curRecord == nil {
129+
record, err := r.reader.Read()
130+
if err != nil {
131+
return err
132+
}
133+
if record.NumRows() == 0 {
134+
continue
135+
}
136+
r.curRecord = record
137+
r.curRow = 0
138+
}
139+
140+
for i, col := range r.curRecord.Columns() {
141+
if col.IsNull(int(r.curRow)) {
142+
dest[i] = nil
143+
continue
144+
}
145+
switch col := col.(type) {
146+
case *array.Boolean:
147+
dest[i] = col.Value(int(r.curRow))
148+
case *array.Int8:
149+
dest[i] = col.Value(int(r.curRow))
150+
case *array.Uint8:
151+
dest[i] = col.Value(int(r.curRow))
152+
case *array.Int16:
153+
dest[i] = col.Value(int(r.curRow))
154+
case *array.Uint16:
155+
dest[i] = col.Value(int(r.curRow))
156+
case *array.Int32:
157+
dest[i] = col.Value(int(r.curRow))
158+
case *array.Uint32:
159+
dest[i] = col.Value(int(r.curRow))
160+
case *array.Int64:
161+
dest[i] = col.Value(int(r.curRow))
162+
case *array.Uint64:
163+
dest[i] = col.Value(int(r.curRow))
164+
case *array.Float32:
165+
dest[i] = col.Value(int(r.curRow))
166+
case *array.Float64:
167+
dest[i] = col.Value(int(r.curRow))
168+
case *array.String:
169+
dest[i] = col.Value(int(r.curRow))
170+
case *array.LargeString:
171+
dest[i] = col.Value(int(r.curRow))
172+
case *array.Binary:
173+
dest[i] = col.Value(int(r.curRow))
174+
case *array.LargeBinary:
175+
dest[i] = col.Value(int(r.curRow))
176+
case *array.Date32:
177+
dest[i] = col.Value(int(r.curRow)).ToTime()
178+
case *array.Date64:
179+
dest[i] = col.Value(int(r.curRow)).ToTime()
180+
case *array.Time32:
181+
dest[i] = col.Value(int(r.curRow)).ToTime(col.DataType().(*arrow.Time32Type).Unit)
182+
case *array.Time64:
183+
dest[i] = col.Value(int(r.curRow)).ToTime(col.DataType().(*arrow.Time64Type).Unit)
184+
case *array.Timestamp:
185+
dest[i] = col.Value(int(r.curRow)).ToTime(col.DataType().(*arrow.TimestampType).Unit)
186+
case *array.Decimal128:
187+
dest[i] = col.Value(int(r.curRow))
188+
case *array.Decimal256:
189+
dest[i] = col.Value(int(r.curRow))
190+
default:
191+
return fmt.Errorf(
192+
"not yet implemented populating from columns of type " + col.DataType().String(),
193+
)
194+
}
195+
}
196+
197+
r.curRow++
111198
return nil
112199
}
113200

114201
func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
115-
return ""
202+
return r.reader.Schema().Field(index).Type.String()
116203
}
117204

118205
func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) {
119-
return
206+
return r.reader.Schema().Field(index).Nullable, true
120207
}
121208

122209
func (r *rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
123-
return
210+
typ := r.reader.Schema().Field(index).Type
211+
switch dt := typ.(type) {
212+
case *arrow.Decimal128Type:
213+
return int64(dt.Precision), int64(dt.Scale), true
214+
case *arrow.Decimal256Type:
215+
return int64(dt.Precision), int64(dt.Scale), true
216+
}
217+
return 0, 0, false
124218
}
125219

126220
func (r *rows) ColumnTypeScanType(index int) reflect.Type {
127-
return reflect.TypeOf(nil)
221+
switch r.reader.Schema().Field(index).Type.ID() {
222+
case arrow.BOOL:
223+
return reflect.TypeOf(false)
224+
case arrow.INT8:
225+
return reflect.TypeOf(int8(0))
226+
case arrow.UINT8:
227+
return reflect.TypeOf(uint8(0))
228+
case arrow.INT16:
229+
return reflect.TypeOf(int16(0))
230+
case arrow.UINT16:
231+
return reflect.TypeOf(uint16(0))
232+
case arrow.INT32:
233+
return reflect.TypeOf(int32(0))
234+
case arrow.UINT32:
235+
return reflect.TypeOf(uint32(0))
236+
case arrow.INT64:
237+
return reflect.TypeOf(int64(0))
238+
case arrow.UINT64:
239+
return reflect.TypeOf(uint64(0))
240+
case arrow.FLOAT32:
241+
return reflect.TypeOf(float32(0))
242+
case arrow.FLOAT64:
243+
return reflect.TypeOf(float64(0))
244+
case arrow.DECIMAL128:
245+
return reflect.TypeOf(decimal128.Num{})
246+
case arrow.DECIMAL256:
247+
return reflect.TypeOf(decimal256.Num{})
248+
case arrow.BINARY:
249+
return reflect.TypeOf([]byte{})
250+
case arrow.STRING:
251+
return reflect.TypeOf(string(""))
252+
case arrow.TIME32, arrow.TIME64, arrow.DATE32, arrow.DATE64, arrow.TIMESTAMP:
253+
return reflect.TypeOf(time.Time{})
254+
}
255+
return nil
128256
}

chdb/driver/driver_test.go

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,6 @@ func TestDb(t *testing.T) {
1313
if db.Ping() != nil {
1414
t.Errorf("ping db fail")
1515
}
16-
{
17-
rows, err := db.Query("SELECT version()")
18-
if err != nil {
19-
t.Errorf("run Query fail, err:%s", err)
20-
}
21-
cols, err := rows.Columns()
22-
if err != nil {
23-
t.Errorf("get result columns fail, err: %s", err)
24-
}
25-
if len(cols) != 1 {
26-
t.Errorf("select version(), result columns length should be 1")
27-
}
28-
}
2916
{
3017
rows, err := db.Query(`SELECT 1,'abc'`)
3118
if err != nil {
@@ -38,5 +25,22 @@ func TestDb(t *testing.T) {
3825
if len(cols) != 2 {
3926
t.Errorf("select version(), result columns length should be 1")
4027
}
28+
var (
29+
bar int
30+
foo string
31+
)
32+
defer rows.Close()
33+
for rows.Next() {
34+
err := rows.Scan(&bar, &foo)
35+
if err != nil {
36+
t.Errorf("scan fail, err: %s", err)
37+
}
38+
if bar != 1 {
39+
t.Errorf("expected error")
40+
}
41+
if foo != "abc" {
42+
t.Errorf("expected error")
43+
}
44+
}
4145
}
4246
}

0 commit comments

Comments
 (0)