From e61a9bc37ab290f9159102a38accc1bd81806303 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Loi=CC=88c=20Alleyne?= Date: Thu, 14 Nov 2024 10:26:38 -0500 Subject: [PATCH] NextBatch --- cmd/main.go | 25 +++++++++++++++------ reader/reader.go | 58 +++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 73 insertions(+), 10 deletions(-) diff --git a/cmd/main.go b/cmd/main.go index 56a1730..165ba86 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -59,14 +59,25 @@ func main() { log.Printf("elapsed: %v\n", time.Since(start)) i := 0 - for r.Next() { - rec := r.Record() - _, err := rec.MarshalJSON() - if err != nil { - fmt.Printf("error marshaling record: %v\n", err) + // for r.Next() { + // rec := r.Record() + // _, err := rec.MarshalJSON() + // if err != nil { + // fmt.Printf("error marshaling record: %v\n", err) + // } + // // fmt.Printf("\nmarshaled record :\n%v\n", string(rj)) + // i++ + // } + for r.NextBatch(1024) { + recs := r.RecordBatch() + for _, rec := range recs { + _, err := rec.MarshalJSON() + if err != nil { + fmt.Printf("error marshaling record: %v\n", err) + } + // fmt.Printf("\nmarshaled record :\n%v\n", string(rj)) + i++ } - // fmt.Printf("\nmarshaled record :\n%v\n", string(rj)) - i++ } log.Println("records", r.Count(), i) } diff --git a/reader/reader.go b/reader/reader.go index c437f23..cd03c9b 100644 --- a/reader/reader.go +++ b/reader/reader.go @@ -49,6 +49,7 @@ type DataReader struct { bldMap *fieldPos ldr *dataLoader cur arrow.Record + curBatch []arrow.Record readerCtx context.Context readCancel func() err error @@ -111,6 +112,8 @@ func NewReader(schema *arrow.Schema, source DataSource, opts ...Option) (*DataRe return r, nil } +// ReadToRecord decodes a datum directly to an arrow.Record. The record +// should be released by the user when done with it. func (r *DataReader) ReadToRecord(a any) (arrow.Record, error) { var err error defer func() { @@ -147,6 +150,52 @@ func (r *DataReader) ReadToRecord(a any) (arrow.Record, error) { return r.bld.NewRecord(), nil } +// NextBatch returns whether a []arrow.Record of a specified size can be received +// from the converted record queue. Will still return true if the queue channel is closed and +// last batch of records available < batch size specified. +// The user should check Err() after a call to NextBatch that returned false to check +// if an error took place. +func (r *DataReader) NextBatch(batchSize int) bool { + if batchSize < 1 { + batchSize = 1 + } + if len(r.curBatch) != 0 { + for _, rec := range r.curBatch { + rec.Release() + } + r.curBatch = []arrow.Record{} + } + r.wg.Wait() + + for len(r.curBatch) <= batchSize { + select { + case rec, ok := <-r.recChan: + if !ok && rec == nil { + if len(r.curBatch) > 0 { + goto jump + } + return false + } + if rec != nil { + r.curBatch = append(r.curBatch, rec) + } + case <-r.bldDone: + if len(r.recChan) > 0 { + r.cur = <-r.recChan + } + case <-r.readerCtx.Done(): + return false + } + } + +jump: + if r.err != nil { + return false + } + + return len(r.curBatch) > 0 +} + // Next returns whether a Record can be received from the converted record queue. // The user should check Err() after a call to Next that returned false to check // if an error took place. @@ -156,7 +205,6 @@ func (r *DataReader) Next() bool { r.cur.Release() r.cur = nil } - r.wg.Wait() select { case r.cur, ok = <-r.recChan: @@ -195,8 +243,12 @@ func (r *DataReader) Opts() []Option { return r.opts } // Record returns the current Arrow record. // It is valid until the next call to Next. -func (r *DataReader) Record() arrow.Record { return r.cur } -func (r *DataReader) Schema() *arrow.Schema { return r.schema } +func (r *DataReader) Record() arrow.Record { return r.cur } + +// Record returns the current Arrow record batch. +// It is valid until the next call to NextBatch. +func (r *DataReader) RecordBatch() []arrow.Record { return r.curBatch } +func (r *DataReader) Schema() *arrow.Schema { return r.schema } // Err returns the last error encountered during the reading of data. func (r *DataReader) Err() error { return r.err }