Skip to content

Commit 48ad725

Browse files
authored
Merge pull request #1 from chdb-io/feat/sql-interface
add: golang sql interface
2 parents d964ddb + 08bcbb4 commit 48ad725

File tree

13 files changed

+623
-121
lines changed

13 files changed

+623
-121
lines changed

chdb/driver/driver.go

Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
package chdbdriver
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"database/sql"
7+
"database/sql/driver"
8+
"fmt"
9+
"reflect"
10+
"strings"
11+
"time"
12+
13+
"github.com/apache/arrow/go/v14/arrow"
14+
"github.com/apache/arrow/go/v14/arrow/array"
15+
"github.com/apache/arrow/go/v14/arrow/decimal128"
16+
"github.com/apache/arrow/go/v14/arrow/decimal256"
17+
"github.com/chdb-io/chdb-go/chdb"
18+
"github.com/chdb-io/chdb-go/chdbstable"
19+
20+
"github.com/apache/arrow/go/v14/arrow/ipc"
21+
)
22+
23+
const sessionOptionKey = "session"
24+
const udfPathOptionKey = "udfPath"
25+
26+
func init() {
27+
sql.Register("chdb", Driver{})
28+
}
29+
30+
type queryHandle func(string, ...string) *chdbstable.LocalResult
31+
32+
type connector struct {
33+
udfPath string
34+
session *chdb.Session
35+
}
36+
37+
// Connect returns a connection to a database.
38+
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
39+
cc := &conn{udfPath: c.udfPath, session: c.session}
40+
cc.SetupQueryFun()
41+
return cc, nil
42+
}
43+
44+
// Driver returns the underying Driver of the connector,
45+
// compatibility with the Driver method on sql.DB
46+
func (c *connector) Driver() driver.Driver { return Driver{} }
47+
48+
func parseConnectStr(str string) (ret map[string]string, err error) {
49+
ret = make(map[string]string)
50+
if len(str) == 0 {
51+
return
52+
}
53+
for _, kv := range strings.Split(str, ";") {
54+
parsed := strings.SplitN(kv, "=", 2)
55+
if len(parsed) != 2 {
56+
return nil, fmt.Errorf("invalid format for connection string, str: %s", kv)
57+
}
58+
59+
ret[strings.TrimSpace(parsed[0])] = strings.TrimSpace(parsed[1])
60+
}
61+
62+
return
63+
}
64+
func NewConnect(opts map[string]string) (ret *connector, err error) {
65+
ret = &connector{}
66+
sessionPath, ok := opts[sessionOptionKey]
67+
if ok {
68+
ret.session, err = chdb.NewSession(sessionPath)
69+
if err != nil {
70+
return nil, err
71+
}
72+
}
73+
udfPath, ok := opts[udfPathOptionKey]
74+
if ok {
75+
ret.udfPath = udfPath
76+
}
77+
return
78+
}
79+
80+
type Driver struct{}
81+
82+
// Open returns a new connection to the database.
83+
func (d Driver) Open(name string) (driver.Conn, error) {
84+
cc, err := d.OpenConnector(name)
85+
if err != nil {
86+
return nil, err
87+
}
88+
return cc.Connect(context.Background())
89+
}
90+
91+
// OpenConnector expects the same format as driver.Open
92+
func (d Driver) OpenConnector(name string) (driver.Connector, error) {
93+
opts, err := parseConnectStr(name)
94+
if err != nil {
95+
return nil, err
96+
}
97+
return NewConnect(opts)
98+
}
99+
100+
type conn struct {
101+
udfPath string
102+
session *chdb.Session
103+
QueryFun queryHandle
104+
}
105+
106+
func (c *conn) Close() error {
107+
return nil
108+
}
109+
110+
func (c *conn) SetupQueryFun() {
111+
c.QueryFun = chdb.Query
112+
if c.session != nil {
113+
c.QueryFun = c.session.Query
114+
}
115+
}
116+
117+
func (c *conn) Query(query string, values []driver.Value) (driver.Rows, error) {
118+
namedValues := make([]driver.NamedValue, len(values))
119+
for i, value := range values {
120+
namedValues[i] = driver.NamedValue{
121+
// nb: Name field is optional
122+
Ordinal: i,
123+
Value: value,
124+
}
125+
}
126+
return c.QueryContext(context.Background(), query, namedValues)
127+
}
128+
129+
func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
130+
result := c.QueryFun(query, "Arrow", c.udfPath)
131+
buf := result.Buf()
132+
if buf == nil {
133+
return nil, fmt.Errorf("result is nil")
134+
}
135+
reader, err := ipc.NewFileReader(bytes.NewReader(buf))
136+
if err != nil {
137+
return nil, err
138+
}
139+
return &rows{localResult: result, reader: reader}, nil
140+
}
141+
142+
func (c *conn) Begin() (driver.Tx, error) {
143+
return nil, fmt.Errorf("does not support Transcation")
144+
}
145+
146+
func (c *conn) Prepare(query string) (driver.Stmt, error) {
147+
return c.PrepareContext(context.Background(), query)
148+
}
149+
150+
func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
151+
return nil, fmt.Errorf("does not support prepare statement")
152+
}
153+
154+
// todo: func(c *conn) Prepare(query string)
155+
// todo: func(c *conn) PrepareContext(ctx context.Context, query string)
156+
// todo: prepared statment
157+
158+
type rows struct {
159+
localResult *chdbstable.LocalResult
160+
reader *ipc.FileReader
161+
curRecord arrow.Record
162+
curRow int64
163+
}
164+
165+
func (r *rows) Columns() (out []string) {
166+
sch := r.reader.Schema()
167+
for i := 0; i < sch.NumFields(); i++ {
168+
out = append(out, sch.Field(i).Name)
169+
}
170+
return
171+
}
172+
173+
func (r *rows) Close() error {
174+
if r.curRecord != nil {
175+
r.curRecord = nil
176+
}
177+
// ignore reader close
178+
_ = r.reader.Close()
179+
r.reader = nil
180+
r.localResult = nil
181+
return nil
182+
}
183+
184+
func (r *rows) Next(dest []driver.Value) error {
185+
if r.curRecord != nil && r.curRow == r.curRecord.NumRows() {
186+
r.curRecord = nil
187+
}
188+
for r.curRecord == nil {
189+
record, err := r.reader.Read()
190+
if err != nil {
191+
return err
192+
}
193+
if record.NumRows() == 0 {
194+
continue
195+
}
196+
r.curRecord = record
197+
r.curRow = 0
198+
}
199+
200+
for i, col := range r.curRecord.Columns() {
201+
if col.IsNull(int(r.curRow)) {
202+
dest[i] = nil
203+
continue
204+
}
205+
switch col := col.(type) {
206+
case *array.Boolean:
207+
dest[i] = col.Value(int(r.curRow))
208+
case *array.Int8:
209+
dest[i] = col.Value(int(r.curRow))
210+
case *array.Uint8:
211+
dest[i] = col.Value(int(r.curRow))
212+
case *array.Int16:
213+
dest[i] = col.Value(int(r.curRow))
214+
case *array.Uint16:
215+
dest[i] = col.Value(int(r.curRow))
216+
case *array.Int32:
217+
dest[i] = col.Value(int(r.curRow))
218+
case *array.Uint32:
219+
dest[i] = col.Value(int(r.curRow))
220+
case *array.Int64:
221+
dest[i] = col.Value(int(r.curRow))
222+
case *array.Uint64:
223+
dest[i] = col.Value(int(r.curRow))
224+
case *array.Float32:
225+
dest[i] = col.Value(int(r.curRow))
226+
case *array.Float64:
227+
dest[i] = col.Value(int(r.curRow))
228+
case *array.String:
229+
dest[i] = col.Value(int(r.curRow))
230+
case *array.LargeString:
231+
dest[i] = col.Value(int(r.curRow))
232+
case *array.Binary:
233+
dest[i] = col.Value(int(r.curRow))
234+
case *array.LargeBinary:
235+
dest[i] = col.Value(int(r.curRow))
236+
case *array.Date32:
237+
dest[i] = col.Value(int(r.curRow)).ToTime()
238+
case *array.Date64:
239+
dest[i] = col.Value(int(r.curRow)).ToTime()
240+
case *array.Time32:
241+
dest[i] = col.Value(int(r.curRow)).ToTime(col.DataType().(*arrow.Time32Type).Unit)
242+
case *array.Time64:
243+
dest[i] = col.Value(int(r.curRow)).ToTime(col.DataType().(*arrow.Time64Type).Unit)
244+
case *array.Timestamp:
245+
dest[i] = col.Value(int(r.curRow)).ToTime(col.DataType().(*arrow.TimestampType).Unit)
246+
case *array.Decimal128:
247+
dest[i] = col.Value(int(r.curRow))
248+
case *array.Decimal256:
249+
dest[i] = col.Value(int(r.curRow))
250+
default:
251+
return fmt.Errorf(
252+
"not yet implemented populating from columns of type " + col.DataType().String(),
253+
)
254+
}
255+
}
256+
257+
r.curRow++
258+
return nil
259+
}
260+
261+
func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
262+
return r.reader.Schema().Field(index).Type.String()
263+
}
264+
265+
func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) {
266+
return r.reader.Schema().Field(index).Nullable, true
267+
}
268+
269+
func (r *rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
270+
typ := r.reader.Schema().Field(index).Type
271+
switch dt := typ.(type) {
272+
case *arrow.Decimal128Type:
273+
return int64(dt.Precision), int64(dt.Scale), true
274+
case *arrow.Decimal256Type:
275+
return int64(dt.Precision), int64(dt.Scale), true
276+
}
277+
return 0, 0, false
278+
}
279+
280+
func (r *rows) ColumnTypeScanType(index int) reflect.Type {
281+
switch r.reader.Schema().Field(index).Type.ID() {
282+
case arrow.BOOL:
283+
return reflect.TypeOf(false)
284+
case arrow.INT8:
285+
return reflect.TypeOf(int8(0))
286+
case arrow.UINT8:
287+
return reflect.TypeOf(uint8(0))
288+
case arrow.INT16:
289+
return reflect.TypeOf(int16(0))
290+
case arrow.UINT16:
291+
return reflect.TypeOf(uint16(0))
292+
case arrow.INT32:
293+
return reflect.TypeOf(int32(0))
294+
case arrow.UINT32:
295+
return reflect.TypeOf(uint32(0))
296+
case arrow.INT64:
297+
return reflect.TypeOf(int64(0))
298+
case arrow.UINT64:
299+
return reflect.TypeOf(uint64(0))
300+
case arrow.FLOAT32:
301+
return reflect.TypeOf(float32(0))
302+
case arrow.FLOAT64:
303+
return reflect.TypeOf(float64(0))
304+
case arrow.DECIMAL128:
305+
return reflect.TypeOf(decimal128.Num{})
306+
case arrow.DECIMAL256:
307+
return reflect.TypeOf(decimal256.Num{})
308+
case arrow.BINARY:
309+
return reflect.TypeOf([]byte{})
310+
case arrow.STRING:
311+
return reflect.TypeOf(string(""))
312+
case arrow.TIME32, arrow.TIME64, arrow.DATE32, arrow.DATE64, arrow.TIMESTAMP:
313+
return reflect.TypeOf(time.Time{})
314+
}
315+
return nil
316+
}

0 commit comments

Comments
 (0)