From bbe4ae74a39e91dafc519b2fbfd7b69b3aa83b87 Mon Sep 17 00:00:00 2001 From: Aditi Ahuja Date: Thu, 5 Dec 2024 12:29:45 +0530 Subject: [PATCH 01/27] hacky start --- index_alias_impl.go | 30 ++++++++++++++++++++++++++++-- index_impl.go | 22 +++++++++++++++++++++- mapping/mapping.go | 3 +++ pre_search.go | 22 ++++++++++++++++++++++ search.go | 3 +++ search/searcher/search_term.go | 1 + search/util.go | 2 ++ 7 files changed, 80 insertions(+), 3 deletions(-) diff --git a/index_alias_impl.go b/index_alias_impl.go index 766240b4a..805ccbf0b 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -195,6 +195,7 @@ func (i *indexAliasImpl) SearchInContext(ctx context.Context, req *SearchRequest flags := &preSearchFlags{ knn: requestHasKNN(req), synonyms: !isMatchNoneQuery(req.Query), + bm25: true, // TODO Just force setting it to true to test } return preSearchDataSearch(ctx, req, flags, i.indexes...) } @@ -573,6 +574,7 @@ type asyncSearchResult struct { type preSearchFlags struct { knn bool synonyms bool + bm25 bool // needs presearch for this too } // preSearchRequired checks if preSearch is required and returns a boolean flag @@ -598,18 +600,30 @@ func preSearchRequired(req *SearchRequest, m mapping.IndexMapping) (*preSearchFl } } } - if knn || synonyms { + var bm25 bool + if !isMatchNoneQuery(req.Query) { + // todo fix this cuRRENTLY ALL INDEX mappings are BM25 mappings, need to fix + // this is just a placeholder. + if _, ok := m.(mapping.BM25Mapping); ok { + bm25 = true + } + } + + if knn || synonyms || bm25 { return &preSearchFlags{ knn: knn, synonyms: synonyms, + bm25: bm25, }, nil } return nil, nil } func preSearch(ctx context.Context, req *SearchRequest, flags *preSearchFlags, indexes ...Index) (*SearchResult, error) { + // create a dummy request with a match none query + // since we only care about the preSearchData in PreSearch var dummyQuery = req.Query - if !flags.synonyms { + if !flags.bm25 || !flags.synonyms { // create a dummy request with a match none query // since we only care about the preSearchData in PreSearch dummyQuery = query.NewMatchNoneQuery() @@ -691,6 +705,11 @@ func constructSynonymPreSearchData(rv map[string]map[string]interface{}, sr *Sea for _, index := range indexes { rv[index.Name()][search.SynonymPreSearchDataKey] = sr.SynonymResult } +} +func constructBM25PreSearchData(rv map[string]map[string]interface{}, sr *SearchResult, indexes []Index) map[string]map[string]interface{} { + for _, index := range indexes { + rv[index.Name()][search.BM25PreSearchDataKey] = sr.totalDocCount + } return rv } @@ -712,6 +731,8 @@ func constructPreSearchData(req *SearchRequest, flags *preSearchFlags, } if flags.synonyms { mergedOut = constructSynonymPreSearchData(mergedOut, preSearchResult, indexes) + if flags.bm25 { + mergedOut = constructBM25PreSearchData(mergedOut, preSearchResult, indexes) } return mergedOut, nil } @@ -820,6 +841,11 @@ func redistributePreSearchData(req *SearchRequest, indexes []Index) (map[string] if fts, ok := req.PreSearchData[search.SynonymPreSearchDataKey].(search.FieldTermSynonymMap); ok { for _, index := range indexes { rv[index.Name()][search.SynonymPreSearchDataKey] = fts + + // TODO Extend to more stats + if totalDocCount, ok := req.PreSearchData[search.BM25PreSearchDataKey].(uint64); ok { + for _, index := range indexes { + rv[index.Name()][search.BM25PreSearchDataKey] = totalDocCount } } return rv, nil diff --git a/index_impl.go b/index_impl.go index 289014f6c..ac32d6dbc 100644 --- a/index_impl.go +++ b/index_impl.go @@ -483,8 +483,9 @@ func (i *indexImpl) preSearch(ctx context.Context, req *SearchRequest, reader in return nil, err } } - + var fts search.FieldTermSynonymMap + var count uint64 if !isMatchNoneQuery(req.Query) { if synMap, ok := i.m.(mapping.SynonymMapping); ok { if synReader, ok := reader.(index.ThesaurusReader); ok { @@ -494,6 +495,13 @@ func (i *indexImpl) preSearch(ctx context.Context, req *SearchRequest, reader in } } } + + if _, ok := i.m.(mapping.BM25Mapping); ok { + count, err = reader.DocCount() + if err != nil { + return nil, err + } + } } return &SearchResult{ @@ -503,6 +511,7 @@ func (i *indexImpl) preSearch(ctx context.Context, req *SearchRequest, reader in }, Hits: knnHits, SynonymResult: fts, + totalDocCount: count, }, nil } @@ -558,6 +567,7 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr var fts search.FieldTermSynonymMap var skipSynonymCollector bool + var bm25TotalDocs uint64 var ok bool if req.PreSearchData != nil { for k, v := range req.PreSearchData { @@ -578,6 +588,14 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr } skipSynonymCollector = true } + skipKnnCollector = true + case search.BM25PreSearchDataKey: + if v != nil { + bm25TotalDocs, ok = v.(uint64) + if !ok { + return nil, fmt.Errorf("bm25 preSearchData must be of type uint64") + } + } } } } @@ -603,6 +621,8 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr if fts != nil { ctx = context.WithValue(ctx, search.FieldTermSynonymMapKey, fts) + if bm25TotalDocs > 0 { + ctx = context.WithValue(ctx, search.BM25MapKey, bm25TotalDocs) } // This callback and variable handles the tracking of bytes read diff --git a/mapping/mapping.go b/mapping/mapping.go index a6c1591b8..0992f9837 100644 --- a/mapping/mapping.go +++ b/mapping/mapping.go @@ -74,3 +74,6 @@ type SynonymMapping interface { SynonymSourceVisitor(visitor analysis.SynonymSourceVisitor) error } +type BM25Mapping interface { + IndexMapping +} diff --git a/pre_search.go b/pre_search.go index 5fd710d68..13f1b496a 100644 --- a/pre_search.go +++ b/pre_search.go @@ -82,6 +82,23 @@ func (s *synonymPreSearchResultProcessor) finalize(sr *SearchResult) { } } +type bm25PreSearchResultProcessor struct { + docCount uint64 // bm25 specific stats +} + +func newBM25PreSearchResultProcessor() *bm25PreSearchResultProcessor { + return &bm25PreSearchResultProcessor{} +} + +// TODO How will this work for queries other than term queries? +func (b *bm25PreSearchResultProcessor) add(sr *SearchResult, indexName string) { + b.docCount += (sr.totalDocCount) +} + +func (b *bm25PreSearchResultProcessor) finalize(sr *SearchResult) { + +} + // ----------------------------------------------------------------------------- // Master struct that can hold any number of presearch result processors type compositePreSearchResultProcessor struct { @@ -122,6 +139,11 @@ func createPreSearchResultProcessor(req *SearchRequest, flags *preSearchFlags) p processors = append(processors, synonymProcessor) } } + if flags.bm25 { + if bm25Processtor := newBM25PreSearchResultProcessor(); bm25Processtor != nil { + processors = append(processors, bm25Processtor) + } + } // Return based on the number of processors, optimizing for the common case of 1 processor // If there are no processors, return nil switch len(processors) { diff --git a/search.go b/search.go index 72bfca5e2..095d58ab2 100644 --- a/search.go +++ b/search.go @@ -447,6 +447,9 @@ type SearchResult struct { // special fields that are applicable only for search // results that are obtained from a presearch SynonymResult search.FieldTermSynonymMap `json:"synonym_result,omitempty"` + // The following fields are applicable to BM25 preSearch + // todo add more fields beyond docCount + totalDocCount uint64 } func (sr *SearchResult) Size() int { diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index c519d8d51..622a12c3d 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -65,6 +65,7 @@ func NewTermSearcherBytes(ctx context.Context, indexReader index.IndexReader, te func newTermSearcherFromReader(indexReader index.IndexReader, reader index.TermFieldReader, term []byte, field string, boost float64, options search.SearcherOptions) (*TermSearcher, error) { + // TODO Instead of passing count from reader here, do it using the presearch phase stats. count, err := indexReader.DocCount() if err != nil { _ = reader.Close() diff --git a/search/util.go b/search/util.go index 2e95f1180..0aa8c7438 100644 --- a/search/util.go +++ b/search/util.go @@ -137,6 +137,7 @@ type GeoBufferPoolCallbackFunc func() *s2.GeoBufferPool const KnnPreSearchDataKey = "_knn_pre_search_data_key" const SynonymPreSearchDataKey = "_synonym_pre_search_data_key" +const BM25PreSearchDataKey = "_bm25_pre_search_data_key" const PreSearchKey = "_presearch_key" @@ -162,6 +163,7 @@ func (f FieldTermSynonymMap) MergeWith(fts FieldTermSynonymMap) { } const FieldTermSynonymMapKey = "_field_term_synonym_map_key" +const BM25MapKey = "_bm25_map_key" const SearcherStartCallbackKey = "_searcher_start_callback_key" const SearcherEndCallbackKey = "_searcher_end_callback_key" From a679009326abe788cce5454f1aa0446b47044c66 Mon Sep 17 00:00:00 2001 From: Aditi Ahuja Date: Thu, 5 Dec 2024 16:05:11 +0530 Subject: [PATCH 02/27] use ctx in term srch --- search/searcher/search_disjunction.go | 2 +- search/searcher/search_term.go | 21 ++++++++++++++------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/search/searcher/search_disjunction.go b/search/searcher/search_disjunction.go index d165ec027..434c705e7 100644 --- a/search/searcher/search_disjunction.go +++ b/search/searcher/search_disjunction.go @@ -114,7 +114,7 @@ func optimizeCompositeSearcher(ctx context.Context, optimizationKind string, return nil, nil } - return newTermSearcherFromReader(indexReader, tfr, + return newTermSearcherFromReader(ctx, indexReader, tfr, []byte(optimizationKind), "*", 1.0, options) } diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index 622a12c3d..164ea31ad 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -16,6 +16,7 @@ package searcher import ( "context" + "fmt" "reflect" "github.com/blevesearch/bleve/v2/search" @@ -60,17 +61,23 @@ func NewTermSearcherBytes(ctx context.Context, indexReader index.IndexReader, te if err != nil { return nil, err } - return newTermSearcherFromReader(indexReader, reader, term, field, boost, options) + return newTermSearcherFromReader(ctx, indexReader, reader, term, field, boost, options) } -func newTermSearcherFromReader(indexReader index.IndexReader, reader index.TermFieldReader, +func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReader, reader index.TermFieldReader, term []byte, field string, boost float64, options search.SearcherOptions) (*TermSearcher, error) { - // TODO Instead of passing count from reader here, do it using the presearch phase stats. - count, err := indexReader.DocCount() - if err != nil { - _ = reader.Close() - return nil, err + count, ok := ctx.Value(search.BM25PreSearchDataKey).(uint64) + if !ok { + var err error + count, err = indexReader.DocCount() + if err != nil { + _ = reader.Close() + return nil, err + } + } else { + fmt.Printf("fetched from ctx \n") } + scorer := scorer.NewTermQueryScorer(term, field, boost, count, reader.Count(), options) return &TermSearcher{ indexReader: indexReader, From 2d8a43d9fa5f43c77d1626802331fd0516ef0587 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Fri, 6 Dec 2024 11:25:28 +0530 Subject: [PATCH 03/27] field cardinality temp save --- index/scorch/snapshot_index.go | 14 ++++++++++---- index/scorch/snapshot_index_dict.go | 10 ++++++---- index_impl.go | 2 +- pre_search.go | 13 ++++++++++--- search.go | 5 +++-- search/searcher/search_term.go | 24 ++++++++++++++---------- 6 files changed, 44 insertions(+), 24 deletions(-) diff --git a/index/scorch/snapshot_index.go b/index/scorch/snapshot_index.go index 6d0a0b60e..488465506 100644 --- a/index/scorch/snapshot_index.go +++ b/index/scorch/snapshot_index.go @@ -42,8 +42,9 @@ type asynchSegmentResult struct { dict segment.TermDictionary dictItr segment.DictionaryIterator - index int - docs *roaring.Bitmap + cardinality int + index int + docs *roaring.Bitmap thesItr segment.ThesaurusIterator @@ -137,6 +138,7 @@ func (is *IndexSnapshot) newIndexSnapshotFieldDict(field string, results := make(chan *asynchSegmentResult) var totalBytesRead uint64 + var fieldCardinality int64 for _, s := range is.segment { go func(s *SegmentSnapshot) { dict, err := s.segment.Dictionary(field) @@ -146,6 +148,8 @@ func (is *IndexSnapshot) newIndexSnapshotFieldDict(field string, if dictStats, ok := dict.(segment.DiskStatsReporter); ok { atomic.AddUint64(&totalBytesRead, dictStats.BytesRead()) } + + atomic.AddInt64(&fieldCardinality, int64(dict.Cardinality())) if randomLookup { results <- &asynchSegmentResult{dict: dict} } else { @@ -157,9 +161,11 @@ func (is *IndexSnapshot) newIndexSnapshotFieldDict(field string, var err error rv := &IndexSnapshotFieldDict{ - snapshot: is, - cursors: make([]*segmentDictCursor, 0, len(is.segment)), + snapshot: is, + cursors: make([]*segmentDictCursor, 0, len(is.segment)), + cardinality: int(fieldCardinality), } + for count := 0; count < len(is.segment); count++ { asr := <-results if asr.err != nil && err == nil { diff --git a/index/scorch/snapshot_index_dict.go b/index/scorch/snapshot_index_dict.go index 658aa8148..f28d5860b 100644 --- a/index/scorch/snapshot_index_dict.go +++ b/index/scorch/snapshot_index_dict.go @@ -28,10 +28,12 @@ type segmentDictCursor struct { } type IndexSnapshotFieldDict struct { - snapshot *IndexSnapshot - cursors []*segmentDictCursor - entry index.DictEntry - bytesRead uint64 + cardinality int + bytesRead uint64 + + snapshot *IndexSnapshot + cursors []*segmentDictCursor + entry index.DictEntry } func (i *IndexSnapshotFieldDict) BytesRead() uint64 { diff --git a/index_impl.go b/index_impl.go index ac32d6dbc..40ab9a737 100644 --- a/index_impl.go +++ b/index_impl.go @@ -511,7 +511,7 @@ func (i *indexImpl) preSearch(ctx context.Context, req *SearchRequest, reader in }, Hits: knnHits, SynonymResult: fts, - totalDocCount: count, + docCount: count, }, nil } diff --git a/pre_search.go b/pre_search.go index 13f1b496a..6e81a02cf 100644 --- a/pre_search.go +++ b/pre_search.go @@ -83,16 +83,23 @@ func (s *synonymPreSearchResultProcessor) finalize(sr *SearchResult) { } type bm25PreSearchResultProcessor struct { - docCount uint64 // bm25 specific stats + docCount uint64 // bm25 specific stats + fieldCardinality map[string]uint64 } func newBM25PreSearchResultProcessor() *bm25PreSearchResultProcessor { - return &bm25PreSearchResultProcessor{} + return &bm25PreSearchResultProcessor{ + fieldCardinality: make(map[string]uint64), + } } // TODO How will this work for queries other than term queries? func (b *bm25PreSearchResultProcessor) add(sr *SearchResult, indexName string) { - b.docCount += (sr.totalDocCount) + b.docCount += (sr.docCount) + + for field, cardinality := range sr.fieldCardinality { + b.fieldCardinality[field] += cardinality + } } func (b *bm25PreSearchResultProcessor) finalize(sr *SearchResult) { diff --git a/search.go b/search.go index 095d58ab2..056cf3b9c 100644 --- a/search.go +++ b/search.go @@ -447,9 +447,10 @@ type SearchResult struct { // special fields that are applicable only for search // results that are obtained from a presearch SynonymResult search.FieldTermSynonymMap `json:"synonym_result,omitempty"` + // The following fields are applicable to BM25 preSearch - // todo add more fields beyond docCount - totalDocCount uint64 + docCount uint64 + fieldCardinality map[string]uint64 // search_field -> cardinality } func (sr *SearchResult) Size() int { diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index 164ea31ad..636d9df21 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -66,18 +66,22 @@ func NewTermSearcherBytes(ctx context.Context, indexReader index.IndexReader, te func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReader, reader index.TermFieldReader, term []byte, field string, boost float64, options search.SearcherOptions) (*TermSearcher, error) { - count, ok := ctx.Value(search.BM25PreSearchDataKey).(uint64) - if !ok { - var err error - count, err = indexReader.DocCount() - if err != nil { - _ = reader.Close() - return nil, err + var count uint64 + if ctx != nil { + ctxCount, ok := ctx.Value(search.BM25PreSearchDataKey).(uint64) + if !ok { + var err error + ctxCount, err = indexReader.DocCount() + if err != nil { + _ = reader.Close() + return nil, err + } + } else { + fmt.Printf("fetched from ctx \n") } - } else { - fmt.Printf("fetched from ctx \n") - } + count = ctxCount + } scorer := scorer.NewTermQueryScorer(term, field, boost, count, reader.Count(), options) return &TermSearcher{ indexReader: indexReader, From 52b176896b0e6aa65e1c26b7d05380147025a98c Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Fri, 6 Dec 2024 16:46:17 +0530 Subject: [PATCH 04/27] average doc length stat for a field --- index_alias_impl.go | 9 ++++++++- index_impl.go | 17 +++++++++++++++++ pre_search.go | 7 ++++--- search.go | 2 +- search/searcher/search_term.go | 21 +++++++++++++++++---- 5 files changed, 47 insertions(+), 9 deletions(-) diff --git a/index_alias_impl.go b/index_alias_impl.go index 805ccbf0b..6bfd4e27f 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -708,7 +708,10 @@ func constructSynonymPreSearchData(rv map[string]map[string]interface{}, sr *Sea } func constructBM25PreSearchData(rv map[string]map[string]interface{}, sr *SearchResult, indexes []Index) map[string]map[string]interface{} { for _, index := range indexes { - rv[index.Name()][search.BM25PreSearchDataKey] = sr.totalDocCount + rv[index.Name()][search.BM25PreSearchDataKey] = map[string]interface{}{ + "docCount": sr.docCount, + "fieldCardinality": sr.fieldCardinality, + } } return rv } @@ -1035,3 +1038,7 @@ func (f *indexAliasImplFieldDict) Close() error { defer f.index.mutex.RUnlock() return f.fieldDict.Close() } + +func (f *indexAliasImplFieldDict) Cardinality() int { + return f.fieldDict.Cardinality() +} diff --git a/index_impl.go b/index_impl.go index 40ab9a737..4f48c3d99 100644 --- a/index_impl.go +++ b/index_impl.go @@ -486,6 +486,7 @@ func (i *indexImpl) preSearch(ctx context.Context, req *SearchRequest, reader in var fts search.FieldTermSynonymMap var count uint64 + fieldCardinality := make(map[string]int) if !isMatchNoneQuery(req.Query) { if synMap, ok := i.m.(mapping.SynonymMapping); ok { if synReader, ok := reader.(index.ThesaurusReader); ok { @@ -501,6 +502,17 @@ func (i *indexImpl) preSearch(ctx context.Context, req *SearchRequest, reader in if err != nil { return nil, err } + + fs := make(query.FieldSet) + fs = query.ExtractFields(req.Query, i.m, fs) + + for field := range fs { + dict, err := reader.FieldDict(field) + if err != nil { + return nil, err + } + fieldCardinality[field] = dict.Cardinality() + } } } @@ -512,6 +524,7 @@ func (i *indexImpl) preSearch(ctx context.Context, req *SearchRequest, reader in Hits: knnHits, SynonymResult: fts, docCount: count, + fieldCardinality: fieldCardinality, }, nil } @@ -1127,6 +1140,10 @@ func (f *indexImplFieldDict) Close() error { return f.indexReader.Close() } +func (f *indexImplFieldDict) Cardinality() int { + return f.fieldDict.Cardinality() +} + // helper function to remove duplicate entries from slice of strings func deDuplicate(fields []string) []string { entries := make(map[string]struct{}) diff --git a/pre_search.go b/pre_search.go index 6e81a02cf..ebe80b723 100644 --- a/pre_search.go +++ b/pre_search.go @@ -84,12 +84,12 @@ func (s *synonymPreSearchResultProcessor) finalize(sr *SearchResult) { type bm25PreSearchResultProcessor struct { docCount uint64 // bm25 specific stats - fieldCardinality map[string]uint64 + fieldCardinality map[string]int } func newBM25PreSearchResultProcessor() *bm25PreSearchResultProcessor { return &bm25PreSearchResultProcessor{ - fieldCardinality: make(map[string]uint64), + fieldCardinality: make(map[string]int), } } @@ -103,7 +103,8 @@ func (b *bm25PreSearchResultProcessor) add(sr *SearchResult, indexName string) { } func (b *bm25PreSearchResultProcessor) finalize(sr *SearchResult) { - + sr.docCount = b.docCount + sr.fieldCardinality = b.fieldCardinality } // ----------------------------------------------------------------------------- diff --git a/search.go b/search.go index 056cf3b9c..cebeddfd0 100644 --- a/search.go +++ b/search.go @@ -450,7 +450,7 @@ type SearchResult struct { // The following fields are applicable to BM25 preSearch docCount uint64 - fieldCardinality map[string]uint64 // search_field -> cardinality + fieldCardinality map[string]int // search_field -> cardinality } func (sr *SearchResult) Size() int { diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index 636d9df21..80635f5c2 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -67,20 +67,33 @@ func NewTermSearcherBytes(ctx context.Context, indexReader index.IndexReader, te func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReader, reader index.TermFieldReader, term []byte, field string, boost float64, options search.SearcherOptions) (*TermSearcher, error) { var count uint64 + var fieldCardinality int if ctx != nil { - ctxCount, ok := ctx.Value(search.BM25PreSearchDataKey).(uint64) + bm25Stats, ok := ctx.Value(search.BM25PreSearchDataKey).(map[string]interface{}) if !ok { var err error - ctxCount, err = indexReader.DocCount() + count, err = indexReader.DocCount() if err != nil { _ = reader.Close() return nil, err } + dict, err := indexReader.FieldDict(field) + if err != nil { + _ = indexReader.Close() + return nil, err + } + fieldCardinality = dict.Cardinality() } else { fmt.Printf("fetched from ctx \n") - } - count = ctxCount + count = bm25Stats["docCount"].(uint64) + fieldCardinalityMap := bm25Stats["fieldCardinality"].(map[string]int) + fieldCardinality, ok = fieldCardinalityMap[field] + if !ok { + return nil, fmt.Errorf("field stat for bm25 not present %s", field) + } + fmt.Println("average doc length for", field, "is", fieldCardinality/int(count)) + } } scorer := scorer.NewTermQueryScorer(term, field, boost, count, reader.Count(), options) return &TermSearcher{ From 42082f8c9c1af80c0436b32db14c5ae79af05ed5 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Fri, 6 Dec 2024 23:52:13 +0530 Subject: [PATCH 05/27] bm25 scoring first implementation --- index/scorch/snapshot_index.go | 8 ++++---- index/scorch/snapshot_index_dict.go | 6 ++++++ index/upsidedown/field_dict.go | 4 ++++ search/scorer/scorer_term.go | 17 ++++++++++++++--- search/searcher/search_term.go | 3 ++- 5 files changed, 30 insertions(+), 8 deletions(-) diff --git a/index/scorch/snapshot_index.go b/index/scorch/snapshot_index.go index 488465506..a04760f9d 100644 --- a/index/scorch/snapshot_index.go +++ b/index/scorch/snapshot_index.go @@ -148,7 +148,7 @@ func (is *IndexSnapshot) newIndexSnapshotFieldDict(field string, if dictStats, ok := dict.(segment.DiskStatsReporter); ok { atomic.AddUint64(&totalBytesRead, dictStats.BytesRead()) } - + fmt.Println("bro what", int64(dict.Cardinality())) atomic.AddInt64(&fieldCardinality, int64(dict.Cardinality())) if randomLookup { results <- &asynchSegmentResult{dict: dict} @@ -161,9 +161,8 @@ func (is *IndexSnapshot) newIndexSnapshotFieldDict(field string, var err error rv := &IndexSnapshotFieldDict{ - snapshot: is, - cursors: make([]*segmentDictCursor, 0, len(is.segment)), - cardinality: int(fieldCardinality), + snapshot: is, + cursors: make([]*segmentDictCursor, 0, len(is.segment)), } for count := 0; count < len(is.segment); count++ { @@ -189,6 +188,7 @@ func (is *IndexSnapshot) newIndexSnapshotFieldDict(field string, } } } + rv.cardinality = int(fieldCardinality) rv.bytesRead = totalBytesRead // after ensuring we've read all items on channel if err != nil { diff --git a/index/scorch/snapshot_index_dict.go b/index/scorch/snapshot_index_dict.go index f28d5860b..729ca77e8 100644 --- a/index/scorch/snapshot_index_dict.go +++ b/index/scorch/snapshot_index_dict.go @@ -16,6 +16,7 @@ package scorch import ( "container/heap" + "fmt" index "github.com/blevesearch/bleve_index_api" segment "github.com/blevesearch/scorch_segment_api/v2" @@ -96,6 +97,11 @@ func (i *IndexSnapshotFieldDict) Next() (*index.DictEntry, error) { return &i.entry, nil } +func (i *IndexSnapshotFieldDict) Cardinality() int { + fmt.Println("cardianlity", i.cardinality) + return i.cardinality +} + func (i *IndexSnapshotFieldDict) Close() error { return nil } diff --git a/index/upsidedown/field_dict.go b/index/upsidedown/field_dict.go index 4875680c9..c990fd47b 100644 --- a/index/upsidedown/field_dict.go +++ b/index/upsidedown/field_dict.go @@ -77,6 +77,10 @@ func (r *UpsideDownCouchFieldDict) Next() (*index.DictEntry, error) { } +func (r *UpsideDownCouchFieldDict) Cardinality() int { + return 0 +} + func (r *UpsideDownCouchFieldDict) Close() error { return r.iterator.Close() } diff --git a/search/scorer/scorer_term.go b/search/scorer/scorer_term.go index 7b60eda4e..b837f607b 100644 --- a/search/scorer/scorer_term.go +++ b/search/scorer/scorer_term.go @@ -37,6 +37,7 @@ type TermQueryScorer struct { queryBoost float64 docTerm uint64 docTotal uint64 + avgDocLength float64 idf float64 options search.SearcherOptions idfExplanation *search.Explanation @@ -61,14 +62,18 @@ func (s *TermQueryScorer) Size() int { return sizeInBytes } -func NewTermQueryScorer(queryTerm []byte, queryField string, queryBoost float64, docTotal, docTerm uint64, options search.SearcherOptions) *TermQueryScorer { +func NewTermQueryScorer(queryTerm []byte, queryField string, queryBoost float64, docTotal, + docTerm uint64, avgDocLength float64, options search.SearcherOptions) *TermQueryScorer { + + idfVal := math.Log(1 + (float64(docTotal)-float64(docTerm)+0.5)/(float64(docTerm)+0.5)) rv := TermQueryScorer{ queryTerm: string(queryTerm), queryField: queryField, queryBoost: queryBoost, docTerm: docTerm, docTotal: docTotal, - idf: 1.0 + math.Log(float64(docTotal)/float64(docTerm+1.0)), + avgDocLength: avgDocLength, + idf: idfVal, options: options, queryWeight: 1.0, includeScore: options.Score != "none", @@ -125,7 +130,13 @@ func (s *TermQueryScorer) Score(ctx *search.SearchContext, termMatch *index.Term } else { tf = math.Sqrt(float64(termMatch.Freq)) } - score := tf * termMatch.Norm * s.idf + // score := tf * termMatch.Norm * s.idf + + // using the posting's norm value to recompute the field length for the doc num + fieldLength := 1 / (termMatch.Norm * termMatch.Norm) + var k float64 = 1 + var b float64 = 1 + score := s.idf * (tf * k) / (tf + k*(1-b+(b*fieldLength/s.avgDocLength))) if s.options.Explain { childrenExplanations := make([]*search.Explanation, 3) diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index 80635f5c2..f5797030b 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -83,6 +83,7 @@ func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReade return nil, err } fieldCardinality = dict.Cardinality() + fmt.Println("average doc length for", field, "is", fieldCardinality/int(count)) } else { fmt.Printf("fetched from ctx \n") count = bm25Stats["docCount"].(uint64) @@ -95,7 +96,7 @@ func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReade fmt.Println("average doc length for", field, "is", fieldCardinality/int(count)) } } - scorer := scorer.NewTermQueryScorer(term, field, boost, count, reader.Count(), options) + scorer := scorer.NewTermQueryScorer(term, field, boost, count, reader.Count(), float64(fieldCardinality/int(count)), options) return &TermSearcher{ indexReader: indexReader, reader: reader, From a52bd4983fc81175d14f04e37af35b8e32631510 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Sat, 7 Dec 2024 00:13:17 +0530 Subject: [PATCH 06/27] notes and keep the default tf-idf stuff --- search/scorer/scorer_term.go | 27 ++++++++++++++++++++------- search/searcher/search_term.go | 4 ++++ 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/search/scorer/scorer_term.go b/search/scorer/scorer_term.go index b837f607b..7967ac393 100644 --- a/search/scorer/scorer_term.go +++ b/search/scorer/scorer_term.go @@ -65,7 +65,14 @@ func (s *TermQueryScorer) Size() int { func NewTermQueryScorer(queryTerm []byte, queryField string, queryBoost float64, docTotal, docTerm uint64, avgDocLength float64, options search.SearcherOptions) *TermQueryScorer { - idfVal := math.Log(1 + (float64(docTotal)-float64(docTerm)+0.5)/(float64(docTerm)+0.5)) + var idfVal float64 + if avgDocLength > 0 { + // avgDocLength is set only for bm25 scoring + idfVal = math.Log(1 + (float64(docTotal)-float64(docTerm)+0.5)/(float64(docTerm)+0.5)) + } else { + idfVal = 1.0 + math.Log(float64(docTotal)/float64(docTerm+1.0)) + } + rv := TermQueryScorer{ queryTerm: string(queryTerm), queryField: queryField, @@ -130,14 +137,20 @@ func (s *TermQueryScorer) Score(ctx *search.SearchContext, termMatch *index.Term } else { tf = math.Sqrt(float64(termMatch.Freq)) } - // score := tf * termMatch.Norm * s.idf - // using the posting's norm value to recompute the field length for the doc num - fieldLength := 1 / (termMatch.Norm * termMatch.Norm) - var k float64 = 1 - var b float64 = 1 - score := s.idf * (tf * k) / (tf + k*(1-b+(b*fieldLength/s.avgDocLength))) + // tf-idf scoring by default + score := tf * termMatch.Norm * s.idf + if s.avgDocLength > 0 { + // using the posting's norm value to recompute the field length for the doc num + fieldLength := 1 / (termMatch.Norm * termMatch.Norm) + + // multipliers. todo: these are something to be set in the scorer by parent layer + var k float64 = 1 + var b float64 = 1 + score = s.idf * (tf * k) / (tf + k*(1-b+(b*fieldLength/s.avgDocLength))) + } + // todo: explain stuff properly if s.options.Explain { childrenExplanations := make([]*search.Explanation, 3) childrenExplanations[0] = &search.Explanation{ diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index f5797030b..53c04357e 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -95,6 +95,10 @@ func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReade fmt.Println("average doc length for", field, "is", fieldCardinality/int(count)) } + + // in case of bm25 need to fetch the multipliers as well (maybe something set in index mapping?) + // fieldMapping := m.FieldMappingForPath(q.VectorField) + // but tbd how to pass on the field mapping here, can we pass it (the multipliers) in the context? } scorer := scorer.NewTermQueryScorer(term, field, boost, count, reader.Count(), float64(fieldCardinality/int(count)), options) return &TermSearcher{ From 36159b690acc182c1770c412cc706dcedef1a486 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Tue, 10 Dec 2024 11:58:11 +0530 Subject: [PATCH 07/27] bug fixes and BM25 UT pass --- index_alias_impl.go | 10 +-- index_impl.go | 16 ++-- index_test.go | 143 ++++++++++++++++++++++++++++++++- mapping/index.go | 2 + mapping/mapping.go | 2 + pre_search.go | 4 +- search/searcher/search_term.go | 4 +- 7 files changed, 164 insertions(+), 17 deletions(-) diff --git a/index_alias_impl.go b/index_alias_impl.go index 6bfd4e27f..6f9606705 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -245,6 +245,8 @@ func (i *indexAliasImpl) SearchInContext(ctx context.Context, req *SearchRequest if err != nil { return nil, err } + + fmt.Println("presearch result", preSearchResult.docCount) // check if the preSearch result has any errors and if so // return the search result as is without executing the query // so that the errors are not lost @@ -601,12 +603,8 @@ func preSearchRequired(req *SearchRequest, m mapping.IndexMapping) (*preSearchFl } } var bm25 bool - if !isMatchNoneQuery(req.Query) { - // todo fix this cuRRENTLY ALL INDEX mappings are BM25 mappings, need to fix - // this is just a placeholder. - if _, ok := m.(mapping.BM25Mapping); ok { - bm25 = true - } + if _, ok := m.(mapping.BM25Mapping); ok { + bm25 = true } if knn || synonyms || bm25 { diff --git a/index_impl.go b/index_impl.go index 4f48c3d99..981e8fb88 100644 --- a/index_impl.go +++ b/index_impl.go @@ -483,7 +483,7 @@ func (i *indexImpl) preSearch(ctx context.Context, req *SearchRequest, reader in return nil, err } } - + var fts search.FieldTermSynonymMap var count uint64 fieldCardinality := make(map[string]int) @@ -521,9 +521,9 @@ func (i *indexImpl) preSearch(ctx context.Context, req *SearchRequest, reader in Total: 1, Successful: 1, }, - Hits: knnHits, - SynonymResult: fts, - docCount: count, + Hits: knnHits, + SynonymResult: fts, + docCount: count, fieldCardinality: fieldCardinality, }, nil } @@ -581,6 +581,7 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr var skipSynonymCollector bool var bm25TotalDocs uint64 + var bm25Data map[string]interface{} var ok bool if req.PreSearchData != nil { for k, v := range req.PreSearchData { @@ -604,7 +605,7 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr skipKnnCollector = true case search.BM25PreSearchDataKey: if v != nil { - bm25TotalDocs, ok = v.(uint64) + bm25Data, ok = v.(map[string]interface{}) if !ok { return nil, fmt.Errorf("bm25 preSearchData must be of type uint64") } @@ -634,8 +635,9 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr if fts != nil { ctx = context.WithValue(ctx, search.FieldTermSynonymMapKey, fts) - if bm25TotalDocs > 0 { - ctx = context.WithValue(ctx, search.BM25MapKey, bm25TotalDocs) + } + if bm25Data != nil { + ctx = context.WithValue(ctx, search.BM25PreSearchDataKey, bm25Data) } // This callback and variable handles the tracking of bytes read diff --git a/index_test.go b/index_test.go index 82be0d947..3d80ed2da 100644 --- a/index_test.go +++ b/index_test.go @@ -350,6 +350,138 @@ func TestBytesWritten(t *testing.T) { cleanupTmpIndexPath(t, tmpIndexPath4) } +func TestBM25(t *testing.T) { + tmpIndexPath := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath) + + indexMapping := NewIndexMapping() + indexMapping.TypeField = "type" + indexMapping.DefaultAnalyzer = "en" + documentMapping := NewDocumentMapping() + indexMapping.AddDocumentMapping("hotel", documentMapping) + indexMapping.StoreDynamic = false + indexMapping.DocValuesDynamic = false + contentFieldMapping := NewTextFieldMapping() + contentFieldMapping.Store = false + + reviewsMapping := NewDocumentMapping() + reviewsMapping.AddFieldMappingsAt("content", contentFieldMapping) + documentMapping.AddSubDocumentMapping("reviews", reviewsMapping) + + typeFieldMapping := NewTextFieldMapping() + typeFieldMapping.Store = false + documentMapping.AddFieldMappingsAt("type", typeFieldMapping) + + idxSinglePartition, err := NewUsing(tmpIndexPath, indexMapping, Config.DefaultIndexType, Config.DefaultMemKVStore, nil) + if err != nil { + t.Fatal(err) + } + + defer func() { + err := idxSinglePartition.Close() + if err != nil { + t.Fatal(err) + } + }() + + batch, err := getBatchFromData(idxSinglePartition, "sample-data.json") + if err != nil { + t.Fatalf("failed to form a batch") + } + err = idxSinglePartition.Batch(batch) + if err != nil { + t.Fatalf("failed to index batch %v\n", err) + } + query := NewMatchQuery("Apartments") + query.FieldVal = "name" + searchRequest := NewSearchRequestOptions(query, int(10), 0, true) + + res, err := idxSinglePartition.Search(searchRequest) + if err != nil { + t.Error(err) + } + + fmt.Println("length of hits", res.Hits[0].Score) + dataset, _ := readDataFromFile("sample-data.json") + fmt.Println("length of dataset", len(dataset)) + tmpIndexPath1 := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath1) + + idxPart1, err := NewUsing(tmpIndexPath1, indexMapping, Config.DefaultIndexType, Config.DefaultMemKVStore, nil) + if err != nil { + t.Fatal(err) + } + + defer func() { + err := idxPart1.Close() + if err != nil { + t.Fatal(err) + } + }() + + batch1 := idxPart1.NewBatch() + for _, doc := range dataset[:len(dataset)/2] { + err = batch1.Index(fmt.Sprintf("%d", doc["id"]), doc) + if err != nil { + t.Fatal(err) + } + } + err = idxPart1.Batch(batch1) + if err != nil { + t.Fatal(err) + } + + tmpIndexPath2 := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath2) + + idxPart2, err := NewUsing(tmpIndexPath2, indexMapping, Config.DefaultIndexType, Config.DefaultMemKVStore, nil) + if err != nil { + t.Fatal(err) + } + + defer func() { + err := idxPart2.Close() + if err != nil { + t.Fatal(err) + } + }() + + batch2 := idxPart2.NewBatch() + for _, doc := range dataset[len(dataset)/2:] { + err = batch2.Index(fmt.Sprintf("%d", doc["id"]), doc) + if err != nil { + t.Fatal(err) + } + } + err = idxPart2.Batch(batch2) + if err != nil { + t.Fatal(err) + } + + multiPartIndex := NewIndexAlias(idxPart1, idxPart2) + err = multiPartIndex.SetIndexMapping(indexMapping) + if err != nil { + t.Fatal(err) + } + + res, err = multiPartIndex.Search(searchRequest) + if err != nil { + t.Error(err) + } + + // ctx := context.Background() + // ctx = context.WithValue(ctx, search.PreSearchKey, + // search.SearcherStartCallbackFn(bleveCtxSearcherStartCallback)) + + // res, err = multiPartIndex.SearchInContext(ctx, searchRequest) + // if err != nil { + // t.Error(err) + // } + + fmt.Println("length of hits alias search", res.Hits[0].Score) + +} + func TestBytesRead(t *testing.T) { tmpIndexPath := createTmpIndexPath(t) defer cleanupTmpIndexPath(t, tmpIndexPath) @@ -671,23 +803,30 @@ func TestBytesReadStored(t *testing.T) { } } -func getBatchFromData(idx Index, fileName string) (*Batch, error) { +func readDataFromFile(fileName string) ([]map[string]interface{}, error) { pwd, err := os.Getwd() if err != nil { return nil, err } path := filepath.Join(pwd, "data", "test", fileName) - batch := idx.NewBatch() + var dataset []map[string]interface{} fileContent, err := os.ReadFile(path) if err != nil { return nil, err } + err = json.Unmarshal(fileContent, &dataset) if err != nil { return nil, err } + return dataset, nil +} + +func getBatchFromData(idx Index, fileName string) (*Batch, error) { + dataset, err := readDataFromFile(fileName) + batch := idx.NewBatch() for _, doc := range dataset { err = batch.Index(fmt.Sprintf("%d", doc["id"]), doc) if err != nil { diff --git a/mapping/index.go b/mapping/index.go index 8a0d5e34a..a78a44a23 100644 --- a/mapping/index.go +++ b/mapping/index.go @@ -557,4 +557,6 @@ func (im *IndexMappingImpl) SynonymSourceVisitor(visitor analysis.SynonymSourceV return err } return nil +func (im *IndexMappingImpl) BM25Impl() { + fmt.Println("BM25Impl") } diff --git a/mapping/mapping.go b/mapping/mapping.go index 0992f9837..9b7690c19 100644 --- a/mapping/mapping.go +++ b/mapping/mapping.go @@ -76,4 +76,6 @@ type SynonymMapping interface { } type BM25Mapping interface { IndexMapping + + BM25Impl() } diff --git a/pre_search.go b/pre_search.go index ebe80b723..2b2ef03d1 100644 --- a/pre_search.go +++ b/pre_search.go @@ -15,6 +15,8 @@ package bleve import ( + "fmt" + "github.com/blevesearch/bleve/v2/search" ) @@ -96,7 +98,7 @@ func newBM25PreSearchResultProcessor() *bm25PreSearchResultProcessor { // TODO How will this work for queries other than term queries? func (b *bm25PreSearchResultProcessor) add(sr *SearchResult, indexName string) { b.docCount += (sr.docCount) - + fmt.Println("docCount: ", b.docCount) for field, cardinality := range sr.fieldCardinality { b.fieldCardinality[field] += cardinality } diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index 53c04357e..e0c817859 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -83,7 +83,9 @@ func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReade return nil, err } fieldCardinality = dict.Cardinality() - fmt.Println("average doc length for", field, "is", fieldCardinality/int(count)) + fmt.Println("------------------") + fmt.Println("the num docs", count) + fmt.Println("the field cardinality", fieldCardinality) } else { fmt.Printf("fetched from ctx \n") count = bm25Stats["docCount"].(uint64) From f3424b5a45c215b0637494c540c88641f91b26a8 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Tue, 10 Dec 2024 12:43:21 +0530 Subject: [PATCH 08/27] making bm25 presearch (i.e. global scoring) optional --- index_alias_impl.go | 21 +++++++++----- index_impl.go | 4 ++- index_test.go | 15 ++++------ search/scorer/scorer_term.go | 51 +++++++++++++++++++++------------- search/searcher/search_term.go | 10 +++---- search/util.go | 3 ++ 6 files changed, 62 insertions(+), 42 deletions(-) diff --git a/index_alias_impl.go b/index_alias_impl.go index 6f9606705..e122cdf48 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -16,7 +16,6 @@ package bleve import ( "context" - "fmt" "sync" "time" @@ -246,7 +245,6 @@ func (i *indexAliasImpl) SearchInContext(ctx context.Context, req *SearchRequest return nil, err } - fmt.Println("presearch result", preSearchResult.docCount) // check if the preSearch result has any errors and if so // return the search result as is without executing the query // so that the errors are not lost @@ -579,9 +577,9 @@ type preSearchFlags struct { bm25 bool // needs presearch for this too } -// preSearchRequired checks if preSearch is required and returns a boolean flag -// It only allocates the preSearchFlags struct if necessary -func preSearchRequired(req *SearchRequest, m mapping.IndexMapping) (*preSearchFlags, error) { +// preSearchRequired checks if preSearch is required and returns the presearch flags struct +// indicating which preSearch is required +func preSearchRequired(ctx context.Context, req *SearchRequest, m mapping.IndexMapping) (*preSearchFlags, error){ // Check for KNN query knn := requestHasKNN(req) var synonyms bool @@ -603,8 +601,17 @@ func preSearchRequired(req *SearchRequest, m mapping.IndexMapping) (*preSearchFl } } var bm25 bool - if _, ok := m.(mapping.BM25Mapping); ok { - bm25 = true + if !isMatchNoneQuery(req.Query) { + if ctx != nil { + if searchType := ctx.Value(search.SearchTypeKey); searchType != nil { + if searchType.(string) == search.FetchStatsAndSearch { + // todo: check mapping to see if bm25 is needed + if _, ok := m.(mapping.BM25Mapping); ok { + bm25 = true + } + } + } + } } if knn || synonyms || bm25 { diff --git a/index_impl.go b/index_impl.go index 981e8fb88..a7751bc4f 100644 --- a/index_impl.go +++ b/index_impl.go @@ -607,7 +607,7 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr if v != nil { bm25Data, ok = v.(map[string]interface{}) if !ok { - return nil, fmt.Errorf("bm25 preSearchData must be of type uint64") + return nil, fmt.Errorf("bm25 preSearchData must be of type map[string]interface{}") } } } @@ -636,6 +636,8 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr if fts != nil { ctx = context.WithValue(ctx, search.FieldTermSynonymMapKey, fts) } + // set the bm25 presearch data (stats important for consistent scoring) in + // the context object if bm25Data != nil { ctx = context.WithValue(ctx, search.BM25PreSearchDataKey, bm25Data) } diff --git a/index_test.go b/index_test.go index 3d80ed2da..05675cbd0 100644 --- a/index_test.go +++ b/index_test.go @@ -464,20 +464,15 @@ func TestBM25(t *testing.T) { t.Fatal(err) } - res, err = multiPartIndex.Search(searchRequest) + ctx := context.Background() + // not setting this doesn't perform a presearch for bm25 + ctx = context.WithValue(ctx, search.SearchTypeKey, search.FetchStatsAndSearch) + + res, err = multiPartIndex.SearchInContext(ctx, searchRequest) if err != nil { t.Error(err) } - // ctx := context.Background() - // ctx = context.WithValue(ctx, search.PreSearchKey, - // search.SearcherStartCallbackFn(bleveCtxSearcherStartCallback)) - - // res, err = multiPartIndex.SearchInContext(ctx, searchRequest) - // if err != nil { - // t.Error(err) - // } - fmt.Println("length of hits alias search", res.Hits[0].Score) } diff --git a/search/scorer/scorer_term.go b/search/scorer/scorer_term.go index 7967ac393..f98d4288d 100644 --- a/search/scorer/scorer_term.go +++ b/search/scorer/scorer_term.go @@ -62,17 +62,23 @@ func (s *TermQueryScorer) Size() int { return sizeInBytes } -func NewTermQueryScorer(queryTerm []byte, queryField string, queryBoost float64, docTotal, - docTerm uint64, avgDocLength float64, options search.SearcherOptions) *TermQueryScorer { - - var idfVal float64 +func (s *TermQueryScorer) computeIDF(avgDocLength float64, docTotal, docTerm uint64) float64 { + var rv float64 if avgDocLength > 0 { // avgDocLength is set only for bm25 scoring - idfVal = math.Log(1 + (float64(docTotal)-float64(docTerm)+0.5)/(float64(docTerm)+0.5)) + rv = math.Log(1 + (float64(docTotal)-float64(docTerm)+0.5)/ + (float64(docTerm)+0.5)) } else { - idfVal = 1.0 + math.Log(float64(docTotal)/float64(docTerm+1.0)) + rv = 1.0 + math.Log(float64(docTotal)/ + float64(docTerm+1.0)) } + return rv +} + +func NewTermQueryScorer(queryTerm []byte, queryField string, queryBoost float64, docTotal, + docTerm uint64, avgDocLength float64, options search.SearcherOptions) *TermQueryScorer { + rv := TermQueryScorer{ queryTerm: string(queryTerm), queryField: queryField, @@ -80,12 +86,12 @@ func NewTermQueryScorer(queryTerm []byte, queryField string, queryBoost float64, docTerm: docTerm, docTotal: docTotal, avgDocLength: avgDocLength, - idf: idfVal, options: options, queryWeight: 1.0, includeScore: options.Score != "none", } + rv.idf = rv.computeIDF(avgDocLength, docTotal, docTerm) if options.Explain { rv.idfExplanation = &search.Explanation{ Value: rv.idf, @@ -126,6 +132,24 @@ func (s *TermQueryScorer) SetQueryNorm(qnorm float64) { } } +func (s *TermQueryScorer) docScore(tf, norm float64) float64 { + // tf-idf scoring by default + score := tf * norm * s.idf + if s.avgDocLength > 0 { + // bm25 scoring + // using the posting's norm value to recompute the field length for the doc num + fieldLength := 1 / (norm * norm) + + // multiplies deciding how much does a doc length affect the score and also + // how much can the term frequency affect the score + var k1 float64 = 1 + var b float64 = 1 + score = s.idf * (tf * k1) / + (tf + k1*(1-b+(b*fieldLength/s.avgDocLength))) + } + return score +} + func (s *TermQueryScorer) Score(ctx *search.SearchContext, termMatch *index.TermFieldDoc) *search.DocumentMatch { rv := ctx.DocumentMatchPool.Get() // perform any score computations only when needed @@ -138,18 +162,7 @@ func (s *TermQueryScorer) Score(ctx *search.SearchContext, termMatch *index.Term tf = math.Sqrt(float64(termMatch.Freq)) } - // tf-idf scoring by default - score := tf * termMatch.Norm * s.idf - if s.avgDocLength > 0 { - // using the posting's norm value to recompute the field length for the doc num - fieldLength := 1 / (termMatch.Norm * termMatch.Norm) - - // multipliers. todo: these are something to be set in the scorer by parent layer - var k float64 = 1 - var b float64 = 1 - score = s.idf * (tf * k) / (tf + k*(1-b+(b*fieldLength/s.avgDocLength))) - } - + score := s.docScore(tf, termMatch.Norm) // todo: explain stuff properly if s.options.Explain { childrenExplanations := make([]*search.Explanation, 3) diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index e0c817859..07908600c 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -39,14 +39,16 @@ type TermSearcher struct { tfd index.TermFieldDoc } -func NewTermSearcher(ctx context.Context, indexReader index.IndexReader, term string, field string, boost float64, options search.SearcherOptions) (search.Searcher, error) { +func NewTermSearcher(ctx context.Context, indexReader index.IndexReader, + term string, field string, boost float64, options search.SearcherOptions) (search.Searcher, error) { if isTermQuery(ctx) { ctx = context.WithValue(ctx, search.QueryTypeKey, search.Term) } return NewTermSearcherBytes(ctx, indexReader, []byte(term), field, boost, options) } -func NewTermSearcherBytes(ctx context.Context, indexReader index.IndexReader, term []byte, field string, boost float64, options search.SearcherOptions) (search.Searcher, error) { +func NewTermSearcherBytes(ctx context.Context, indexReader index.IndexReader, + term []byte, field string, boost float64, options search.SearcherOptions) (search.Searcher, error) { if ctx != nil { if fts, ok := ctx.Value(search.FieldTermSynonymMapKey).(search.FieldTermSynonymMap); ok { if ts, exists := fts[field]; exists { @@ -98,9 +100,7 @@ func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReade fmt.Println("average doc length for", field, "is", fieldCardinality/int(count)) } - // in case of bm25 need to fetch the multipliers as well (maybe something set in index mapping?) - // fieldMapping := m.FieldMappingForPath(q.VectorField) - // but tbd how to pass on the field mapping here, can we pass it (the multipliers) in the context? + // in case of bm25 need to fetch the multipliers as well (perhaps via context's presearch data) } scorer := scorer.NewTermQueryScorer(term, field, boost, count, reader.Count(), float64(fieldCardinality/int(count)), options) return &TermSearcher{ diff --git a/search/util.go b/search/util.go index 0aa8c7438..14534027b 100644 --- a/search/util.go +++ b/search/util.go @@ -141,6 +141,9 @@ const BM25PreSearchDataKey = "_bm25_pre_search_data_key" const PreSearchKey = "_presearch_key" +const SearchTypeKey = "_search_type_key" +const FetchStatsAndSearch = "fetch_stats_and_search" + type ScoreExplCorrectionCallbackFunc func(queryMatch *DocumentMatch, knnMatch *DocumentMatch) (float64, *Explanation) type SearcherStartCallbackFn func(size uint64) error From d393616e041e54928da91512c75c937fa5473a14 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Wed, 11 Dec 2024 12:15:38 +0530 Subject: [PATCH 09/27] field mapping to capture type of scoring; bm25 by default --- index_alias_impl.go | 37 +++++++------- index_impl.go | 10 +++- index_test.go | 2 +- mapping/field.go | 4 +- mapping/mapping_vectors.go | 6 +-- search/query/knn.go | 2 +- search/searcher/search_term.go | 90 ++++++++++++++++++++++++---------- search/util.go | 10 ++-- 8 files changed, 106 insertions(+), 55 deletions(-) diff --git a/index_alias_impl.go b/index_alias_impl.go index e122cdf48..dbc7ac85d 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -723,26 +723,27 @@ func constructBM25PreSearchData(rv map[string]map[string]interface{}, sr *Search func constructPreSearchData(req *SearchRequest, flags *preSearchFlags, preSearchResult *SearchResult, indexes []Index) (map[string]map[string]interface{}, error) { - if flags == nil || preSearchResult == nil { - return nil, fmt.Errorf("invalid input, flags: %v, preSearchResult: %v", flags, preSearchResult) - } - mergedOut := make(map[string]map[string]interface{}, len(indexes)) - for _, index := range indexes { - mergedOut[index.Name()] = make(map[string]interface{}) - } - var err error - if flags.knn { - mergedOut, err = constructKnnPreSearchData(mergedOut, preSearchResult, indexes) - if err != nil { - return nil, err + if flags == nil || preSearchResult == nil { + return nil, fmt.Errorf("invalid input, flags: %v, preSearchResult: %v", flags, preSearchResult) } + mergedOut := make(map[string]map[string]interface{}, len(indexes)) + for _, index := range indexes { + mergedOut[index.Name()] = make(map[string]interface{}) + } + var err error + if flags.knn { + mergedOut, err = constructKnnPreSearchData(mergedOut, preSearchResult, indexes) + if err != nil { + return nil, err + } + } + if flags.synonyms { + mergedOut = constructSynonymPreSearchData(mergedOut, preSearchResult, indexes) + if flags.bm25 { + mergedOut = constructBM25PreSearchData(mergedOut, preSearchResult, indexes) + } + return mergedOut, nil } - if flags.synonyms { - mergedOut = constructSynonymPreSearchData(mergedOut, preSearchResult, indexes) - if flags.bm25 { - mergedOut = constructBM25PreSearchData(mergedOut, preSearchResult, indexes) - } - return mergedOut, nil } func preSearchDataSearch(ctx context.Context, req *SearchRequest, flags *preSearchFlags, indexes ...Index) (*SearchResult, error) { diff --git a/index_impl.go b/index_impl.go index a7751bc4f..be8623bbf 100644 --- a/index_impl.go +++ b/index_impl.go @@ -580,7 +580,6 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr var fts search.FieldTermSynonymMap var skipSynonymCollector bool - var bm25TotalDocs uint64 var bm25Data map[string]interface{} var ok bool if req.PreSearchData != nil { @@ -602,7 +601,7 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr } skipSynonymCollector = true } - skipKnnCollector = true + skipKNNCollector = true case search.BM25PreSearchDataKey: if v != nil { bm25Data, ok = v.(map[string]interface{}) @@ -636,6 +635,13 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr if fts != nil { ctx = context.WithValue(ctx, search.FieldTermSynonymMapKey, fts) } + fieldMappingCallback := func(field string) string { + rv := i.m.FieldMappingForPath(field) + return rv.Similarity + } + ctx = context.WithValue(ctx, search.GetSimilarityModelCallbackKey, + search.GetSimilarityModelCallbackFn(fieldMappingCallback)) + // set the bm25 presearch data (stats important for consistent scoring) in // the context object if bm25Data != nil { diff --git a/index_test.go b/index_test.go index 05675cbd0..8a60e4006 100644 --- a/index_test.go +++ b/index_test.go @@ -466,7 +466,7 @@ func TestBM25(t *testing.T) { ctx := context.Background() // not setting this doesn't perform a presearch for bm25 - ctx = context.WithValue(ctx, search.SearchTypeKey, search.FetchStatsAndSearch) + // ctx = context.WithValue(ctx, search.SearchTypeKey, search.FetchStatsAndSearch) res, err = multiPartIndex.SearchInContext(ctx, searchRequest) if err != nil { diff --git a/mapping/field.go b/mapping/field.go index ce2878b18..cfb390b40 100644 --- a/mapping/field.go +++ b/mapping/field.go @@ -74,8 +74,8 @@ type FieldMapping struct { Dims int `json:"dims,omitempty"` // Similarity is the similarity algorithm used for scoring - // vector fields. - // See: index.DefaultSimilarityMetric & index.SupportedSimilarityMetrics + // field's content while performing search on it. + // See: index.SimilarityModels Similarity string `json:"similarity,omitempty"` // Applicable to vector fields only - optimization string diff --git a/mapping/mapping_vectors.go b/mapping/mapping_vectors.go index dbfde1fb0..20cbac6a8 100644 --- a/mapping/mapping_vectors.go +++ b/mapping/mapping_vectors.go @@ -204,7 +204,7 @@ func validateVectorFieldAlias(field *FieldMapping, parentName string, } if field.Similarity == "" { - field.Similarity = index.DefaultSimilarityMetric + field.Similarity = index.DefaultVectorSimilarityMetric } if field.VectorIndexOptimizedFor == "" { @@ -249,10 +249,10 @@ func validateVectorFieldAlias(field *FieldMapping, parentName string, MinVectorDims, MaxVectorDims) } - if _, ok := index.SupportedSimilarityMetrics[field.Similarity]; !ok { + if _, ok := index.SupportedVectorSimilarityMetrics[field.Similarity]; !ok { return fmt.Errorf("field: '%s', invalid similarity "+ "metric: '%s', valid metrics are: %+v", field.Name, field.Similarity, - reflect.ValueOf(index.SupportedSimilarityMetrics).MapKeys()) + reflect.ValueOf(index.SupportedVectorSimilarityMetrics).MapKeys()) } if fieldAliasCtx != nil { // writing to a nil map is unsafe diff --git a/search/query/knn.go b/search/query/knn.go index 4d105d943..8221fbcea 100644 --- a/search/query/knn.go +++ b/search/query/knn.go @@ -82,7 +82,7 @@ func (q *KNNQuery) Searcher(ctx context.Context, i index.IndexReader, fieldMapping := m.FieldMappingForPath(q.VectorField) similarityMetric := fieldMapping.Similarity if similarityMetric == "" { - similarityMetric = index.DefaultSimilarityMetric + similarityMetric = index.DefaultVectorSimilarityMetric } if q.K <= 0 || len(q.Vector) == 0 { return nil, fmt.Errorf("k must be greater than 0 and vector must be non-empty") diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index 07908600c..ab4309660 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -66,41 +66,81 @@ func NewTermSearcherBytes(ctx context.Context, indexReader index.IndexReader, return newTermSearcherFromReader(ctx, indexReader, reader, term, field, boost, options) } +func tfTDFScoreMetrics(indexReader index.IndexReader) (uint64, int, error) { + // default tf-idf stats + count, err := indexReader.DocCount() + if err != nil { + return 0, 0, err + } + fieldCardinality := 0 + return count, fieldCardinality, nil +} + +func bm25ScoreMetrics(ctx context.Context, field string, + indexReader index.IndexReader) (uint64, int, error) { + var count uint64 + var fieldCardinality int + var err error + + bm25Stats, ok := ctx.Value(search.BM25PreSearchDataKey).(map[string]interface{}) + if !ok { + count, err = indexReader.DocCount() + if err != nil { + return 0, 0, err + } + dict, err := indexReader.FieldDict(field) + if err != nil { + return 0, 0, err + } + fieldCardinality = dict.Cardinality() + } else { + count = bm25Stats["docCount"].(uint64) + fieldCardinalityMap := bm25Stats["fieldCardinality"].(map[string]int) + fieldCardinality, ok = fieldCardinalityMap[field] + if !ok { + return 0, 0, fmt.Errorf("field stat for bm25 not present %s", field) + } + } + + fmt.Println("----------bm25 stats--------") + fmt.Println("docCount: ", count) + fmt.Println("fieldCardinality: ", fieldCardinality) + fmt.Println("avgDocLength: ", fieldCardinality/int(count)) + + return count, fieldCardinality, nil +} + func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReader, reader index.TermFieldReader, term []byte, field string, boost float64, options search.SearcherOptions) (*TermSearcher, error) { var count uint64 var fieldCardinality int + var err error if ctx != nil { - bm25Stats, ok := ctx.Value(search.BM25PreSearchDataKey).(map[string]interface{}) - if !ok { - var err error - count, err = indexReader.DocCount() - if err != nil { - _ = reader.Close() - return nil, err + if similaritModelCallback, ok := ctx.Value(search. + GetSimilarityModelCallbackKey).(search.GetSimilarityModelCallbackFn); ok { + similarityModel := similaritModelCallback(field) + if similarityModel == "" || similarityModel == index.BM25Similarity { + // in case of bm25 need to fetch the multipliers as well (perhaps via context's presearch data) + count, fieldCardinality, err = bm25ScoreMetrics(ctx, field, indexReader) + if err != nil { + _ = reader.Close() + return nil, err + } + } else { + count, fieldCardinality, err = tfTDFScoreMetrics(indexReader) + if err != nil { + _ = reader.Close() + return nil, err + } } - dict, err := indexReader.FieldDict(field) + } else { + // default tf-idf stats + count, fieldCardinality, err = tfTDFScoreMetrics(indexReader) if err != nil { - _ = indexReader.Close() + _ = reader.Close() return nil, err } - fieldCardinality = dict.Cardinality() - fmt.Println("------------------") - fmt.Println("the num docs", count) - fmt.Println("the field cardinality", fieldCardinality) - } else { - fmt.Printf("fetched from ctx \n") - count = bm25Stats["docCount"].(uint64) - fieldCardinalityMap := bm25Stats["fieldCardinality"].(map[string]int) - fieldCardinality, ok = fieldCardinalityMap[field] - if !ok { - return nil, fmt.Errorf("field stat for bm25 not present %s", field) - } - - fmt.Println("average doc length for", field, "is", fieldCardinality/int(count)) } - - // in case of bm25 need to fetch the multipliers as well (perhaps via context's presearch data) } scorer := scorer.NewTermQueryScorer(term, field, boost, count, reader.Count(), float64(fieldCardinality/int(count)), options) return &TermSearcher{ diff --git a/search/util.go b/search/util.go index 14534027b..d438cd61d 100644 --- a/search/util.go +++ b/search/util.go @@ -144,7 +144,8 @@ const PreSearchKey = "_presearch_key" const SearchTypeKey = "_search_type_key" const FetchStatsAndSearch = "fetch_stats_and_search" -type ScoreExplCorrectionCallbackFunc func(queryMatch *DocumentMatch, knnMatch *DocumentMatch) (float64, *Explanation) +const SearcherStartCallbackKey = "_searcher_start_callback_key" +const SearcherEndCallbackKey = "_searcher_end_callback_key" type SearcherStartCallbackFn func(size uint64) error type SearcherEndCallbackFn func(size uint64) error @@ -168,5 +169,8 @@ func (f FieldTermSynonymMap) MergeWith(fts FieldTermSynonymMap) { const FieldTermSynonymMapKey = "_field_term_synonym_map_key" const BM25MapKey = "_bm25_map_key" -const SearcherStartCallbackKey = "_searcher_start_callback_key" -const SearcherEndCallbackKey = "_searcher_end_callback_key" +const GetSimilarityModelCallbackKey = "_get_similarity_model" + +type GetSimilarityModelCallbackFn func(field string) string + +type ScoreExplCorrectionCallbackFunc func(queryMatch *DocumentMatch, knnMatch *DocumentMatch) (float64, *Explanation) From 55e63fd6158e334b71a98f0f8a9f24ad60103b3d Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Wed, 11 Dec 2024 12:51:40 +0530 Subject: [PATCH 10/27] bug fixes, unit test fixes --- index/scorch/snapshot_index.go | 1 - pre_search.go | 3 --- search/scorer/scorer_term_test.go | 4 +-- search/searcher/search_term.go | 44 ++++++++++++++++++++++--------- 4 files changed, 33 insertions(+), 19 deletions(-) diff --git a/index/scorch/snapshot_index.go b/index/scorch/snapshot_index.go index a04760f9d..ece32eee6 100644 --- a/index/scorch/snapshot_index.go +++ b/index/scorch/snapshot_index.go @@ -148,7 +148,6 @@ func (is *IndexSnapshot) newIndexSnapshotFieldDict(field string, if dictStats, ok := dict.(segment.DiskStatsReporter); ok { atomic.AddUint64(&totalBytesRead, dictStats.BytesRead()) } - fmt.Println("bro what", int64(dict.Cardinality())) atomic.AddInt64(&fieldCardinality, int64(dict.Cardinality())) if randomLookup { results <- &asynchSegmentResult{dict: dict} diff --git a/pre_search.go b/pre_search.go index 2b2ef03d1..6978dd5ef 100644 --- a/pre_search.go +++ b/pre_search.go @@ -15,8 +15,6 @@ package bleve import ( - "fmt" - "github.com/blevesearch/bleve/v2/search" ) @@ -98,7 +96,6 @@ func newBM25PreSearchResultProcessor() *bm25PreSearchResultProcessor { // TODO How will this work for queries other than term queries? func (b *bm25PreSearchResultProcessor) add(sr *SearchResult, indexName string) { b.docCount += (sr.docCount) - fmt.Println("docCount: ", b.docCount) for field, cardinality := range sr.fieldCardinality { b.fieldCardinality[field] += cardinality } diff --git a/search/scorer/scorer_term_test.go b/search/scorer/scorer_term_test.go index ffe535183..5a7522514 100644 --- a/search/scorer/scorer_term_test.go +++ b/search/scorer/scorer_term_test.go @@ -30,7 +30,7 @@ func TestTermScorer(t *testing.T) { var queryTerm = []byte("beer") var queryField = "desc" var queryBoost = 1.0 - scorer := NewTermQueryScorer(queryTerm, queryField, queryBoost, docTotal, docTerm, search.SearcherOptions{Explain: true}) + scorer := NewTermQueryScorer(queryTerm, queryField, queryBoost, docTotal, docTerm, 0, search.SearcherOptions{Explain: true}) idf := 1.0 + math.Log(float64(docTotal)/float64(docTerm+1.0)) tests := []struct { @@ -175,7 +175,7 @@ func TestTermScorerWithQueryNorm(t *testing.T) { var queryTerm = []byte("beer") var queryField = "desc" var queryBoost = 3.0 - scorer := NewTermQueryScorer(queryTerm, queryField, queryBoost, docTotal, docTerm, search.SearcherOptions{Explain: true}) + scorer := NewTermQueryScorer(queryTerm, queryField, queryBoost, docTotal, docTerm, 0, search.SearcherOptions{Explain: true}) idf := 1.0 + math.Log(float64(docTotal)/float64(docTerm+1.0)) scorer.SetQueryNorm(2.0) diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index ab4309660..9992da9fe 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -66,18 +66,26 @@ func NewTermSearcherBytes(ctx context.Context, indexReader index.IndexReader, return newTermSearcherFromReader(ctx, indexReader, reader, term, field, boost, options) } -func tfTDFScoreMetrics(indexReader index.IndexReader) (uint64, int, error) { +func tfTDFScoreMetrics(indexReader index.IndexReader) (uint64, float64, error) { // default tf-idf stats count, err := indexReader.DocCount() if err != nil { return 0, 0, err } fieldCardinality := 0 - return count, fieldCardinality, nil + + // fmt.Println("----------tf-idf stats--------") + // fmt.Println("docCount: ", count) + // fmt.Println("fieldCardinality: ", fieldCardinality) + + if count == 0 && fieldCardinality == 0 { + return 0, 0, nil + } + return count, float64(fieldCardinality / int(count)), nil } func bm25ScoreMetrics(ctx context.Context, field string, - indexReader index.IndexReader) (uint64, int, error) { + indexReader index.IndexReader) (uint64, float64, error) { var count uint64 var fieldCardinality int var err error @@ -102,18 +110,21 @@ func bm25ScoreMetrics(ctx context.Context, field string, } } - fmt.Println("----------bm25 stats--------") - fmt.Println("docCount: ", count) - fmt.Println("fieldCardinality: ", fieldCardinality) - fmt.Println("avgDocLength: ", fieldCardinality/int(count)) + // fmt.Println("----------bm25 stats--------") + // fmt.Println("docCount: ", count) + // fmt.Println("fieldCardinality: ", fieldCardinality) + // fmt.Println("avgDocLength: ", fieldCardinality/int(count)) - return count, fieldCardinality, nil + if count == 0 && fieldCardinality == 0 { + return 0, 0, nil + } + return count, float64(fieldCardinality / int(count)), nil } func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReader, reader index.TermFieldReader, term []byte, field string, boost float64, options search.SearcherOptions) (*TermSearcher, error) { var count uint64 - var fieldCardinality int + var avgDocLength float64 var err error if ctx != nil { if similaritModelCallback, ok := ctx.Value(search. @@ -121,13 +132,13 @@ func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReade similarityModel := similaritModelCallback(field) if similarityModel == "" || similarityModel == index.BM25Similarity { // in case of bm25 need to fetch the multipliers as well (perhaps via context's presearch data) - count, fieldCardinality, err = bm25ScoreMetrics(ctx, field, indexReader) + count, avgDocLength, err = bm25ScoreMetrics(ctx, field, indexReader) if err != nil { _ = reader.Close() return nil, err } } else { - count, fieldCardinality, err = tfTDFScoreMetrics(indexReader) + count, avgDocLength, err = tfTDFScoreMetrics(indexReader) if err != nil { _ = reader.Close() return nil, err @@ -135,14 +146,21 @@ func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReade } } else { // default tf-idf stats - count, fieldCardinality, err = tfTDFScoreMetrics(indexReader) + count, avgDocLength, err = tfTDFScoreMetrics(indexReader) if err != nil { _ = reader.Close() return nil, err } } + } else { + // default tf-idf stats + count, avgDocLength, err = tfTDFScoreMetrics(indexReader) + if err != nil { + _ = reader.Close() + return nil, err + } } - scorer := scorer.NewTermQueryScorer(term, field, boost, count, reader.Count(), float64(fieldCardinality/int(count)), options) + scorer := scorer.NewTermQueryScorer(term, field, boost, count, reader.Count(), avgDocLength, options) return &TermSearcher{ indexReader: indexReader, reader: reader, From 04e1e7263fb1b36c4b3d8f7c4fd46dc60fbffd83 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Thu, 12 Dec 2024 13:23:30 +0530 Subject: [PATCH 11/27] cleanup/refactor --- index/scorch/snapshot_index_dict.go | 2 -- index_test.go | 5 +-- search/scorer/scorer_term.go | 4 +-- search/searcher/search_term.go | 52 +++++++++++++---------------- 4 files changed, 28 insertions(+), 35 deletions(-) diff --git a/index/scorch/snapshot_index_dict.go b/index/scorch/snapshot_index_dict.go index 729ca77e8..2ae789c6b 100644 --- a/index/scorch/snapshot_index_dict.go +++ b/index/scorch/snapshot_index_dict.go @@ -16,7 +16,6 @@ package scorch import ( "container/heap" - "fmt" index "github.com/blevesearch/bleve_index_api" segment "github.com/blevesearch/scorch_segment_api/v2" @@ -98,7 +97,6 @@ func (i *IndexSnapshotFieldDict) Next() (*index.DictEntry, error) { } func (i *IndexSnapshotFieldDict) Cardinality() int { - fmt.Println("cardianlity", i.cardinality) return i.cardinality } diff --git a/index_test.go b/index_test.go index 8a60e4006..7fb3abf0c 100644 --- a/index_test.go +++ b/index_test.go @@ -465,8 +465,9 @@ func TestBM25(t *testing.T) { } ctx := context.Background() - // not setting this doesn't perform a presearch for bm25 - // ctx = context.WithValue(ctx, search.SearchTypeKey, search.FetchStatsAndSearch) + // this key is set to ensure that we have a consistent scoring at the index alias + // level (it forces a pre search phase which can have a small overhead) + ctx = context.WithValue(ctx, search.SearchTypeKey, search.FetchStatsAndSearch) res, err = multiPartIndex.SearchInContext(ctx, searchRequest) if err != nil { diff --git a/search/scorer/scorer_term.go b/search/scorer/scorer_term.go index f98d4288d..5f966e489 100644 --- a/search/scorer/scorer_term.go +++ b/search/scorer/scorer_term.go @@ -142,8 +142,8 @@ func (s *TermQueryScorer) docScore(tf, norm float64) float64 { // multiplies deciding how much does a doc length affect the score and also // how much can the term frequency affect the score - var k1 float64 = 1 - var b float64 = 1 + var k1 float64 = 1.2 + var b float64 = 0.75 score = s.idf * (tf * k1) / (tf + k1*(1-b+(b*fieldLength/s.avgDocLength))) } diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index 9992da9fe..240555036 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -74,9 +74,9 @@ func tfTDFScoreMetrics(indexReader index.IndexReader) (uint64, float64, error) { } fieldCardinality := 0 - // fmt.Println("----------tf-idf stats--------") - // fmt.Println("docCount: ", count) - // fmt.Println("fieldCardinality: ", fieldCardinality) + fmt.Println("----------tf-idf stats--------") + fmt.Println("docCount: ", count) + fmt.Println("fieldCardinality: ", fieldCardinality) if count == 0 && fieldCardinality == 0 { return 0, 0, nil @@ -102,18 +102,24 @@ func bm25ScoreMetrics(ctx context.Context, field string, } fieldCardinality = dict.Cardinality() } else { + fmt.Println("prefetching stats") count = bm25Stats["docCount"].(uint64) fieldCardinalityMap := bm25Stats["fieldCardinality"].(map[string]int) fieldCardinality, ok = fieldCardinalityMap[field] if !ok { return 0, 0, fmt.Errorf("field stat for bm25 not present %s", field) } + + fmt.Println("----------bm25 stats--------") + fmt.Println("docCount: ", count) + fmt.Println("fieldCardinality: ", fieldCardinality) + fmt.Println("avgDocLength: ", fieldCardinality/int(count)) } - // fmt.Println("----------bm25 stats--------") - // fmt.Println("docCount: ", count) - // fmt.Println("fieldCardinality: ", fieldCardinality) - // fmt.Println("avgDocLength: ", fieldCardinality/int(count)) + fmt.Println("----------bm25 stats--------") + fmt.Println("docCount: ", count) + fmt.Println("fieldCardinality: ", fieldCardinality) + fmt.Println("avgDocLength: ", fieldCardinality/int(count)) if count == 0 && fieldCardinality == 0 { return 0, 0, nil @@ -121,11 +127,19 @@ func bm25ScoreMetrics(ctx context.Context, field string, return count, float64(fieldCardinality / int(count)), nil } -func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReader, reader index.TermFieldReader, - term []byte, field string, boost float64, options search.SearcherOptions) (*TermSearcher, error) { +func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReader, + reader index.TermFieldReader, term []byte, field string, boost float64, + options search.SearcherOptions) (*TermSearcher, error) { var count uint64 var avgDocLength float64 var err error + + // as a fallback case we track certain stats for tf-idf scoring + count, avgDocLength, err = tfTDFScoreMetrics(indexReader) + if err != nil { + _ = reader.Close() + return nil, err + } if ctx != nil { if similaritModelCallback, ok := ctx.Value(search. GetSimilarityModelCallbackKey).(search.GetSimilarityModelCallbackFn); ok { @@ -137,27 +151,7 @@ func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReade _ = reader.Close() return nil, err } - } else { - count, avgDocLength, err = tfTDFScoreMetrics(indexReader) - if err != nil { - _ = reader.Close() - return nil, err - } } - } else { - // default tf-idf stats - count, avgDocLength, err = tfTDFScoreMetrics(indexReader) - if err != nil { - _ = reader.Close() - return nil, err - } - } - } else { - // default tf-idf stats - count, avgDocLength, err = tfTDFScoreMetrics(indexReader) - if err != nil { - _ = reader.Close() - return nil, err } } scorer := scorer.NewTermQueryScorer(term, field, boost, count, reader.Count(), avgDocLength, options) From ab58975af269b85879999d7730d556bf2b61980e Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Thu, 12 Dec 2024 16:08:57 +0530 Subject: [PATCH 12/27] bug fixes --- index_alias_impl.go | 6 ++++-- index_impl.go | 9 ++++++--- pre_search.go | 8 ++++---- search.go | 4 ++-- search/searcher/search_term.go | 14 +++----------- 5 files changed, 19 insertions(+), 22 deletions(-) diff --git a/index_alias_impl.go b/index_alias_impl.go index dbc7ac85d..32a21efd5 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -16,6 +16,7 @@ package bleve import ( "context" + "fmt" "sync" "time" @@ -714,8 +715,8 @@ func constructSynonymPreSearchData(rv map[string]map[string]interface{}, sr *Sea func constructBM25PreSearchData(rv map[string]map[string]interface{}, sr *SearchResult, indexes []Index) map[string]map[string]interface{} { for _, index := range indexes { rv[index.Name()][search.BM25PreSearchDataKey] = map[string]interface{}{ - "docCount": sr.docCount, - "fieldCardinality": sr.fieldCardinality, + "docCount": sr.DocCount, + "fieldCardinality": sr.FieldCardinality, } } return rv @@ -924,6 +925,7 @@ func MultiSearch(ctx context.Context, req *SearchRequest, preSearchData map[stri var payload map[string]interface{} if preSearchData != nil { payload = preSearchData[in.Name()] + fmt.Println("the payload", payload) } go searchChildIndex(in, createChildSearchRequest(req, payload)) } diff --git a/index_impl.go b/index_impl.go index be8623bbf..1c0648f72 100644 --- a/index_impl.go +++ b/index_impl.go @@ -504,7 +504,10 @@ func (i *indexImpl) preSearch(ctx context.Context, req *SearchRequest, reader in } fs := make(query.FieldSet) - fs = query.ExtractFields(req.Query, i.m, fs) + fs, err = query.ExtractFields(req.Query, i.m, fs) + if err != nil { + return nil, err + } for field := range fs { dict, err := reader.FieldDict(field) @@ -523,8 +526,8 @@ func (i *indexImpl) preSearch(ctx context.Context, req *SearchRequest, reader in }, Hits: knnHits, SynonymResult: fts, - docCount: count, - fieldCardinality: fieldCardinality, + DocCount: count, + FieldCardinality: fieldCardinality, }, nil } diff --git a/pre_search.go b/pre_search.go index 6978dd5ef..3872dcb61 100644 --- a/pre_search.go +++ b/pre_search.go @@ -95,15 +95,15 @@ func newBM25PreSearchResultProcessor() *bm25PreSearchResultProcessor { // TODO How will this work for queries other than term queries? func (b *bm25PreSearchResultProcessor) add(sr *SearchResult, indexName string) { - b.docCount += (sr.docCount) - for field, cardinality := range sr.fieldCardinality { + b.docCount += (sr.DocCount) + for field, cardinality := range sr.FieldCardinality { b.fieldCardinality[field] += cardinality } } func (b *bm25PreSearchResultProcessor) finalize(sr *SearchResult) { - sr.docCount = b.docCount - sr.fieldCardinality = b.fieldCardinality + sr.DocCount = b.docCount + sr.FieldCardinality = b.fieldCardinality } // ----------------------------------------------------------------------------- diff --git a/search.go b/search.go index cebeddfd0..3e4e1a256 100644 --- a/search.go +++ b/search.go @@ -449,8 +449,8 @@ type SearchResult struct { SynonymResult search.FieldTermSynonymMap `json:"synonym_result,omitempty"` // The following fields are applicable to BM25 preSearch - docCount uint64 - fieldCardinality map[string]int // search_field -> cardinality + DocCount uint64 + FieldCardinality map[string]int // search_field -> cardinality } func (sr *SearchResult) Size() int { diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index 240555036..456c10b8b 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -17,6 +17,7 @@ package searcher import ( "context" "fmt" + "math" "reflect" "github.com/blevesearch/bleve/v2/search" @@ -74,10 +75,6 @@ func tfTDFScoreMetrics(indexReader index.IndexReader) (uint64, float64, error) { } fieldCardinality := 0 - fmt.Println("----------tf-idf stats--------") - fmt.Println("docCount: ", count) - fmt.Println("fieldCardinality: ", fieldCardinality) - if count == 0 && fieldCardinality == 0 { return 0, 0, nil } @@ -109,22 +106,17 @@ func bm25ScoreMetrics(ctx context.Context, field string, if !ok { return 0, 0, fmt.Errorf("field stat for bm25 not present %s", field) } - - fmt.Println("----------bm25 stats--------") - fmt.Println("docCount: ", count) - fmt.Println("fieldCardinality: ", fieldCardinality) - fmt.Println("avgDocLength: ", fieldCardinality/int(count)) } fmt.Println("----------bm25 stats--------") fmt.Println("docCount: ", count) fmt.Println("fieldCardinality: ", fieldCardinality) - fmt.Println("avgDocLength: ", fieldCardinality/int(count)) + fmt.Println("avgDocLength: ", math.Ceil(float64(fieldCardinality)/float64(count))) if count == 0 && fieldCardinality == 0 { return 0, 0, nil } - return count, float64(fieldCardinality / int(count)), nil + return count, math.Ceil(float64(fieldCardinality) / float64(count)), nil } func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReader, From dbed9575a344df75dccf4c9249b4d70677b75a7d Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Fri, 13 Dec 2024 17:04:39 +0530 Subject: [PATCH 13/27] fix scatter-gather path --- index_alias_impl.go | 44 ++++++++++++++++++++++++---------- index_impl.go | 24 +++++++------------ mapping/index.go | 2 -- mapping/mapping.go | 5 ---- pre_search.go | 20 +++++++++------- search.go | 3 +-- search/query/query.go | 13 ++++++++++ search/searcher/search_term.go | 13 ++++++---- search/util.go | 5 ++++ 9 files changed, 78 insertions(+), 51 deletions(-) diff --git a/index_alias_impl.go b/index_alias_impl.go index 32a21efd5..56abb34ff 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -16,7 +16,6 @@ package bleve import ( "context" - "fmt" "sync" "time" @@ -578,6 +577,27 @@ type preSearchFlags struct { bm25 bool // needs presearch for this too } +func isBM25Enabled(req *SearchRequest, m mapping.IndexMapping) (bool, query.FieldSet) { + rv := false + fs := make(query.FieldSet) + fs, err := query.ExtractFields(req.Query, m, fs) + if err != nil { + return rv, nil + } + // if there is any field that has bm25 scoring enabled, we set + // the flag to true to presearch the stats needed for the bm25 + // scoring. Otherwise, we just skip the presearch + for field := range fs { + f := m.FieldMappingForPath(field) + if f.Similarity == "" || f.Similarity == index.BM25Similarity { + rv = true + break + } + } + + return rv, fs +} + // preSearchRequired checks if preSearch is required and returns the presearch flags struct // indicating which preSearch is required func preSearchRequired(ctx context.Context, req *SearchRequest, m mapping.IndexMapping) (*preSearchFlags, error){ @@ -606,10 +626,7 @@ func preSearchRequired(ctx context.Context, req *SearchRequest, m mapping.IndexM if ctx != nil { if searchType := ctx.Value(search.SearchTypeKey); searchType != nil { if searchType.(string) == search.FetchStatsAndSearch { - // todo: check mapping to see if bm25 is needed - if _, ok := m.(mapping.BM25Mapping); ok { - bm25 = true - } + bm25, _ = isBM25Enabled(req, m) } } } @@ -713,10 +730,13 @@ func constructSynonymPreSearchData(rv map[string]map[string]interface{}, sr *Sea } } func constructBM25PreSearchData(rv map[string]map[string]interface{}, sr *SearchResult, indexes []Index) map[string]map[string]interface{} { - for _, index := range indexes { - rv[index.Name()][search.BM25PreSearchDataKey] = map[string]interface{}{ - "docCount": sr.DocCount, - "fieldCardinality": sr.FieldCardinality, + bmStats := sr.BM25Stats + if bmStats != nil { + for _, index := range indexes { + rv[index.Name()][search.BM25PreSearchDataKey] = &search.BM25Stats{ + DocCount: bmStats.DocCount, + FieldCardinality: bmStats.FieldCardinality, + } } } return rv @@ -852,10 +872,9 @@ func redistributePreSearchData(req *SearchRequest, indexes []Index) (map[string] for _, index := range indexes { rv[index.Name()][search.SynonymPreSearchDataKey] = fts - // TODO Extend to more stats - if totalDocCount, ok := req.PreSearchData[search.BM25PreSearchDataKey].(uint64); ok { + if bm25Data, ok := req.PreSearchData[search.BM25PreSearchDataKey].(*search.BM25Stats); ok { for _, index := range indexes { - rv[index.Name()][search.BM25PreSearchDataKey] = totalDocCount + rv[index.Name()][search.BM25PreSearchDataKey] = bm25Data } } return rv, nil @@ -925,7 +944,6 @@ func MultiSearch(ctx context.Context, req *SearchRequest, preSearchData map[stri var payload map[string]interface{} if preSearchData != nil { payload = preSearchData[in.Name()] - fmt.Println("the payload", payload) } go searchChildIndex(in, createChildSearchRequest(req, payload)) } diff --git a/index_impl.go b/index_impl.go index 1c0648f72..5497e1df1 100644 --- a/index_impl.go +++ b/index_impl.go @@ -38,7 +38,6 @@ import ( "github.com/blevesearch/bleve/v2/search/collector" "github.com/blevesearch/bleve/v2/search/facet" "github.com/blevesearch/bleve/v2/search/highlight" - "github.com/blevesearch/bleve/v2/search/query" "github.com/blevesearch/bleve/v2/util" index "github.com/blevesearch/bleve_index_api" "github.com/blevesearch/geo/s2" @@ -496,19 +495,12 @@ func (i *indexImpl) preSearch(ctx context.Context, req *SearchRequest, reader in } } } - - if _, ok := i.m.(mapping.BM25Mapping); ok { + if ok, fs := isBM25Enabled(req, i.m); ok { count, err = reader.DocCount() if err != nil { return nil, err } - fs := make(query.FieldSet) - fs, err = query.ExtractFields(req.Query, i.m, fs) - if err != nil { - return nil, err - } - for field := range fs { dict, err := reader.FieldDict(field) if err != nil { @@ -524,10 +516,12 @@ func (i *indexImpl) preSearch(ctx context.Context, req *SearchRequest, reader in Total: 1, Successful: 1, }, - Hits: knnHits, - SynonymResult: fts, - DocCount: count, - FieldCardinality: fieldCardinality, + Hits: knnHits, + SynonymResult: fts, + BM25Stats: &search.BM25Stats{ + DocCount: float64(count), + FieldCardinality: fieldCardinality, + }, }, nil } @@ -583,7 +577,7 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr var fts search.FieldTermSynonymMap var skipSynonymCollector bool - var bm25Data map[string]interface{} + var bm25Data *search.BM25Stats var ok bool if req.PreSearchData != nil { for k, v := range req.PreSearchData { @@ -607,7 +601,7 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr skipKNNCollector = true case search.BM25PreSearchDataKey: if v != nil { - bm25Data, ok = v.(map[string]interface{}) + bm25Data, ok = v.(*search.BM25Stats) if !ok { return nil, fmt.Errorf("bm25 preSearchData must be of type map[string]interface{}") } diff --git a/mapping/index.go b/mapping/index.go index a78a44a23..8a0d5e34a 100644 --- a/mapping/index.go +++ b/mapping/index.go @@ -557,6 +557,4 @@ func (im *IndexMappingImpl) SynonymSourceVisitor(visitor analysis.SynonymSourceV return err } return nil -func (im *IndexMappingImpl) BM25Impl() { - fmt.Println("BM25Impl") } diff --git a/mapping/mapping.go b/mapping/mapping.go index 9b7690c19..a6c1591b8 100644 --- a/mapping/mapping.go +++ b/mapping/mapping.go @@ -74,8 +74,3 @@ type SynonymMapping interface { SynonymSourceVisitor(visitor analysis.SynonymSourceVisitor) error } -type BM25Mapping interface { - IndexMapping - - BM25Impl() -} diff --git a/pre_search.go b/pre_search.go index 3872dcb61..0c3b8365e 100644 --- a/pre_search.go +++ b/pre_search.go @@ -14,9 +14,7 @@ package bleve -import ( - "github.com/blevesearch/bleve/v2/search" -) +import "github.com/blevesearch/bleve/v2/search" // A preSearchResultProcessor processes the data in // the preSearch result from multiple @@ -83,7 +81,7 @@ func (s *synonymPreSearchResultProcessor) finalize(sr *SearchResult) { } type bm25PreSearchResultProcessor struct { - docCount uint64 // bm25 specific stats + docCount float64 // bm25 specific stats fieldCardinality map[string]int } @@ -95,15 +93,19 @@ func newBM25PreSearchResultProcessor() *bm25PreSearchResultProcessor { // TODO How will this work for queries other than term queries? func (b *bm25PreSearchResultProcessor) add(sr *SearchResult, indexName string) { - b.docCount += (sr.DocCount) - for field, cardinality := range sr.FieldCardinality { - b.fieldCardinality[field] += cardinality + if sr.BM25Stats != nil { + b.docCount += (sr.BM25Stats.DocCount) + for field, cardinality := range sr.BM25Stats.FieldCardinality { + b.fieldCardinality[field] += cardinality + } } } func (b *bm25PreSearchResultProcessor) finalize(sr *SearchResult) { - sr.DocCount = b.docCount - sr.FieldCardinality = b.fieldCardinality + sr.BM25Stats = &search.BM25Stats{ + DocCount: b.docCount, + FieldCardinality: b.fieldCardinality, + } } // ----------------------------------------------------------------------------- diff --git a/search.go b/search.go index 3e4e1a256..e13a93703 100644 --- a/search.go +++ b/search.go @@ -449,8 +449,7 @@ type SearchResult struct { SynonymResult search.FieldTermSynonymMap `json:"synonym_result,omitempty"` // The following fields are applicable to BM25 preSearch - DocCount uint64 - FieldCardinality map[string]int // search_field -> cardinality + BM25Stats *search.BM25Stats `json:"bm25_stats,omitempty"` } func (sr *SearchResult) Size() int { diff --git a/search/query/query.go b/search/query/query.go index 86859ae5b..a1f7b3404 100644 --- a/search/query/query.go +++ b/search/query/query.go @@ -105,6 +105,19 @@ func ParsePreSearchData(input []byte) (map[string]interface{}, error) { rv = make(map[string]interface{}) } rv[search.SynonymPreSearchDataKey] = value + case search.BM25PreSearchDataKey: + var value *search.BM25Stats + if v != nil { + err := util.UnmarshalJSON(v, &value) + if err != nil { + return nil, err + } + } + if rv == nil { + rv = make(map[string]interface{}) + } + rv[search.BM25PreSearchDataKey] = value + } } return rv, nil diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index 456c10b8b..ef18bab47 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -87,7 +87,7 @@ func bm25ScoreMetrics(ctx context.Context, field string, var fieldCardinality int var err error - bm25Stats, ok := ctx.Value(search.BM25PreSearchDataKey).(map[string]interface{}) + bm25Stats, ok := ctx.Value(search.BM25PreSearchDataKey).(*search.BM25Stats) if !ok { count, err = indexReader.DocCount() if err != nil { @@ -99,13 +99,16 @@ func bm25ScoreMetrics(ctx context.Context, field string, } fieldCardinality = dict.Cardinality() } else { - fmt.Println("prefetching stats") - count = bm25Stats["docCount"].(uint64) - fieldCardinalityMap := bm25Stats["fieldCardinality"].(map[string]int) - fieldCardinality, ok = fieldCardinalityMap[field] + count = uint64(bm25Stats.DocCount) + fieldCardinality, ok = bm25Stats.FieldCardinality[field] if !ok { return 0, 0, fmt.Errorf("field stat for bm25 not present %s", field) } + // fieldCardinalityMap := bm25Stats["fieldCardinality"].(map[string]int) + // fieldCardinality, ok = fieldCardinalityMap[field] + // if !ok { + // return 0, 0, fmt.Errorf("field stat for bm25 not present %s", field) + // } } fmt.Println("----------bm25 stats--------") diff --git a/search/util.go b/search/util.go index d438cd61d..0c568e15b 100644 --- a/search/util.go +++ b/search/util.go @@ -171,6 +171,11 @@ const BM25MapKey = "_bm25_map_key" const GetSimilarityModelCallbackKey = "_get_similarity_model" +type BM25Stats struct { + DocCount float64 `json:"doc_count"` + FieldCardinality map[string]int `json:"field_cardinality"` +} + type GetSimilarityModelCallbackFn func(field string) string type ScoreExplCorrectionCallbackFunc func(queryMatch *DocumentMatch, knnMatch *DocumentMatch) (float64, *Explanation) From 52e318d9eeb95986781220d282739bc26b9e1643 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Thu, 2 Jan 2025 14:26:41 +0530 Subject: [PATCH 14/27] bug fixes after merge conflict resolution --- index_alias_impl.go | 58 +++++++++++++++++++--------------- index_impl.go | 1 + pre_search.go | 6 ++-- search/searcher/search_term.go | 2 +- 4 files changed, 38 insertions(+), 29 deletions(-) diff --git a/index_alias_impl.go b/index_alias_impl.go index 56abb34ff..88c021a90 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -16,6 +16,7 @@ package bleve import ( "context" + "fmt" "sync" "time" @@ -191,10 +192,11 @@ func (i *indexAliasImpl) SearchInContext(ctx context.Context, req *SearchRequest // indicates that this index alias is set as an Index // in another alias, so we need to do a preSearch search // and NOT a real search + bm25PreSearch, _ := isBM25Enabled(req, i.mapping) flags := &preSearchFlags{ knn: requestHasKNN(req), synonyms: !isMatchNoneQuery(req.Query), - bm25: true, // TODO Just force setting it to true to test + bm25: bm25PreSearch, } return preSearchDataSearch(ctx, req, flags, i.indexes...) } @@ -234,7 +236,7 @@ func (i *indexAliasImpl) SearchInContext(ctx context.Context, req *SearchRequest // - the request requires preSearch var preSearchDuration time.Duration var sr *SearchResult - flags, err := preSearchRequired(req, i.mapping) + flags, err := preSearchRequired(ctx, req, i.mapping) if err != nil { return nil, err } @@ -574,7 +576,7 @@ type asyncSearchResult struct { type preSearchFlags struct { knn bool synonyms bool - bm25 bool // needs presearch for this too + bm25 bool // needs presearch for this too } func isBM25Enabled(req *SearchRequest, m mapping.IndexMapping) (bool, query.FieldSet) { @@ -600,7 +602,7 @@ func isBM25Enabled(req *SearchRequest, m mapping.IndexMapping) (bool, query.Fiel // preSearchRequired checks if preSearch is required and returns the presearch flags struct // indicating which preSearch is required -func preSearchRequired(ctx context.Context, req *SearchRequest, m mapping.IndexMapping) (*preSearchFlags, error){ +func preSearchRequired(ctx context.Context, req *SearchRequest, m mapping.IndexMapping) (*preSearchFlags, error) { // Check for KNN query knn := requestHasKNN(req) var synonyms bool @@ -631,12 +633,12 @@ func preSearchRequired(ctx context.Context, req *SearchRequest, m mapping.IndexM } } } - + if knn || synonyms || bm25 { return &preSearchFlags{ knn: knn, synonyms: synonyms, - bm25: bm25, + bm25: bm25, }, nil } return nil, nil @@ -646,7 +648,7 @@ func preSearch(ctx context.Context, req *SearchRequest, flags *preSearchFlags, i // create a dummy request with a match none query // since we only care about the preSearchData in PreSearch var dummyQuery = req.Query - if !flags.bm25 || !flags.synonyms { + if !flags.bm25 && !flags.synonyms { // create a dummy request with a match none query // since we only care about the preSearchData in PreSearch dummyQuery = query.NewMatchNoneQuery() @@ -728,7 +730,9 @@ func constructSynonymPreSearchData(rv map[string]map[string]interface{}, sr *Sea for _, index := range indexes { rv[index.Name()][search.SynonymPreSearchDataKey] = sr.SynonymResult } + return rv } + func constructBM25PreSearchData(rv map[string]map[string]interface{}, sr *SearchResult, indexes []Index) map[string]map[string]interface{} { bmStats := sr.BM25Stats if bmStats != nil { @@ -744,27 +748,27 @@ func constructBM25PreSearchData(rv map[string]map[string]interface{}, sr *Search func constructPreSearchData(req *SearchRequest, flags *preSearchFlags, preSearchResult *SearchResult, indexes []Index) (map[string]map[string]interface{}, error) { - if flags == nil || preSearchResult == nil { - return nil, fmt.Errorf("invalid input, flags: %v, preSearchResult: %v", flags, preSearchResult) - } - mergedOut := make(map[string]map[string]interface{}, len(indexes)) - for _, index := range indexes { - mergedOut[index.Name()] = make(map[string]interface{}) - } - var err error - if flags.knn { - mergedOut, err = constructKnnPreSearchData(mergedOut, preSearchResult, indexes) - if err != nil { - return nil, err - } - } - if flags.synonyms { - mergedOut = constructSynonymPreSearchData(mergedOut, preSearchResult, indexes) - if flags.bm25 { - mergedOut = constructBM25PreSearchData(mergedOut, preSearchResult, indexes) + if flags == nil || preSearchResult == nil { + return nil, fmt.Errorf("invalid input, flags: %v, preSearchResult: %v", flags, preSearchResult) + } + mergedOut := make(map[string]map[string]interface{}, len(indexes)) + for _, index := range indexes { + mergedOut[index.Name()] = make(map[string]interface{}) + } + var err error + if flags.knn { + mergedOut, err = constructKnnPreSearchData(mergedOut, preSearchResult, indexes) + if err != nil { + return nil, err } - return mergedOut, nil } + if flags.synonyms { + mergedOut = constructSynonymPreSearchData(mergedOut, preSearchResult, indexes) + } + if flags.bm25 { + mergedOut = constructBM25PreSearchData(mergedOut, preSearchResult, indexes) + } + return mergedOut, nil } func preSearchDataSearch(ctx context.Context, req *SearchRequest, flags *preSearchFlags, indexes ...Index) (*SearchResult, error) { @@ -871,6 +875,8 @@ func redistributePreSearchData(req *SearchRequest, indexes []Index) (map[string] if fts, ok := req.PreSearchData[search.SynonymPreSearchDataKey].(search.FieldTermSynonymMap); ok { for _, index := range indexes { rv[index.Name()][search.SynonymPreSearchDataKey] = fts + } + } if bm25Data, ok := req.PreSearchData[search.BM25PreSearchDataKey].(*search.BM25Stats); ok { for _, index := range indexes { diff --git a/index_impl.go b/index_impl.go index 5497e1df1..1261667d2 100644 --- a/index_impl.go +++ b/index_impl.go @@ -38,6 +38,7 @@ import ( "github.com/blevesearch/bleve/v2/search/collector" "github.com/blevesearch/bleve/v2/search/facet" "github.com/blevesearch/bleve/v2/search/highlight" + "github.com/blevesearch/bleve/v2/search/query" "github.com/blevesearch/bleve/v2/util" index "github.com/blevesearch/bleve_index_api" "github.com/blevesearch/geo/s2" diff --git a/pre_search.go b/pre_search.go index 0c3b8365e..3dd7e0fe3 100644 --- a/pre_search.go +++ b/pre_search.go @@ -14,7 +14,9 @@ package bleve -import "github.com/blevesearch/bleve/v2/search" +import ( + "github.com/blevesearch/bleve/v2/search" +) // A preSearchResultProcessor processes the data in // the preSearch result from multiple @@ -94,7 +96,7 @@ func newBM25PreSearchResultProcessor() *bm25PreSearchResultProcessor { // TODO How will this work for queries other than term queries? func (b *bm25PreSearchResultProcessor) add(sr *SearchResult, indexName string) { if sr.BM25Stats != nil { - b.docCount += (sr.BM25Stats.DocCount) + b.docCount += sr.BM25Stats.DocCount for field, cardinality := range sr.BM25Stats.FieldCardinality { b.fieldCardinality[field] += cardinality } diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index ef18bab47..55215b944 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -164,7 +164,7 @@ func NewSynonymSearcher(ctx context.Context, indexReader index.IndexReader, term if err != nil { return nil, err } - return newTermSearcherFromReader(indexReader, reader, term, field, boostVal, options) + return newTermSearcherFromReader(ctx, indexReader, reader, term, field, boostVal, options) } // create a searcher for the term itself termSearcher, err := createTermSearcher(term, boost) From 36db386f13c05251ba502100be7694046cf32f1b Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Mon, 6 Jan 2025 12:28:27 +0530 Subject: [PATCH 15/27] score explanation --- index_alias_impl.go | 2 +- mapping/index.go | 1 + search/scorer/scorer_term.go | 61 +++++++++++++++++++++++++--------- search/searcher/search_term.go | 2 +- 4 files changed, 49 insertions(+), 17 deletions(-) diff --git a/index_alias_impl.go b/index_alias_impl.go index 88c021a90..7956a7a72 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -591,7 +591,7 @@ func isBM25Enabled(req *SearchRequest, m mapping.IndexMapping) (bool, query.Fiel // scoring. Otherwise, we just skip the presearch for field := range fs { f := m.FieldMappingForPath(field) - if f.Similarity == "" || f.Similarity == index.BM25Similarity { + if f.Similarity == index.BM25Similarity { rv = true break } diff --git a/mapping/index.go b/mapping/index.go index 8a0d5e34a..6399bef20 100644 --- a/mapping/index.go +++ b/mapping/index.go @@ -50,6 +50,7 @@ type IndexMappingImpl struct { DefaultAnalyzer string `json:"default_analyzer"` DefaultDateTimeParser string `json:"default_datetime_parser"` DefaultSynonymSource string `json:"default_synonym_source,omitempty"` + DefaultSimilarity string `json:"default_similarity,omitempty"` DefaultField string `json:"default_field"` StoreDynamic bool `json:"store_dynamic"` IndexDynamic bool `json:"index_dynamic"` diff --git a/search/scorer/scorer_term.go b/search/scorer/scorer_term.go index 5f966e489..7c4d6fab9 100644 --- a/search/scorer/scorer_term.go +++ b/search/scorer/scorer_term.go @@ -132,6 +132,11 @@ func (s *TermQueryScorer) SetQueryNorm(qnorm float64) { } } +// multiplies deciding how much does a doc length affect the score and also +// how much can the term frequency affect the score in BM25 scoring +var k1 float64 = 1.2 +var b float64 = 0.75 + func (s *TermQueryScorer) docScore(tf, norm float64) float64 { // tf-idf scoring by default score := tf * norm * s.idf @@ -140,16 +145,52 @@ func (s *TermQueryScorer) docScore(tf, norm float64) float64 { // using the posting's norm value to recompute the field length for the doc num fieldLength := 1 / (norm * norm) - // multiplies deciding how much does a doc length affect the score and also - // how much can the term frequency affect the score - var k1 float64 = 1.2 - var b float64 = 0.75 score = s.idf * (tf * k1) / (tf + k1*(1-b+(b*fieldLength/s.avgDocLength))) } return score } +func (s *TermQueryScorer) scoreExplanation(tf float64, termMatch *index.TermFieldDoc) []*search.Explanation { + var rv []*search.Explanation + if s.avgDocLength > 0 { + fieldLength := 1 / (termMatch.Norm * termMatch.Norm) + fieldNormVal := 1 - b + (b * fieldLength / s.avgDocLength) + fieldNormalizeExplanation := &search.Explanation{ + Value: fieldNormVal, + Message: fmt.Sprintf("fieldNorm(field=%s), b=%f, fieldLength=%f, avgFieldLength=%f)", + s.queryField, b, fieldLength, s.avgDocLength), + } + + saturationExplanation := &search.Explanation{ + Value: k1 / (tf + k1*fieldNormVal), + Message: fmt.Sprintf("saturation(term:%s), k1=%f/(tf=%f + k1*fieldNorm=%f))", + termMatch.Term, tf, k1, fieldNormVal), + Children: []*search.Explanation{fieldNormalizeExplanation}, + } + + rv = make([]*search.Explanation, 3) + rv[0] = &search.Explanation{ + Value: tf, + Message: fmt.Sprintf("tf(termFreq(%s:%s)=%d", s.queryField, s.queryTerm, termMatch.Freq), + } + rv[1] = saturationExplanation + rv[2] = s.idfExplanation + } else { + rv = make([]*search.Explanation, 3) + rv[0] = &search.Explanation{ + Value: tf, + Message: fmt.Sprintf("tf(termFreq(%s:%s)=%d", s.queryField, s.queryTerm, termMatch.Freq), + } + rv[1] = &search.Explanation{ + Value: termMatch.Norm, + Message: fmt.Sprintf("fieldNorm(field=%s, doc=%s)", s.queryField, termMatch.ID), + } + rv[2] = s.idfExplanation + } + return rv +} + func (s *TermQueryScorer) Score(ctx *search.SearchContext, termMatch *index.TermFieldDoc) *search.DocumentMatch { rv := ctx.DocumentMatchPool.Get() // perform any score computations only when needed @@ -163,18 +204,8 @@ func (s *TermQueryScorer) Score(ctx *search.SearchContext, termMatch *index.Term } score := s.docScore(tf, termMatch.Norm) - // todo: explain stuff properly if s.options.Explain { - childrenExplanations := make([]*search.Explanation, 3) - childrenExplanations[0] = &search.Explanation{ - Value: tf, - Message: fmt.Sprintf("tf(termFreq(%s:%s)=%d", s.queryField, s.queryTerm, termMatch.Freq), - } - childrenExplanations[1] = &search.Explanation{ - Value: termMatch.Norm, - Message: fmt.Sprintf("fieldNorm(field=%s, doc=%s)", s.queryField, termMatch.ID), - } - childrenExplanations[2] = s.idfExplanation + childrenExplanations := s.scoreExplanation(tf, termMatch) scoreExplanation = &search.Explanation{ Value: score, Message: fmt.Sprintf("fieldWeight(%s:%s in %s), product of:", s.queryField, s.queryTerm, termMatch.ID), diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index 55215b944..c052ea00c 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -139,7 +139,7 @@ func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReade if similaritModelCallback, ok := ctx.Value(search. GetSimilarityModelCallbackKey).(search.GetSimilarityModelCallbackFn); ok { similarityModel := similaritModelCallback(field) - if similarityModel == "" || similarityModel == index.BM25Similarity { + if similarityModel == index.BM25Similarity { // in case of bm25 need to fetch the multipliers as well (perhaps via context's presearch data) count, avgDocLength, err = bm25ScoreMetrics(ctx, field, indexReader) if err != nil { From e83cca00a193285ff0d8bb882b429bc5d4d97075 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Mon, 6 Jan 2025 12:46:13 +0530 Subject: [PATCH 16/27] default similarity config for an index --- index_test.go | 1 + mapping/index.go | 9 ++++++++- search/scorer/scorer_term.go | 2 +- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/index_test.go b/index_test.go index 7fb3abf0c..df663c0cb 100644 --- a/index_test.go +++ b/index_test.go @@ -357,6 +357,7 @@ func TestBM25(t *testing.T) { indexMapping := NewIndexMapping() indexMapping.TypeField = "type" indexMapping.DefaultAnalyzer = "en" + indexMapping.DefaultSimilarity = index.BM25Similarity documentMapping := NewDocumentMapping() indexMapping.AddDocumentMapping("hotel", documentMapping) indexMapping.StoreDynamic = false diff --git a/mapping/index.go b/mapping/index.go index 6399bef20..70b333e86 100644 --- a/mapping/index.go +++ b/mapping/index.go @@ -488,7 +488,14 @@ func (im *IndexMappingImpl) FieldMappingForPath(path string) FieldMapping { return *fm } - return FieldMapping{} + // the edge case where there are no field mapping defined for a path, just + // return all the field specific defaults from the index mapping. + fm = &FieldMapping{ + Analyzer: im.DefaultAnalyzer, + Similarity: im.DefaultSimilarity, + SynonymSource: im.DefaultSynonymSource, + } + return *fm } // wrapper to satisfy new interface diff --git a/search/scorer/scorer_term.go b/search/scorer/scorer_term.go index 7c4d6fab9..5521e83ac 100644 --- a/search/scorer/scorer_term.go +++ b/search/scorer/scorer_term.go @@ -165,7 +165,7 @@ func (s *TermQueryScorer) scoreExplanation(tf float64, termMatch *index.TermFiel saturationExplanation := &search.Explanation{ Value: k1 / (tf + k1*fieldNormVal), Message: fmt.Sprintf("saturation(term:%s), k1=%f/(tf=%f + k1*fieldNorm=%f))", - termMatch.Term, tf, k1, fieldNormVal), + termMatch.Term, k1, tf, fieldNormVal), Children: []*search.Explanation{fieldNormalizeExplanation}, } From a643a3bd2ed265662c2c9f29b84e96ab5c8822e1 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Mon, 6 Jan 2025 16:59:07 +0530 Subject: [PATCH 17/27] cleanup --- index_test.go | 10 ++++++---- search/searcher/search_term.go | 10 ---------- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/index_test.go b/index_test.go index df663c0cb..f6c0120ea 100644 --- a/index_test.go +++ b/index_test.go @@ -350,7 +350,7 @@ func TestBytesWritten(t *testing.T) { cleanupTmpIndexPath(t, tmpIndexPath4) } -func TestBM25(t *testing.T) { +func TestConsistentScoring(t *testing.T) { tmpIndexPath := createTmpIndexPath(t) defer cleanupTmpIndexPath(t, tmpIndexPath) @@ -402,9 +402,8 @@ func TestBM25(t *testing.T) { t.Error(err) } - fmt.Println("length of hits", res.Hits[0].Score) + singleIndexScore := res.Hits[0].Score dataset, _ := readDataFromFile("sample-data.json") - fmt.Println("length of dataset", len(dataset)) tmpIndexPath1 := createTmpIndexPath(t) defer cleanupTmpIndexPath(t, tmpIndexPath1) @@ -475,7 +474,10 @@ func TestBM25(t *testing.T) { t.Error(err) } - fmt.Println("length of hits alias search", res.Hits[0].Score) + if singleIndexScore != res.Hits[0].Score { + t.Fatalf("expected the scores to be the same, got %v and %v", + singleIndexScore, res.Hits[0].Score) + } } diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index c052ea00c..8aef31d9a 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -104,18 +104,8 @@ func bm25ScoreMetrics(ctx context.Context, field string, if !ok { return 0, 0, fmt.Errorf("field stat for bm25 not present %s", field) } - // fieldCardinalityMap := bm25Stats["fieldCardinality"].(map[string]int) - // fieldCardinality, ok = fieldCardinalityMap[field] - // if !ok { - // return 0, 0, fmt.Errorf("field stat for bm25 not present %s", field) - // } } - fmt.Println("----------bm25 stats--------") - fmt.Println("docCount: ", count) - fmt.Println("fieldCardinality: ", fieldCardinality) - fmt.Println("avgDocLength: ", math.Ceil(float64(fieldCardinality)/float64(count))) - if count == 0 && fieldCardinality == 0 { return 0, 0, nil } From b5a7c9b9ff210c06e3cda91e0e304391bc39e427 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Tue, 7 Jan 2025 17:52:07 +0530 Subject: [PATCH 18/27] keeping scoring as an index level config for consistency --- index_alias_impl.go | 27 +++++++-------------------- index_impl.go | 20 ++++++++++++++------ index_test.go | 2 +- mapping/index.go | 22 +++++++++++++--------- search/searcher/search_term.go | 8 ++++---- search/util.go | 6 +++--- 6 files changed, 42 insertions(+), 43 deletions(-) diff --git a/index_alias_impl.go b/index_alias_impl.go index 7956a7a72..4c29d22e3 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -192,7 +192,7 @@ func (i *indexAliasImpl) SearchInContext(ctx context.Context, req *SearchRequest // indicates that this index alias is set as an Index // in another alias, so we need to do a preSearch search // and NOT a real search - bm25PreSearch, _ := isBM25Enabled(req, i.mapping) + bm25PreSearch := isBM25Enabled(i.mapping) flags := &preSearchFlags{ knn: requestHasKNN(req), synonyms: !isMatchNoneQuery(req.Query), @@ -579,25 +579,12 @@ type preSearchFlags struct { bm25 bool // needs presearch for this too } -func isBM25Enabled(req *SearchRequest, m mapping.IndexMapping) (bool, query.FieldSet) { - rv := false - fs := make(query.FieldSet) - fs, err := query.ExtractFields(req.Query, m, fs) - if err != nil { - return rv, nil - } - // if there is any field that has bm25 scoring enabled, we set - // the flag to true to presearch the stats needed for the bm25 - // scoring. Otherwise, we just skip the presearch - for field := range fs { - f := m.FieldMappingForPath(field) - if f.Similarity == index.BM25Similarity { - rv = true - break - } +func isBM25Enabled(m mapping.IndexMapping) bool { + var rv bool + if m, ok := m.(*mapping.IndexMappingImpl); ok { + rv = m.ScoringModel == index.BM25Scoring } - - return rv, fs + return rv } // preSearchRequired checks if preSearch is required and returns the presearch flags struct @@ -628,7 +615,7 @@ func preSearchRequired(ctx context.Context, req *SearchRequest, m mapping.IndexM if ctx != nil { if searchType := ctx.Value(search.SearchTypeKey); searchType != nil { if searchType.(string) == search.FetchStatsAndSearch { - bm25, _ = isBM25Enabled(req, m) + bm25 = isBM25Enabled(m) } } } diff --git a/index_impl.go b/index_impl.go index 1261667d2..7e40d8ebc 100644 --- a/index_impl.go +++ b/index_impl.go @@ -496,12 +496,17 @@ func (i *indexImpl) preSearch(ctx context.Context, req *SearchRequest, reader in } } } - if ok, fs := isBM25Enabled(req, i.m); ok { + if ok := isBM25Enabled(i.m); ok { count, err = reader.DocCount() if err != nil { return nil, err } + fs := make(query.FieldSet) + fs, err := query.ExtractFields(req.Query, i.m, fs) + if err != nil { + return nil, err + } for field := range fs { dict, err := reader.FieldDict(field) if err != nil { @@ -633,12 +638,15 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr if fts != nil { ctx = context.WithValue(ctx, search.FieldTermSynonymMapKey, fts) } - fieldMappingCallback := func(field string) string { - rv := i.m.FieldMappingForPath(field) - return rv.Similarity + + scoringModelCallback := func() string { + if isBM25Enabled(i.m) { + return index.BM25Scoring + } + return index.DefaultScoringModel } - ctx = context.WithValue(ctx, search.GetSimilarityModelCallbackKey, - search.GetSimilarityModelCallbackFn(fieldMappingCallback)) + ctx = context.WithValue(ctx, search.GetScoringModelCallbackKey, + search.GetScoringModelCallbackFn(scoringModelCallback)) // set the bm25 presearch data (stats important for consistent scoring) in // the context object diff --git a/index_test.go b/index_test.go index f6c0120ea..8a11db593 100644 --- a/index_test.go +++ b/index_test.go @@ -357,7 +357,7 @@ func TestConsistentScoring(t *testing.T) { indexMapping := NewIndexMapping() indexMapping.TypeField = "type" indexMapping.DefaultAnalyzer = "en" - indexMapping.DefaultSimilarity = index.BM25Similarity + indexMapping.ScoringModel = index.BM25Scoring documentMapping := NewDocumentMapping() indexMapping.AddDocumentMapping("hotel", documentMapping) indexMapping.StoreDynamic = false diff --git a/mapping/index.go b/mapping/index.go index 70b333e86..8fb7c2b63 100644 --- a/mapping/index.go +++ b/mapping/index.go @@ -50,7 +50,7 @@ type IndexMappingImpl struct { DefaultAnalyzer string `json:"default_analyzer"` DefaultDateTimeParser string `json:"default_datetime_parser"` DefaultSynonymSource string `json:"default_synonym_source,omitempty"` - DefaultSimilarity string `json:"default_similarity,omitempty"` + ScoringModel string `json:"scoring_model,omitempty"` DefaultField string `json:"default_field"` StoreDynamic bool `json:"store_dynamic"` IndexDynamic bool `json:"index_dynamic"` @@ -202,6 +202,11 @@ func (im *IndexMappingImpl) Validate() error { return err } } + + if _, ok := index.SupportedScoringModels[im.ScoringModel]; !ok { + return fmt.Errorf("unsupported scoring model: %s", im.ScoringModel) + } + return nil } @@ -304,6 +309,12 @@ func (im *IndexMappingImpl) UnmarshalJSON(data []byte) error { if err != nil { return err } + case "scoring_model": + err := util.UnmarshalJSON(v, &im.ScoringModel) + if err != nil { + return err + } + default: invalidKeys = append(invalidKeys, k) } @@ -488,14 +499,7 @@ func (im *IndexMappingImpl) FieldMappingForPath(path string) FieldMapping { return *fm } - // the edge case where there are no field mapping defined for a path, just - // return all the field specific defaults from the index mapping. - fm = &FieldMapping{ - Analyzer: im.DefaultAnalyzer, - Similarity: im.DefaultSimilarity, - SynonymSource: im.DefaultSynonymSource, - } - return *fm + return FieldMapping{} } // wrapper to satisfy new interface diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index 8aef31d9a..5d2a691b9 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -75,7 +75,7 @@ func tfTDFScoreMetrics(indexReader index.IndexReader) (uint64, float64, error) { } fieldCardinality := 0 - if count == 0 && fieldCardinality == 0 { + if count == 0 { return 0, 0, nil } return count, float64(fieldCardinality / int(count)), nil @@ -127,9 +127,9 @@ func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReade } if ctx != nil { if similaritModelCallback, ok := ctx.Value(search. - GetSimilarityModelCallbackKey).(search.GetSimilarityModelCallbackFn); ok { - similarityModel := similaritModelCallback(field) - if similarityModel == index.BM25Similarity { + GetScoringModelCallbackKey).(search.GetScoringModelCallbackFn); ok { + similarityModel := similaritModelCallback() + if similarityModel == index.BM25Scoring { // in case of bm25 need to fetch the multipliers as well (perhaps via context's presearch data) count, avgDocLength, err = bm25ScoreMetrics(ctx, field, indexReader) if err != nil { diff --git a/search/util.go b/search/util.go index 0c568e15b..07538d730 100644 --- a/search/util.go +++ b/search/util.go @@ -169,13 +169,13 @@ func (f FieldTermSynonymMap) MergeWith(fts FieldTermSynonymMap) { const FieldTermSynonymMapKey = "_field_term_synonym_map_key" const BM25MapKey = "_bm25_map_key" -const GetSimilarityModelCallbackKey = "_get_similarity_model" - type BM25Stats struct { DocCount float64 `json:"doc_count"` FieldCardinality map[string]int `json:"field_cardinality"` } -type GetSimilarityModelCallbackFn func(field string) string +const GetScoringModelCallbackKey = "_get_scoring_model" + +type GetScoringModelCallbackFn func() string type ScoreExplCorrectionCallbackFunc func(queryMatch *DocumentMatch, knnMatch *DocumentMatch) (float64, *Explanation) From 7c4873c4f8d6c2aafa62fc4a850e5bdc7a601f75 Mon Sep 17 00:00:00 2001 From: Abhinav Dangeti Date: Tue, 7 Jan 2025 08:24:22 -0700 Subject: [PATCH 19/27] Upgrade bleve_index_api, scorch_segment_api, zapx --- go.mod | 6 +++--- go.sum | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index cfee95607..409c9d21a 100644 --- a/go.mod +++ b/go.mod @@ -5,14 +5,14 @@ go 1.21 require ( github.com/RoaringBitmap/roaring v1.9.3 github.com/bits-and-blooms/bitset v1.12.0 - github.com/blevesearch/bleve_index_api v1.2.0 + github.com/blevesearch/bleve_index_api v1.2.1 github.com/blevesearch/geo v0.1.20 github.com/blevesearch/go-faiss v1.0.24 github.com/blevesearch/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/blevesearch/go-porterstemmer v1.0.3 github.com/blevesearch/goleveldb v1.0.1 github.com/blevesearch/gtreap v0.1.1 - github.com/blevesearch/scorch_segment_api/v2 v2.3.0 + github.com/blevesearch/scorch_segment_api/v2 v2.3.1 github.com/blevesearch/segment v0.9.1 github.com/blevesearch/snowball v0.6.1 github.com/blevesearch/snowballstem v0.9.0 @@ -24,7 +24,7 @@ require ( github.com/blevesearch/zapx/v13 v13.3.10 github.com/blevesearch/zapx/v14 v14.3.10 github.com/blevesearch/zapx/v15 v15.3.17 - github.com/blevesearch/zapx/v16 v16.1.11-0.20241219160422-82553cdd4b38 + github.com/blevesearch/zapx/v16 v16.1.11-0.20250107152255-021e66397612 github.com/couchbase/moss v0.2.0 github.com/golang/protobuf v1.3.2 github.com/spf13/cobra v1.7.0 diff --git a/go.sum b/go.sum index f21c89611..847eec37c 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ github.com/RoaringBitmap/roaring v1.9.3 h1:t4EbC5qQwnisr5PrP9nt0IRhRTb9gMUgQF4t4 github.com/RoaringBitmap/roaring v1.9.3/go.mod h1:6AXUsoIEzDTFFQCe1RbGA6uFONMhvejWj5rqITANK90= github.com/bits-and-blooms/bitset v1.12.0 h1:U/q1fAF7xXRhFCrhROzIfffYnu+dlS38vCZtmFVPHmA= github.com/bits-and-blooms/bitset v1.12.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= -github.com/blevesearch/bleve_index_api v1.2.0 h1:/DXMMWBwx/UmGKM1xDhTwDoJI5yQrG6rqRWPFcOgUVo= -github.com/blevesearch/bleve_index_api v1.2.0/go.mod h1:PbcwjIcRmjhGbkS/lJCpfgVSMROV6TRubGGAODaK1W8= +github.com/blevesearch/bleve_index_api v1.2.1 h1:IuXwLvmyp7I7+e0FOA68gcHHLfzSQ4AqQ8wVab5uxk0= +github.com/blevesearch/bleve_index_api v1.2.1/go.mod h1:rKQDl4u51uwafZxFrPD1R7xFOwKnzZW7s/LSeK4lgo0= github.com/blevesearch/geo v0.1.20 h1:paaSpu2Ewh/tn5DKn/FB5SzvH0EWupxHEIwbCk/QPqM= github.com/blevesearch/geo v0.1.20/go.mod h1:DVG2QjwHNMFmjo+ZgzrIq2sfCh6rIHzy9d9d0B59I6w= github.com/blevesearch/go-faiss v1.0.24 h1:K79IvKjoKHdi7FdiXEsAhxpMuns0x4fM0BO93bW5jLI= @@ -19,8 +19,8 @@ github.com/blevesearch/gtreap v0.1.1/go.mod h1:QaQyDRAT51sotthUWAH4Sj08awFSSWzgY github.com/blevesearch/mmap-go v1.0.2/go.mod h1:ol2qBqYaOUsGdm7aRMRrYGgPvnwLe6Y+7LMvAB5IbSA= github.com/blevesearch/mmap-go v1.0.4 h1:OVhDhT5B/M1HNPpYPBKIEJaD0F3Si+CrEKULGCDPWmc= github.com/blevesearch/mmap-go v1.0.4/go.mod h1:EWmEAOmdAS9z/pi/+Toxu99DnsbhG1TIxUoRmJw/pSs= -github.com/blevesearch/scorch_segment_api/v2 v2.3.0 h1:vxCjbXAkkEBSb4AB3Iqgr/EJcPyYRsiGxpcvsS8E1Dw= -github.com/blevesearch/scorch_segment_api/v2 v2.3.0/go.mod h1:5y+TgXYSx+xJGaCwSlvy9G/UJBIY5wzvIkhvhBm2ATc= +github.com/blevesearch/scorch_segment_api/v2 v2.3.1 h1:jjexIzwOdBtC9MlUceNErYHepLvoKxTdA5atbeZSRWE= +github.com/blevesearch/scorch_segment_api/v2 v2.3.1/go.mod h1:Np3Y03rsemM5TsyFxQ3wy+tG97EcviLTbp2S5W0tpRY= github.com/blevesearch/segment v0.9.1 h1:+dThDy+Lvgj5JMxhmOVlgFfkUtZV2kw49xax4+jTfSU= github.com/blevesearch/segment v0.9.1/go.mod h1:zN21iLm7+GnBHWTao9I+Au/7MBiL8pPFtJBJTsk6kQw= github.com/blevesearch/snowball v0.6.1 h1:cDYjn/NCH+wwt2UdehaLpr2e4BwLIjN4V/TdLsL+B5A= @@ -43,8 +43,8 @@ github.com/blevesearch/zapx/v14 v14.3.10 h1:SG6xlsL+W6YjhX5N3aEiL/2tcWh3DO75Bnz7 github.com/blevesearch/zapx/v14 v14.3.10/go.mod h1:qqyuR0u230jN1yMmE4FIAuCxmahRQEOehF78m6oTgns= github.com/blevesearch/zapx/v15 v15.3.17 h1:NkkMI98pYLq/uHnB6YWcITrrLpCVyvZ9iP+AyfpW1Ys= github.com/blevesearch/zapx/v15 v15.3.17/go.mod h1:vXRQzJJvlGVCdmOD5hg7t7JdjUT5DmDPhsAfjvtzIq8= -github.com/blevesearch/zapx/v16 v16.1.11-0.20241219160422-82553cdd4b38 h1:iJ3Q3sbyo2d0bjfb720RmGjj7cqzh/EdP3528ggDIMY= -github.com/blevesearch/zapx/v16 v16.1.11-0.20241219160422-82553cdd4b38/go.mod h1:JTZseJiEpogtkepKSubIKAmfgbQiOReJXfmjxB1qta4= +github.com/blevesearch/zapx/v16 v16.1.11-0.20250107152255-021e66397612 h1:LhORiqEVyUPUrVETzmmVuT0Yudsz2R3qGLFJWUpMsQo= +github.com/blevesearch/zapx/v16 v16.1.11-0.20250107152255-021e66397612/go.mod h1:+FIylxb+5Z/sFVmNaGpppGLHKBMUEnPSbkKoi+izER8= github.com/couchbase/ghistogram v0.1.0 h1:b95QcQTCzjTUocDXp/uMgSNQi8oj1tGwnJ4bODWZnps= github.com/couchbase/ghistogram v0.1.0/go.mod h1:s1Jhy76zqfEecpNWJfWUiKZookAFaiGOEoyzgHt9i7k= github.com/couchbase/moss v0.2.0 h1:VCYrMzFwEryyhRSeI+/b3tRBSeTpi/8gn5Kf6dxqn+o= From 12c2c72a432ac26ddae0d0a0d3c1385f6dd7561e Mon Sep 17 00:00:00 2001 From: Abhinav Dangeti Date: Wed, 8 Jan 2025 10:03:06 -0700 Subject: [PATCH 20/27] Bump up zapx's v11, v12, v13, v14, v15 on account of interface change --- go.mod | 10 +++++----- go.sum | 20 ++++++++++---------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/go.mod b/go.mod index 409c9d21a..49b71da6e 100644 --- a/go.mod +++ b/go.mod @@ -19,11 +19,11 @@ require ( github.com/blevesearch/stempel v0.2.0 github.com/blevesearch/upsidedown_store_api v1.0.2 github.com/blevesearch/vellum v1.1.0 - github.com/blevesearch/zapx/v11 v11.3.10 - github.com/blevesearch/zapx/v12 v12.3.10 - github.com/blevesearch/zapx/v13 v13.3.10 - github.com/blevesearch/zapx/v14 v14.3.10 - github.com/blevesearch/zapx/v15 v15.3.17 + github.com/blevesearch/zapx/v11 v11.3.11 + github.com/blevesearch/zapx/v12 v12.3.11 + github.com/blevesearch/zapx/v13 v13.3.11 + github.com/blevesearch/zapx/v14 v14.3.11 + github.com/blevesearch/zapx/v15 v15.3.18 github.com/blevesearch/zapx/v16 v16.1.11-0.20250107152255-021e66397612 github.com/couchbase/moss v0.2.0 github.com/golang/protobuf v1.3.2 diff --git a/go.sum b/go.sum index 847eec37c..1914f7919 100644 --- a/go.sum +++ b/go.sum @@ -33,16 +33,16 @@ github.com/blevesearch/upsidedown_store_api v1.0.2 h1:U53Q6YoWEARVLd1OYNc9kvhBMG github.com/blevesearch/upsidedown_store_api v1.0.2/go.mod h1:M01mh3Gpfy56Ps/UXHjEO/knbqyQ1Oamg8If49gRwrQ= github.com/blevesearch/vellum v1.1.0 h1:CinkGyIsgVlYf8Y2LUQHvdelgXr6PYuvoDIajq6yR9w= github.com/blevesearch/vellum v1.1.0/go.mod h1:QgwWryE8ThtNPxtgWJof5ndPfx0/YMBh+W2weHKPw8Y= -github.com/blevesearch/zapx/v11 v11.3.10 h1:hvjgj9tZ9DeIqBCxKhi70TtSZYMdcFn7gDb71Xo/fvk= -github.com/blevesearch/zapx/v11 v11.3.10/go.mod h1:0+gW+FaE48fNxoVtMY5ugtNHHof/PxCqh7CnhYdnMzQ= -github.com/blevesearch/zapx/v12 v12.3.10 h1:yHfj3vXLSYmmsBleJFROXuO08mS3L1qDCdDK81jDl8s= -github.com/blevesearch/zapx/v12 v12.3.10/go.mod h1:0yeZg6JhaGxITlsS5co73aqPtM04+ycnI6D1v0mhbCs= -github.com/blevesearch/zapx/v13 v13.3.10 h1:0KY9tuxg06rXxOZHg3DwPJBjniSlqEgVpxIqMGahDE8= -github.com/blevesearch/zapx/v13 v13.3.10/go.mod h1:w2wjSDQ/WBVeEIvP0fvMJZAzDwqwIEzVPnCPrz93yAk= -github.com/blevesearch/zapx/v14 v14.3.10 h1:SG6xlsL+W6YjhX5N3aEiL/2tcWh3DO75Bnz77pSwwKU= -github.com/blevesearch/zapx/v14 v14.3.10/go.mod h1:qqyuR0u230jN1yMmE4FIAuCxmahRQEOehF78m6oTgns= -github.com/blevesearch/zapx/v15 v15.3.17 h1:NkkMI98pYLq/uHnB6YWcITrrLpCVyvZ9iP+AyfpW1Ys= -github.com/blevesearch/zapx/v15 v15.3.17/go.mod h1:vXRQzJJvlGVCdmOD5hg7t7JdjUT5DmDPhsAfjvtzIq8= +github.com/blevesearch/zapx/v11 v11.3.11 h1:r6/wFHFAKWvXJb82f5aO53l6p+gRH6eiX7S1tb3VGc0= +github.com/blevesearch/zapx/v11 v11.3.11/go.mod h1:0+gW+FaE48fNxoVtMY5ugtNHHof/PxCqh7CnhYdnMzQ= +github.com/blevesearch/zapx/v12 v12.3.11 h1:GBBAmXesxXLV5UZ+FZ0qILb7HPssT+kxEkbPPfp5HPM= +github.com/blevesearch/zapx/v12 v12.3.11/go.mod h1:0yeZg6JhaGxITlsS5co73aqPtM04+ycnI6D1v0mhbCs= +github.com/blevesearch/zapx/v13 v13.3.11 h1:H5ZvgS1qM1XKzsAuwp3kvDfh5sJFu9bLH/B8U6Im5e8= +github.com/blevesearch/zapx/v13 v13.3.11/go.mod h1:w2wjSDQ/WBVeEIvP0fvMJZAzDwqwIEzVPnCPrz93yAk= +github.com/blevesearch/zapx/v14 v14.3.11 h1:pg+c/YFzMJ32GkOwLzH/HAQ/GBr6y1Ar7/K5ZQpxTNo= +github.com/blevesearch/zapx/v14 v14.3.11/go.mod h1:qqyuR0u230jN1yMmE4FIAuCxmahRQEOehF78m6oTgns= +github.com/blevesearch/zapx/v15 v15.3.18 h1:yJcQnQyHGNF6rAiwq85OHn3HaXo26t7vgd83RclEw7U= +github.com/blevesearch/zapx/v15 v15.3.18/go.mod h1:vXRQzJJvlGVCdmOD5hg7t7JdjUT5DmDPhsAfjvtzIq8= github.com/blevesearch/zapx/v16 v16.1.11-0.20250107152255-021e66397612 h1:LhORiqEVyUPUrVETzmmVuT0Yudsz2R3qGLFJWUpMsQo= github.com/blevesearch/zapx/v16 v16.1.11-0.20250107152255-021e66397612/go.mod h1:+FIylxb+5Z/sFVmNaGpppGLHKBMUEnPSbkKoi+izER8= github.com/couchbase/ghistogram v0.1.0 h1:b95QcQTCzjTUocDXp/uMgSNQi8oj1tGwnJ4bODWZnps= From ce537e67eafc41cf14f1104c21bd926922a79682 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Thu, 9 Jan 2025 12:23:08 +0530 Subject: [PATCH 21/27] code comments and handling edge case --- mapping/index.go | 2 +- search/searcher/search_term.go | 17 ++++++++------- search/util.go | 40 ++++++++++++++++++++++++---------- 3 files changed, 38 insertions(+), 21 deletions(-) diff --git a/mapping/index.go b/mapping/index.go index 8fb7c2b63..6150f2a38 100644 --- a/mapping/index.go +++ b/mapping/index.go @@ -203,7 +203,7 @@ func (im *IndexMappingImpl) Validate() error { } } - if _, ok := index.SupportedScoringModels[im.ScoringModel]; !ok { + if _, ok := index.SupportedScoringModels[im.ScoringModel]; !ok && im.ScoringModel != "" { return fmt.Errorf("unsupported scoring model: %s", im.ScoringModel) } diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index 5d2a691b9..748c2f095 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -73,8 +73,8 @@ func tfTDFScoreMetrics(indexReader index.IndexReader) (uint64, float64, error) { if err != nil { return 0, 0, err } + // field cardinality metric is not used in the tf-idf scoring algo. fieldCardinality := 0 - if count == 0 { return 0, 0, nil } @@ -120,17 +120,18 @@ func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReade var err error // as a fallback case we track certain stats for tf-idf scoring - count, avgDocLength, err = tfTDFScoreMetrics(indexReader) - if err != nil { - _ = reader.Close() - return nil, err - } if ctx != nil { if similaritModelCallback, ok := ctx.Value(search. GetScoringModelCallbackKey).(search.GetScoringModelCallbackFn); ok { similarityModel := similaritModelCallback() - if similarityModel == index.BM25Scoring { - // in case of bm25 need to fetch the multipliers as well (perhaps via context's presearch data) + switch similarityModel { + case index.TFIDFScoring: + count, avgDocLength, err = tfTDFScoreMetrics(indexReader) + if err != nil { + _ = reader.Close() + return nil, err + } + case index.BM25Scoring: count, avgDocLength, err = bm25ScoreMetrics(ctx, field, indexReader) if err != nil { _ = reader.Close() diff --git a/search/util.go b/search/util.go index 07538d730..453adddbb 100644 --- a/search/util.go +++ b/search/util.go @@ -135,21 +135,46 @@ const MinGeoBufPoolSize = 24 type GeoBufferPoolCallbackFunc func() *s2.GeoBufferPool +// PreSearchKey indicates whether to perform a preliminary search to gather necessary +// information which would be used in the actual search down the line. +const PreSearchKey = "_presearch_key" + +// *PreSearchDataKey are used to store the data gathered during the presearch phase +// which would be use in the actual search phase. const KnnPreSearchDataKey = "_knn_pre_search_data_key" const SynonymPreSearchDataKey = "_synonym_pre_search_data_key" const BM25PreSearchDataKey = "_bm25_pre_search_data_key" -const PreSearchKey = "_presearch_key" - +// SearchTypeKey is used to identify type of the search being performed. +// +// for consistent scoring in cases an index is partitioned/sharded (using an +// index alias), FetchStatsAndSearch helps in aggregating the necessary stats across +// all the child bleve indexes (shards/partitions) first before the actual search +// is performed. const SearchTypeKey = "_search_type_key" -const FetchStatsAndSearch = "fetch_stats_and_search" +// The following keys are used to invoke the callbacks at the start and end stages +// of optimizing the disjunction/conjunction searcher creation. const SearcherStartCallbackKey = "_searcher_start_callback_key" const SearcherEndCallbackKey = "_searcher_end_callback_key" +// FieldTermSynonymMapKey is used to store and transport the synonym definitions data +// to the actual search phase which would use the synonyms to perform the search. +const FieldTermSynonymMapKey = "_field_term_synonym_map_key" + +const FetchStatsAndSearch = "fetch_stats_and_search" + +// GetScoringModelCallbackKey is used to help the underlying searcher identify +// which scoring mechanism to use based on index mapping. +const GetScoringModelCallbackKey = "_get_scoring_model" + type SearcherStartCallbackFn func(size uint64) error type SearcherEndCallbackFn func(size uint64) error +type GetScoringModelCallbackFn func() string + +type ScoreExplCorrectionCallbackFunc func(queryMatch *DocumentMatch, knnMatch *DocumentMatch) (float64, *Explanation) + // field -> term -> synonyms type FieldTermSynonymMap map[string]map[string][]string @@ -166,16 +191,7 @@ func (f FieldTermSynonymMap) MergeWith(fts FieldTermSynonymMap) { } } -const FieldTermSynonymMapKey = "_field_term_synonym_map_key" -const BM25MapKey = "_bm25_map_key" - type BM25Stats struct { DocCount float64 `json:"doc_count"` FieldCardinality map[string]int `json:"field_cardinality"` } - -const GetScoringModelCallbackKey = "_get_scoring_model" - -type GetScoringModelCallbackFn func() string - -type ScoreExplCorrectionCallbackFunc func(queryMatch *DocumentMatch, knnMatch *DocumentMatch) (float64, *Explanation) From 79bd0c162db141ac13f382031ac1bc6c19f89f0b Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Thu, 9 Jan 2025 12:41:20 +0530 Subject: [PATCH 22/27] unit tests fix --- search/searcher/search_term.go | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index 748c2f095..f4d46b634 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -118,26 +118,29 @@ func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReade var count uint64 var avgDocLength float64 var err error + var similarityModel string // as a fallback case we track certain stats for tf-idf scoring if ctx != nil { if similaritModelCallback, ok := ctx.Value(search. GetScoringModelCallbackKey).(search.GetScoringModelCallbackFn); ok { - similarityModel := similaritModelCallback() - switch similarityModel { - case index.TFIDFScoring: - count, avgDocLength, err = tfTDFScoreMetrics(indexReader) - if err != nil { - _ = reader.Close() - return nil, err - } - case index.BM25Scoring: - count, avgDocLength, err = bm25ScoreMetrics(ctx, field, indexReader) - if err != nil { - _ = reader.Close() - return nil, err - } - } + similarityModel = similaritModelCallback() + } + } + switch similarityModel { + case index.BM25Scoring: + count, avgDocLength, err = bm25ScoreMetrics(ctx, field, indexReader) + if err != nil { + _ = reader.Close() + return nil, err + } + case index.TFIDFScoring: + fallthrough + default: + count, avgDocLength, err = tfTDFScoreMetrics(indexReader) + if err != nil { + _ = reader.Close() + return nil, err } } scorer := scorer.NewTermQueryScorer(term, field, boost, count, reader.Count(), avgDocLength, options) From 8cdb5259a75532a8facb721ce78df2abeaa24e32 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Thu, 9 Jan 2025 15:37:41 +0530 Subject: [PATCH 23/27] cleanup? --- index_alias_impl.go | 2 +- index_test.go | 4 ++-- search/util.go | 7 ++++--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/index_alias_impl.go b/index_alias_impl.go index 4c29d22e3..a4f724e34 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -614,7 +614,7 @@ func preSearchRequired(ctx context.Context, req *SearchRequest, m mapping.IndexM if !isMatchNoneQuery(req.Query) { if ctx != nil { if searchType := ctx.Value(search.SearchTypeKey); searchType != nil { - if searchType.(string) == search.FetchStatsAndSearch { + if searchType.(string) == search.GlobalScoring { bm25 = isBM25Enabled(m) } } diff --git a/index_test.go b/index_test.go index 8a11db593..f2e267694 100644 --- a/index_test.go +++ b/index_test.go @@ -350,7 +350,7 @@ func TestBytesWritten(t *testing.T) { cleanupTmpIndexPath(t, tmpIndexPath4) } -func TestConsistentScoring(t *testing.T) { +func TestGlobalScoring(t *testing.T) { tmpIndexPath := createTmpIndexPath(t) defer cleanupTmpIndexPath(t, tmpIndexPath) @@ -467,7 +467,7 @@ func TestConsistentScoring(t *testing.T) { ctx := context.Background() // this key is set to ensure that we have a consistent scoring at the index alias // level (it forces a pre search phase which can have a small overhead) - ctx = context.WithValue(ctx, search.SearchTypeKey, search.FetchStatsAndSearch) + ctx = context.WithValue(ctx, search.SearchTypeKey, search.GlobalScoring) res, err = multiPartIndex.SearchInContext(ctx, searchRequest) if err != nil { diff --git a/search/util.go b/search/util.go index 453adddbb..129966e53 100644 --- a/search/util.go +++ b/search/util.go @@ -148,9 +148,10 @@ const BM25PreSearchDataKey = "_bm25_pre_search_data_key" // SearchTypeKey is used to identify type of the search being performed. // // for consistent scoring in cases an index is partitioned/sharded (using an -// index alias), FetchStatsAndSearch helps in aggregating the necessary stats across +// index alias), GlobalScoring helps in aggregating the necessary stats across // all the child bleve indexes (shards/partitions) first before the actual search -// is performed. +// is performed, such that the scoring involved using these stats would be at a +// global level. const SearchTypeKey = "_search_type_key" // The following keys are used to invoke the callbacks at the start and end stages @@ -162,7 +163,7 @@ const SearcherEndCallbackKey = "_searcher_end_callback_key" // to the actual search phase which would use the synonyms to perform the search. const FieldTermSynonymMapKey = "_field_term_synonym_map_key" -const FetchStatsAndSearch = "fetch_stats_and_search" +const GlobalScoring = "_global_scoring" // GetScoringModelCallbackKey is used to help the underlying searcher identify // which scoring mechanism to use based on index mapping. From d478f4f9278a8bf6966a374c713103a80e82a08b Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Fri, 10 Jan 2025 12:52:14 +0530 Subject: [PATCH 24/27] code comment, exposing the multipliers to be made configurable --- search/scorer/scorer_term.go | 21 ++++++++------------- search/searcher/search_term.go | 13 ++++++------- search/util.go | 7 +++++++ 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/search/scorer/scorer_term.go b/search/scorer/scorer_term.go index 5521e83ac..15533686c 100644 --- a/search/scorer/scorer_term.go +++ b/search/scorer/scorer_term.go @@ -35,8 +35,8 @@ type TermQueryScorer struct { queryTerm string queryField string queryBoost float64 - docTerm uint64 - docTotal uint64 + docTerm uint64 // number of documents containing the term + docTotal uint64 // total number of documents in the index avgDocLength float64 idf float64 options search.SearcherOptions @@ -132,11 +132,6 @@ func (s *TermQueryScorer) SetQueryNorm(qnorm float64) { } } -// multiplies deciding how much does a doc length affect the score and also -// how much can the term frequency affect the score in BM25 scoring -var k1 float64 = 1.2 -var b float64 = 0.75 - func (s *TermQueryScorer) docScore(tf, norm float64) float64 { // tf-idf scoring by default score := tf * norm * s.idf @@ -145,8 +140,8 @@ func (s *TermQueryScorer) docScore(tf, norm float64) float64 { // using the posting's norm value to recompute the field length for the doc num fieldLength := 1 / (norm * norm) - score = s.idf * (tf * k1) / - (tf + k1*(1-b+(b*fieldLength/s.avgDocLength))) + score = s.idf * (tf * search.BM25_k1) / + (tf + search.BM25_k1*(1-search.BM25_b+(search.BM25_b*fieldLength/s.avgDocLength))) } return score } @@ -155,17 +150,17 @@ func (s *TermQueryScorer) scoreExplanation(tf float64, termMatch *index.TermFiel var rv []*search.Explanation if s.avgDocLength > 0 { fieldLength := 1 / (termMatch.Norm * termMatch.Norm) - fieldNormVal := 1 - b + (b * fieldLength / s.avgDocLength) + fieldNormVal := 1 - search.BM25_b + (search.BM25_b * fieldLength / s.avgDocLength) fieldNormalizeExplanation := &search.Explanation{ Value: fieldNormVal, Message: fmt.Sprintf("fieldNorm(field=%s), b=%f, fieldLength=%f, avgFieldLength=%f)", - s.queryField, b, fieldLength, s.avgDocLength), + s.queryField, search.BM25_b, fieldLength, s.avgDocLength), } saturationExplanation := &search.Explanation{ - Value: k1 / (tf + k1*fieldNormVal), + Value: search.BM25_k1 / (tf + search.BM25_k1*fieldNormVal), Message: fmt.Sprintf("saturation(term:%s), k1=%f/(tf=%f + k1*fieldNorm=%f))", - termMatch.Term, k1, tf, fieldNormVal), + termMatch.Term, search.BM25_k1, tf, fieldNormVal), Children: []*search.Explanation{fieldNormalizeExplanation}, } diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index f4d46b634..1c33c6a41 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -67,18 +67,17 @@ func NewTermSearcherBytes(ctx context.Context, indexReader index.IndexReader, return newTermSearcherFromReader(ctx, indexReader, reader, term, field, boost, options) } -func tfTDFScoreMetrics(indexReader index.IndexReader) (uint64, float64, error) { +func tfIDFScoreMetrics(indexReader index.IndexReader) (uint64, error) { // default tf-idf stats count, err := indexReader.DocCount() if err != nil { - return 0, 0, err + return 0, err } - // field cardinality metric is not used in the tf-idf scoring algo. - fieldCardinality := 0 + if count == 0 { - return 0, 0, nil + return 0, nil } - return count, float64(fieldCardinality / int(count)), nil + return count, nil } func bm25ScoreMetrics(ctx context.Context, field string, @@ -137,7 +136,7 @@ func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReade case index.TFIDFScoring: fallthrough default: - count, avgDocLength, err = tfTDFScoreMetrics(indexReader) + count, err = tfIDFScoreMetrics(indexReader) if err != nil { _ = reader.Close() return nil, err diff --git a/search/util.go b/search/util.go index 129966e53..cde561a64 100644 --- a/search/util.go +++ b/search/util.go @@ -192,6 +192,13 @@ func (f FieldTermSynonymMap) MergeWith(fts FieldTermSynonymMap) { } } +// BM25 specific multipliers which affect the scoring of a document. +// +// BM25_b - how much does a doc's field length affect the score +// BM25_k1 - how much can the term frequency affect the score +var BM25_k1 float64 = 1.2 +var BM25_b float64 = 0.75 + type BM25Stats struct { DocCount float64 `json:"doc_count"` FieldCardinality map[string]int `json:"field_cardinality"` From eaca63a2e9593ca273acc3200b58591dd36d192d Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Mon, 13 Jan 2025 10:20:51 +0530 Subject: [PATCH 25/27] update score explanation, code cleanup --- search/scorer/scorer_term.go | 25 ++++++++++++++++++------- search/scorer/scorer_term_test.go | 8 ++++---- search/util.go | 8 +++++--- 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/search/scorer/scorer_term.go b/search/scorer/scorer_term.go index 15533686c..f5f8ec935 100644 --- a/search/scorer/scorer_term.go +++ b/search/scorer/scorer_term.go @@ -76,6 +76,13 @@ func (s *TermQueryScorer) computeIDF(avgDocLength float64, docTotal, docTerm uin return rv } +// queryTerm - the specific term being scored by this scorer object +// queryField - the field in which the term is being searched +// queryBoost - the boost value for the query term +// docTotal - total number of documents in the index +// docTerm - number of documents containing the term +// avgDocLength - average document length in the index +// options - search options such as explain scoring, include the location of the term etc. func NewTermQueryScorer(queryTerm []byte, queryField string, queryBoost float64, docTotal, docTerm uint64, avgDocLength float64, options search.SearcherOptions) *TermQueryScorer { @@ -132,9 +139,7 @@ func (s *TermQueryScorer) SetQueryNorm(qnorm float64) { } } -func (s *TermQueryScorer) docScore(tf, norm float64) float64 { - // tf-idf scoring by default - score := tf * norm * s.idf +func (s *TermQueryScorer) docScore(tf, norm float64) (score float64, model string) { if s.avgDocLength > 0 { // bm25 scoring // using the posting's norm value to recompute the field length for the doc num @@ -142,8 +147,13 @@ func (s *TermQueryScorer) docScore(tf, norm float64) float64 { score = s.idf * (tf * search.BM25_k1) / (tf + search.BM25_k1*(1-search.BM25_b+(search.BM25_b*fieldLength/s.avgDocLength))) + model = index.BM25Scoring + } else { + // tf-idf scoring by default + score = tf * norm * s.idf + model = index.DefaultScoringModel } - return score + return score, model } func (s *TermQueryScorer) scoreExplanation(tf float64, termMatch *index.TermFieldDoc) []*search.Explanation { @@ -198,12 +208,13 @@ func (s *TermQueryScorer) Score(ctx *search.SearchContext, termMatch *index.Term tf = math.Sqrt(float64(termMatch.Freq)) } - score := s.docScore(tf, termMatch.Norm) + score, scoringModel := s.docScore(tf, termMatch.Norm) if s.options.Explain { childrenExplanations := s.scoreExplanation(tf, termMatch) scoreExplanation = &search.Explanation{ - Value: score, - Message: fmt.Sprintf("fieldWeight(%s:%s in %s), product of:", s.queryField, s.queryTerm, termMatch.ID), + Value: score, + Message: fmt.Sprintf("fieldWeight(%s:%s in %s), as per %s model, "+ + "product of:", s.queryField, s.queryTerm, termMatch.ID, scoringModel), Children: childrenExplanations, } } diff --git a/search/scorer/scorer_term_test.go b/search/scorer/scorer_term_test.go index 5a7522514..097dbe243 100644 --- a/search/scorer/scorer_term_test.go +++ b/search/scorer/scorer_term_test.go @@ -58,7 +58,7 @@ func TestTermScorer(t *testing.T) { Sort: []string{}, Expl: &search.Explanation{ Value: math.Sqrt(1.0) * idf, - Message: "fieldWeight(desc:beer in one), product of:", + Message: "fieldWeight(desc:beer in one), as per tfidf model, product of:", Children: []*search.Explanation{ { Value: 1, @@ -100,7 +100,7 @@ func TestTermScorer(t *testing.T) { Sort: []string{}, Expl: &search.Explanation{ Value: math.Sqrt(1.0) * idf, - Message: "fieldWeight(desc:beer in one), product of:", + Message: "fieldWeight(desc:beer in one), as per tfidf model, product of:", Children: []*search.Explanation{ { Value: 1, @@ -131,7 +131,7 @@ func TestTermScorer(t *testing.T) { Sort: []string{}, Expl: &search.Explanation{ Value: math.Sqrt(65) * idf, - Message: "fieldWeight(desc:beer in one), product of:", + Message: "fieldWeight(desc:beer in one), as per tfidf model, product of:", Children: []*search.Explanation{ { Value: math.Sqrt(65), @@ -224,7 +224,7 @@ func TestTermScorerWithQueryNorm(t *testing.T) { }, { Value: math.Sqrt(1.0) * idf, - Message: "fieldWeight(desc:beer in one), product of:", + Message: "fieldWeight(desc:beer in one), as per tfidf model, product of:", Children: []*search.Explanation{ { Value: 1, diff --git a/search/util.go b/search/util.go index cde561a64..f4c631194 100644 --- a/search/util.go +++ b/search/util.go @@ -192,10 +192,12 @@ func (f FieldTermSynonymMap) MergeWith(fts FieldTermSynonymMap) { } } -// BM25 specific multipliers which affect the scoring of a document. +// BM25 specific multipliers which control the scoring of a document. // -// BM25_b - how much does a doc's field length affect the score -// BM25_k1 - how much can the term frequency affect the score +// BM25_b - controls the extent to which doc's field length normalize term frequency part of score +// BM25_k1 - controls the saturation of the score due to term frequency +// the default values are as per elastic search's implementation +// - https://www.elastic.co/guide/en/elasticsearch/reference/current/index-modules-similarity.html#bm25 var BM25_k1 float64 = 1.2 var BM25_b float64 = 0.75 From fbd4ed89dc43e1c3554a41dd31bca49bbc6d8051 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Mon, 13 Jan 2025 10:23:52 +0530 Subject: [PATCH 26/27] update links --- search/util.go | 1 + 1 file changed, 1 insertion(+) diff --git a/search/util.go b/search/util.go index f4c631194..0530c6732 100644 --- a/search/util.go +++ b/search/util.go @@ -198,6 +198,7 @@ func (f FieldTermSynonymMap) MergeWith(fts FieldTermSynonymMap) { // BM25_k1 - controls the saturation of the score due to term frequency // the default values are as per elastic search's implementation // - https://www.elastic.co/guide/en/elasticsearch/reference/current/index-modules-similarity.html#bm25 +// - https://www.elastic.co/blog/practical-bm25-part-3-considerations-for-picking-b-and-k1-in-elasticsearch var BM25_k1 float64 = 1.2 var BM25_b float64 = 0.75 From 4d4b2d640e185eb6692a4e2826cec11ab0a99f11 Mon Sep 17 00:00:00 2001 From: Thejas-bhat Date: Tue, 14 Jan 2025 19:42:30 +0530 Subject: [PATCH 27/27] updating the unit tests and naming --- index_impl.go | 4 +-- index_test.go | 99 +++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 91 insertions(+), 12 deletions(-) diff --git a/index_impl.go b/index_impl.go index 7e40d8ebc..d59dfb9a1 100644 --- a/index_impl.go +++ b/index_impl.go @@ -486,7 +486,7 @@ func (i *indexImpl) preSearch(ctx context.Context, req *SearchRequest, reader in var fts search.FieldTermSynonymMap var count uint64 - fieldCardinality := make(map[string]int) + var fieldCardinality map[string]int if !isMatchNoneQuery(req.Query) { if synMap, ok := i.m.(mapping.SynonymMapping); ok { if synReader, ok := reader.(index.ThesaurusReader); ok { @@ -497,6 +497,7 @@ func (i *indexImpl) preSearch(ctx context.Context, req *SearchRequest, reader in } } if ok := isBM25Enabled(i.m); ok { + fieldCardinality = make(map[string]int) count, err = reader.DocCount() if err != nil { return nil, err @@ -604,7 +605,6 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr } skipSynonymCollector = true } - skipKNNCollector = true case search.BM25PreSearchDataKey: if v != nil { bm25Data, ok = v.(*search.BM25Stats) diff --git a/index_test.go b/index_test.go index f2e267694..c2844584a 100644 --- a/index_test.go +++ b/index_test.go @@ -350,14 +350,11 @@ func TestBytesWritten(t *testing.T) { cleanupTmpIndexPath(t, tmpIndexPath4) } -func TestGlobalScoring(t *testing.T) { - tmpIndexPath := createTmpIndexPath(t) - defer cleanupTmpIndexPath(t, tmpIndexPath) - +func createIndexMappingOnSampleData() *mapping.IndexMappingImpl { indexMapping := NewIndexMapping() indexMapping.TypeField = "type" indexMapping.DefaultAnalyzer = "en" - indexMapping.ScoringModel = index.BM25Scoring + indexMapping.ScoringModel = index.DefaultScoringModel documentMapping := NewDocumentMapping() indexMapping.AddDocumentMapping("hotel", documentMapping) indexMapping.StoreDynamic = false @@ -373,6 +370,85 @@ func TestGlobalScoring(t *testing.T) { typeFieldMapping.Store = false documentMapping.AddFieldMappingsAt("type", typeFieldMapping) + return indexMapping +} + +func TestBM25TFIDFScoring(t *testing.T) { + tmpIndexPath1 := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath1) + tmpIndexPath2 := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath2) + + indexMapping := createIndexMappingOnSampleData() + indexMapping.ScoringModel = index.BM25Scoring + indexBM25, err := NewUsing(tmpIndexPath1, indexMapping, Config.DefaultIndexType, Config.DefaultMemKVStore, nil) + if err != nil { + t.Fatal(err) + } + + indexMapping1 := createIndexMappingOnSampleData() + indexTFIDF, err := NewUsing(tmpIndexPath2, indexMapping1, Config.DefaultIndexType, Config.DefaultMemKVStore, nil) + if err != nil { + t.Fatal(err) + } + + defer func() { + err := indexBM25.Close() + if err != nil { + t.Fatal(err) + } + + err = indexTFIDF.Close() + if err != nil { + t.Fatal(err) + } + }() + + batch, err := getBatchFromData(indexBM25, "sample-data.json") + if err != nil { + t.Fatalf("failed to form a batch") + } + err = indexBM25.Batch(batch) + if err != nil { + t.Fatalf("failed to index batch %v\n", err) + } + query := NewMatchQuery("Hotel") + query.FieldVal = "name" + searchRequest := NewSearchRequestOptions(query, int(10), 0, true) + + resBM25, err := indexBM25.Search(searchRequest) + if err != nil { + t.Error(err) + } + + batch, err = getBatchFromData(indexTFIDF, "sample-data.json") + if err != nil { + t.Fatalf("failed to form a batch") + } + err = indexTFIDF.Batch(batch) + if err != nil { + t.Fatalf("failed to index batch %v\n", err) + } + + resTFIDF, err := indexTFIDF.Search(searchRequest) + if err != nil { + t.Error(err) + } + + for i, hit := range resTFIDF.Hits { + if hit.Score < resBM25.Hits[i].Score { + t.Fatalf("expected the score to be higher for BM25, got %v and %v", + resBM25.Hits[i].Score, hit.Score) + } + } +} + +func TestBM25GlobalScoring(t *testing.T) { + tmpIndexPath := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath) + + indexMapping := createIndexMappingOnSampleData() + indexMapping.ScoringModel = index.BM25Scoring idxSinglePartition, err := NewUsing(tmpIndexPath, indexMapping, Config.DefaultIndexType, Config.DefaultMemKVStore, nil) if err != nil { t.Fatal(err) @@ -393,7 +469,7 @@ func TestGlobalScoring(t *testing.T) { if err != nil { t.Fatalf("failed to index batch %v\n", err) } - query := NewMatchQuery("Apartments") + query := NewMatchQuery("Hotel") query.FieldVal = "name" searchRequest := NewSearchRequestOptions(query, int(10), 0, true) @@ -402,7 +478,8 @@ func TestGlobalScoring(t *testing.T) { t.Error(err) } - singleIndexScore := res.Hits[0].Score + singlePartHits := res.Hits + dataset, _ := readDataFromFile("sample-data.json") tmpIndexPath1 := createTmpIndexPath(t) defer cleanupTmpIndexPath(t, tmpIndexPath1) @@ -474,9 +551,11 @@ func TestGlobalScoring(t *testing.T) { t.Error(err) } - if singleIndexScore != res.Hits[0].Score { - t.Fatalf("expected the scores to be the same, got %v and %v", - singleIndexScore, res.Hits[0].Score) + for i, hit := range res.Hits { + if hit.Score != singlePartHits[i].Score { + t.Fatalf("expected the scores to be the same, got %v and %v", + hit.Score, singlePartHits[i].Score) + } } }