diff --git a/index/scorch/merge.go b/index/scorch/merge.go index eb94c12b8..9f1774b68 100644 --- a/index/scorch/merge.go +++ b/index/scorch/merge.go @@ -354,12 +354,8 @@ func (s *Scorch) planMergeAtSnapshot(ctx context.Context, return err } - switch segI := seg.(type) { - case segment.DiskStatsReporter: - totalBytesRead := segI.BytesRead() + prevBytesReadTotal - segI.ResetBytesRead(totalBytesRead) - seg = segI.(segment.Segment) - } + totalBytesRead := seg.BytesRead() + prevBytesReadTotal + seg.ResetBytesRead(totalBytesRead) oldNewDocNums = make(map[uint64][]uint64) for i, segNewDocNums := range newDocNums { @@ -438,9 +434,7 @@ type segmentMerge struct { func cumulateBytesRead(sbs []segment.Segment) uint64 { var rv uint64 for _, seg := range sbs { - if segI, diskStatsAvailable := seg.(segment.DiskStatsReporter); diskStatsAvailable { - rv += segI.BytesRead() - } + rv += seg.BytesRead() } return rv } diff --git a/index/scorch/reader_test.go b/index/scorch/reader_test.go index e1bca049c..68c7a3e57 100644 --- a/index/scorch/reader_test.go +++ b/index/scorch/reader_test.go @@ -153,6 +153,9 @@ func TestIndexReader(t *testing.T) { if err != nil { t.Errorf("unexpected error: %v", err) } + // Ignoring the BytesRead value, since it doesn't have + // relevance in this type of test + match.BytesRead = 0 if !reflect.DeepEqual(expectedMatch, match) { t.Errorf("got %#v, expected %#v", match, expectedMatch) } diff --git a/index/scorch/scorch.go b/index/scorch/scorch.go index 6eeaabfda..413684cf8 100644 --- a/index/scorch/scorch.go +++ b/index/scorch/scorch.go @@ -146,6 +146,7 @@ func NewScorch(storeName string, if ok { rv.onAsyncError = RegistryAsyncErrorCallbacks[aecbName] } + return rv, nil } diff --git a/index/scorch/snapshot_index.go b/index/scorch/snapshot_index.go index 511b9ef6b..5db501a88 100644 --- a/index/scorch/snapshot_index.go +++ b/index/scorch/snapshot_index.go @@ -60,8 +60,6 @@ var reflectStaticSizeIndexSnapshot int // in the kvConfig. var DefaultFieldTFRCacheThreshold uint64 = 10 -type diskStatsReporter segment.DiskStatsReporter - func init() { var is interface{} = IndexSnapshot{} reflectStaticSizeIndexSnapshot = int(reflect.TypeOf(is).Size()) @@ -149,18 +147,15 @@ func (i *IndexSnapshot) newIndexSnapshotFieldDict(field string, for index, segment := range i.segment { go func(index int, segment *SegmentSnapshot) { var prevBytesRead uint64 - seg, diskStatsAvailable := segment.segment.(diskStatsReporter) - if diskStatsAvailable { - prevBytesRead = seg.BytesRead() - } + prevBytesRead = segment.segment.BytesRead() + dict, err := segment.segment.Dictionary(field) if err != nil { results <- &asynchSegmentResult{err: err} } else { - if diskStatsAvailable { - atomic.AddUint64(&i.parent.stats.TotBytesReadAtQueryTime, - seg.BytesRead()-prevBytesRead) - } + atomic.AddUint64(&i.parent.stats.TotBytesReadAtQueryTime, + segment.segment.BytesRead()-prevBytesRead) + if randomLookup { results <- &asynchSegmentResult{dict: dict} } else { @@ -435,11 +430,8 @@ func (i *IndexSnapshot) Document(id string) (rv index.Document, err error) { segmentIndex, localDocNum := i.segmentIndexAndLocalDocNumFromGlobal(docNum) rvd := document.NewDocument(id) - var prevBytesRead uint64 - seg, diskStatsAvailable := i.segment[segmentIndex].segment.(segment.DiskStatsReporter) - if diskStatsAvailable { - prevBytesRead = seg.BytesRead() - } + prevBytesRead := i.segment[segmentIndex].segment.BytesRead() + err = i.segment[segmentIndex].VisitDocument(localDocNum, func(name string, typ byte, val []byte, pos []uint64) bool { if name == "_id" { return true @@ -471,8 +463,8 @@ func (i *IndexSnapshot) Document(id string) (rv index.Document, err error) { if err != nil { return nil, err } - if diskStatsAvailable { - delta := seg.BytesRead() - prevBytesRead + + if delta := i.segment[segmentIndex].segment.BytesRead() - prevBytesRead; delta > 0 { atomic.AddUint64(&i.parent.stats.TotBytesReadAtQueryTime, delta) } return rvd, nil @@ -549,17 +541,13 @@ func (is *IndexSnapshot) TermFieldReader(term []byte, field string, includeFreq, if rv.dicts == nil { rv.dicts = make([]segment.TermDictionary, len(is.segment)) for i, segment := range is.segment { - var prevBytesRead uint64 - segP, diskStatsAvailable := segment.segment.(diskStatsReporter) - if diskStatsAvailable { - prevBytesRead = segP.BytesRead() - } + prevBytesRead := segment.segment.BytesRead() dict, err := segment.segment.Dictionary(field) if err != nil { return nil, err } - if diskStatsAvailable { - atomic.AddUint64(&is.parent.stats.TotBytesReadAtQueryTime, segP.BytesRead()-prevBytesRead) + if bytesRead := segment.segment.BytesRead(); bytesRead > prevBytesRead { + atomic.AddUint64(&is.parent.stats.TotBytesReadAtQueryTime, bytesRead-prevBytesRead) } rv.dicts[i] = dict } @@ -567,8 +555,8 @@ func (is *IndexSnapshot) TermFieldReader(term []byte, field string, includeFreq, for i, segment := range is.segment { var prevBytesReadPL uint64 - if postings, diskStatsAvailable := rv.postings[i].(diskStatsReporter); diskStatsAvailable { - prevBytesReadPL = postings.BytesRead() + if rv.postings[i] != nil { + prevBytesReadPL = rv.postings[i].BytesRead() } pl, err := rv.dicts[i].PostingsList(term, segment.deleted, rv.postings[i]) if err != nil { @@ -577,21 +565,19 @@ func (is *IndexSnapshot) TermFieldReader(term []byte, field string, includeFreq, rv.postings[i] = pl var prevBytesReadItr uint64 - if itr, diskStatsAvailable := rv.iterators[i].(diskStatsReporter); diskStatsAvailable { - prevBytesReadItr = itr.BytesRead() + if rv.iterators[i] != nil { + prevBytesReadItr = rv.iterators[i].BytesRead() } rv.iterators[i] = pl.Iterator(includeFreq, includeNorm, includeTermVectors, rv.iterators[i]) - if postings, diskStatsAvailable := pl.(diskStatsReporter); diskStatsAvailable && - prevBytesReadPL < postings.BytesRead() { + if bytesRead := rv.postings[i].BytesRead(); prevBytesReadPL < bytesRead { atomic.AddUint64(&is.parent.stats.TotBytesReadAtQueryTime, - postings.BytesRead()-prevBytesReadPL) + bytesRead-prevBytesReadPL) } - if itr, diskStatsAvailable := rv.iterators[i].(diskStatsReporter); diskStatsAvailable && - prevBytesReadItr < itr.BytesRead() { + if bytesRead := rv.iterators[i].BytesRead(); prevBytesReadItr < bytesRead { atomic.AddUint64(&is.parent.stats.TotBytesReadAtQueryTime, - itr.BytesRead()-prevBytesReadItr) + bytesRead-prevBytesReadItr) } } atomic.AddUint64(&is.parent.stats.TotTermSearchersStarted, uint64(1)) @@ -711,17 +697,13 @@ func (i *IndexSnapshot) documentVisitFieldTermsOnSegment( } if ssvOk && ssv != nil && len(vFields) > 0 { - var prevBytesRead uint64 - ssvp, diskStatsAvailable := ssv.(segment.DiskStatsReporter) - if diskStatsAvailable { - prevBytesRead = ssvp.BytesRead() - } + prevBytesRead := ss.segment.BytesRead() dvs, err = ssv.VisitDocValues(localDocNum, fields, visitor, dvs) if err != nil { return nil, nil, err } - if diskStatsAvailable { - atomic.AddUint64(&i.parent.stats.TotBytesReadAtQueryTime, ssvp.BytesRead()-prevBytesRead) + if delta := ss.segment.BytesRead() - prevBytesRead; delta > 0 { + atomic.AddUint64(&i.parent.stats.TotBytesReadAtQueryTime, delta) } } @@ -889,6 +871,10 @@ func (i *IndexSnapshot) CopyTo(d index.Directory) error { return copyBolt.Sync() } +func (s *IndexSnapshot) UpdateIOStats(val uint64) { + atomic.AddUint64(&s.parent.stats.TotBytesReadAtQueryTime, val) +} + func (i *IndexSnapshot) GetSpatialAnalyzerPlugin(typ string) ( index.SpatialAnalyzerPlugin, error) { var rv index.SpatialAnalyzerPlugin diff --git a/index/scorch/snapshot_index_tfr.go b/index/scorch/snapshot_index_tfr.go index 7283b1371..bd20eb2e6 100644 --- a/index/scorch/snapshot_index_tfr.go +++ b/index/scorch/snapshot_index_tfr.go @@ -76,11 +76,7 @@ func (i *IndexSnapshotTermFieldReader) Next(preAlloced *index.TermFieldDoc) (*in } // find the next hit for i.segmentOffset < len(i.iterators) { - prevBytesRead := uint64(0) - itr, diskStatsAvailable := i.iterators[i.segmentOffset].(segment.DiskStatsReporter) - if diskStatsAvailable { - prevBytesRead = itr.BytesRead() - } + prevBytesRead := i.iterators[i.segmentOffset].BytesRead() next, err := i.iterators[i.segmentOffset].Next() if err != nil { return nil, err @@ -98,9 +94,8 @@ func (i *IndexSnapshotTermFieldReader) Next(preAlloced *index.TermFieldDoc) (*in // this is because there are chances of having a series of loadChunk calls, // and they have to be added together before sending the bytesRead at this point // upstream. - if diskStatsAvailable { - delta := itr.BytesRead() - prevBytesRead - atomic.AddUint64(&i.snapshot.parent.stats.TotBytesReadAtQueryTime, uint64(delta)) + if delta := i.iterators[i.segmentOffset].BytesRead() - prevBytesRead; delta > 0 { + rv.BytesRead = delta } return rv, nil diff --git a/index_impl.go b/index_impl.go index 1945568ea..407f1ff5b 100644 --- a/index_impl.go +++ b/index_impl.go @@ -26,6 +26,7 @@ import ( "time" "github.com/blevesearch/bleve/v2/document" + "github.com/blevesearch/bleve/v2/index/scorch" "github.com/blevesearch/bleve/v2/index/upsidedown" "github.com/blevesearch/bleve/v2/mapping" "github.com/blevesearch/bleve/v2/registry" @@ -527,12 +528,21 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr }() } } + var totalBytesRead uint64 + SendBytesRead := func(bytesRead uint64) { + totalBytesRead = bytesRead + } + ctx = context.WithValue(ctx, collector.SearchIOStatsCallbackKey, + collector.SearchIOStatsCallbackFunc(SendBytesRead)) err = coll.Collect(ctx, searcher, indexReader) if err != nil { return nil, err } + if sr, ok := indexReader.(*scorch.IndexSnapshot); ok { + sr.UpdateIOStats(totalBytesRead) + } hits := coll.Results() var highlighter highlight.Highlighter diff --git a/index_test.go b/index_test.go index 590a132bb..a7914ee63 100644 --- a/index_test.go +++ b/index_test.go @@ -371,6 +371,7 @@ func TestBytesRead(t *testing.T) { typeFieldMapping := NewTextFieldMapping() typeFieldMapping.Store = false documentMapping.AddFieldMappingsAt("type", typeFieldMapping) + idx, err := NewUsing(tmpIndexPath, indexMapping, Config.DefaultIndexType, Config.DefaultMemKVStore, nil) if err != nil { t.Fatal(err) @@ -579,7 +580,7 @@ func TestBytesReadStored(t *testing.T) { stats, _ := idx.StatsMap()["index"].(map[string]interface{}) bytesRead, _ := stats["num_bytes_read_at_query_time"].(uint64) if bytesRead != 15792 { - t.Fatalf("expected the bytes read stat to be around 15792, got %v", err) + t.Fatalf("expected the bytes read stat to be around 15792, got %v", bytesRead) } prevBytesRead := bytesRead @@ -591,7 +592,7 @@ func TestBytesReadStored(t *testing.T) { stats, _ = idx.StatsMap()["index"].(map[string]interface{}) bytesRead, _ = stats["num_bytes_read_at_query_time"].(uint64) if bytesRead-prevBytesRead != 15 { - t.Fatalf("expected the bytes read stat to be around 15, got %v", err) + t.Fatalf("expected the bytes read stat to be around 15, got %v", bytesRead-prevBytesRead) } prevBytesRead = bytesRead @@ -660,7 +661,7 @@ func TestBytesReadStored(t *testing.T) { stats, _ = idx1.StatsMap()["index"].(map[string]interface{}) bytesRead, _ = stats["num_bytes_read_at_query_time"].(uint64) if bytesRead-prevBytesRead != 12 { - t.Fatalf("expected the bytes read stat to be around 12, got %v", err) + t.Fatalf("expected the bytes read stat to be around 12, got %v", bytesRead-prevBytesRead) } prevBytesRead = bytesRead @@ -674,7 +675,7 @@ func TestBytesReadStored(t *testing.T) { bytesRead, _ = stats["num_bytes_read_at_query_time"].(uint64) if bytesRead-prevBytesRead != 646 { - t.Fatalf("expected the bytes read stat to be around 646, got %v", err) + t.Fatalf("expected the bytes read stat to be around 646, got %v", bytesRead-prevBytesRead) } } diff --git a/search/collector/topn.go b/search/collector/topn.go index 13d31e06f..7c9db9ff0 100644 --- a/search/collector/topn.go +++ b/search/collector/topn.go @@ -49,6 +49,10 @@ type collectorCompare func(i, j *search.DocumentMatch) int type collectorFixup func(d *search.DocumentMatch) error +const SearchIOStatsCallbackKey = "_search_io_stats_callback_key" + +type SearchIOStatsCallbackFunc func(uint64) + // TopNCollector collects the top N hits, optionally skipping some results type TopNCollector struct { size int @@ -197,7 +201,7 @@ func (hc *TopNCollector) Collect(ctx context.Context, searcher search.Searcher, } hc.needDocIds = hc.needDocIds || loadID - + var totalBytesRead uint64 select { case <-ctx.Done(): return ctx.Err() @@ -205,6 +209,7 @@ func (hc *TopNCollector) Collect(ctx context.Context, searcher search.Searcher, next, err = searcher.Next(searchContext) } for err == nil && next != nil { + totalBytesRead += next.BytesRead if hc.total%CheckDoneEvery == 0 { select { case <-ctx.Done(): @@ -226,6 +231,11 @@ func (hc *TopNCollector) Collect(ctx context.Context, searcher search.Searcher, next, err = searcher.Next(searchContext) } + statsCallbackFn := ctx.Value(SearchIOStatsCallbackKey) + if statsCallbackFn != nil { + statsCallbackFn.(SearchIOStatsCallbackFunc)(totalBytesRead) + } + // help finalize/flush the results in case // of custom document match handlers. err = dmHandler(nil) diff --git a/search/scorer/scorer_conjunction.go b/search/scorer/scorer_conjunction.go index f3c81a78c..f5dd8ca54 100644 --- a/search/scorer/scorer_conjunction.go +++ b/search/scorer/scorer_conjunction.go @@ -41,7 +41,13 @@ func NewConjunctionQueryScorer(options search.SearcherOptions) *ConjunctionQuery options: options, } } - +func getTotalBytesRead(matches []*search.DocumentMatch) uint64 { + var rv uint64 + for _, match := range matches { + rv += match.BytesRead + } + return rv +} func (s *ConjunctionQueryScorer) Score(ctx *search.SearchContext, constituents []*search.DocumentMatch) *search.DocumentMatch { var sum float64 var childrenExplanations []*search.Explanation @@ -67,6 +73,7 @@ func (s *ConjunctionQueryScorer) Score(ctx *search.SearchContext, constituents [ rv.Expl = newExpl rv.FieldTermLocations = search.MergeFieldTermLocations( rv.FieldTermLocations, constituents[1:]) + rv.BytesRead = getTotalBytesRead(constituents) return rv } diff --git a/search/scorer/scorer_disjunction.go b/search/scorer/scorer_disjunction.go index 054e76fd4..fd9d0bb0f 100644 --- a/search/scorer/scorer_disjunction.go +++ b/search/scorer/scorer_disjunction.go @@ -78,6 +78,6 @@ func (s *DisjunctionQueryScorer) Score(ctx *search.SearchContext, constituents [ rv.Expl = newExpl rv.FieldTermLocations = search.MergeFieldTermLocations( rv.FieldTermLocations, constituents[1:]) - + rv.BytesRead = getTotalBytesRead(constituents) return rv } diff --git a/search/scorer/scorer_term.go b/search/scorer/scorer_term.go index ca268648b..ce5f202d8 100644 --- a/search/scorer/scorer_term.go +++ b/search/scorer/scorer_term.go @@ -198,6 +198,6 @@ func (s *TermQueryScorer) Score(ctx *search.SearchContext, termMatch *index.Term }) } } - + rv.BytesRead = termMatch.BytesRead return rv } diff --git a/search/search.go b/search/search.go index d2dd33712..d45491b4a 100644 --- a/search/search.go +++ b/search/search.go @@ -166,6 +166,8 @@ type DocumentMatch struct { // be later incorporated into the Locations map when search // results are completed FieldTermLocations []FieldTermLocation `json:"-"` + + BytesRead uint64 `json:"-"` } func (dm *DocumentMatch) AddFieldValue(name string, value interface{}) { diff --git a/search/searcher/search_disjunction.go b/search/searcher/search_disjunction.go index 4cee46841..a2f1cf2ab 100644 --- a/search/searcher/search_disjunction.go +++ b/search/searcher/search_disjunction.go @@ -16,6 +16,7 @@ package searcher import ( "fmt" + "github.com/blevesearch/bleve/v2/search" index "github.com/blevesearch/bleve_index_api" ) diff --git a/search/searcher/search_disjunction_slice.go b/search/searcher/search_disjunction_slice.go index 79fee9f4d..63ee7ef2e 100644 --- a/search/searcher/search_disjunction_slice.go +++ b/search/searcher/search_disjunction_slice.go @@ -156,7 +156,6 @@ func (s *DisjunctionSliceSearcher) updateMatches() error { matchingIdxs = matchingIdxs[:0] } } - matching = append(matching, curr) matchingIdxs = append(matchingIdxs, i) } diff --git a/test/versus_test.go b/test/versus_test.go index eef6407a3..e96eae6ad 100644 --- a/test/versus_test.go +++ b/test/versus_test.go @@ -349,6 +349,7 @@ func testVersusSearches(vt *VersusTest, searchTemplates []string, idxA, idxB ble t.Errorf("\n doc: %d, body: %s", idx, strings.Join(vt.Bodies[idx], " ")) } } + if !reflect.DeepEqual(hitsA, hitsB) { t.Errorf("=========\nsearch: (%d) %s,\n res hits mismatch,\n len(hitsA): %d,\n len(hitsB): %d", i, bufBytes, len(hitsA), len(hitsB)) @@ -386,7 +387,9 @@ func hitsById(res *bleve.SearchResult) map[string]*search.DocumentMatch { hit.Score = math.Trunc(hit.Score*1000.0) / 1000.0 hit.IndexInternalID = nil hit.HitNumber = 0 - + // Ignoring the BytesRead value, since it doesn't have + // relevance in this type of test + hit.BytesRead = 0 rv[hit.ID] = hit }