diff --git a/go.mod b/go.mod index cfee95607..49b71da6e 100644 --- a/go.mod +++ b/go.mod @@ -5,26 +5,26 @@ go 1.21 require ( github.com/RoaringBitmap/roaring v1.9.3 github.com/bits-and-blooms/bitset v1.12.0 - github.com/blevesearch/bleve_index_api v1.2.0 + github.com/blevesearch/bleve_index_api v1.2.1 github.com/blevesearch/geo v0.1.20 github.com/blevesearch/go-faiss v1.0.24 github.com/blevesearch/go-metrics v0.0.0-20201227073835-cf1acfcdf475 github.com/blevesearch/go-porterstemmer v1.0.3 github.com/blevesearch/goleveldb v1.0.1 github.com/blevesearch/gtreap v0.1.1 - github.com/blevesearch/scorch_segment_api/v2 v2.3.0 + github.com/blevesearch/scorch_segment_api/v2 v2.3.1 github.com/blevesearch/segment v0.9.1 github.com/blevesearch/snowball v0.6.1 github.com/blevesearch/snowballstem v0.9.0 github.com/blevesearch/stempel v0.2.0 github.com/blevesearch/upsidedown_store_api v1.0.2 github.com/blevesearch/vellum v1.1.0 - github.com/blevesearch/zapx/v11 v11.3.10 - github.com/blevesearch/zapx/v12 v12.3.10 - github.com/blevesearch/zapx/v13 v13.3.10 - github.com/blevesearch/zapx/v14 v14.3.10 - github.com/blevesearch/zapx/v15 v15.3.17 - github.com/blevesearch/zapx/v16 v16.1.11-0.20241219160422-82553cdd4b38 + github.com/blevesearch/zapx/v11 v11.3.11 + github.com/blevesearch/zapx/v12 v12.3.11 + github.com/blevesearch/zapx/v13 v13.3.11 + github.com/blevesearch/zapx/v14 v14.3.11 + github.com/blevesearch/zapx/v15 v15.3.18 + github.com/blevesearch/zapx/v16 v16.1.11-0.20250107152255-021e66397612 github.com/couchbase/moss v0.2.0 github.com/golang/protobuf v1.3.2 github.com/spf13/cobra v1.7.0 diff --git a/go.sum b/go.sum index f21c89611..1914f7919 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ github.com/RoaringBitmap/roaring v1.9.3 h1:t4EbC5qQwnisr5PrP9nt0IRhRTb9gMUgQF4t4 github.com/RoaringBitmap/roaring v1.9.3/go.mod h1:6AXUsoIEzDTFFQCe1RbGA6uFONMhvejWj5rqITANK90= github.com/bits-and-blooms/bitset v1.12.0 h1:U/q1fAF7xXRhFCrhROzIfffYnu+dlS38vCZtmFVPHmA= github.com/bits-and-blooms/bitset v1.12.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= -github.com/blevesearch/bleve_index_api v1.2.0 h1:/DXMMWBwx/UmGKM1xDhTwDoJI5yQrG6rqRWPFcOgUVo= -github.com/blevesearch/bleve_index_api v1.2.0/go.mod h1:PbcwjIcRmjhGbkS/lJCpfgVSMROV6TRubGGAODaK1W8= +github.com/blevesearch/bleve_index_api v1.2.1 h1:IuXwLvmyp7I7+e0FOA68gcHHLfzSQ4AqQ8wVab5uxk0= +github.com/blevesearch/bleve_index_api v1.2.1/go.mod h1:rKQDl4u51uwafZxFrPD1R7xFOwKnzZW7s/LSeK4lgo0= github.com/blevesearch/geo v0.1.20 h1:paaSpu2Ewh/tn5DKn/FB5SzvH0EWupxHEIwbCk/QPqM= github.com/blevesearch/geo v0.1.20/go.mod h1:DVG2QjwHNMFmjo+ZgzrIq2sfCh6rIHzy9d9d0B59I6w= github.com/blevesearch/go-faiss v1.0.24 h1:K79IvKjoKHdi7FdiXEsAhxpMuns0x4fM0BO93bW5jLI= @@ -19,8 +19,8 @@ github.com/blevesearch/gtreap v0.1.1/go.mod h1:QaQyDRAT51sotthUWAH4Sj08awFSSWzgY github.com/blevesearch/mmap-go v1.0.2/go.mod h1:ol2qBqYaOUsGdm7aRMRrYGgPvnwLe6Y+7LMvAB5IbSA= github.com/blevesearch/mmap-go v1.0.4 h1:OVhDhT5B/M1HNPpYPBKIEJaD0F3Si+CrEKULGCDPWmc= github.com/blevesearch/mmap-go v1.0.4/go.mod h1:EWmEAOmdAS9z/pi/+Toxu99DnsbhG1TIxUoRmJw/pSs= -github.com/blevesearch/scorch_segment_api/v2 v2.3.0 h1:vxCjbXAkkEBSb4AB3Iqgr/EJcPyYRsiGxpcvsS8E1Dw= -github.com/blevesearch/scorch_segment_api/v2 v2.3.0/go.mod h1:5y+TgXYSx+xJGaCwSlvy9G/UJBIY5wzvIkhvhBm2ATc= +github.com/blevesearch/scorch_segment_api/v2 v2.3.1 h1:jjexIzwOdBtC9MlUceNErYHepLvoKxTdA5atbeZSRWE= +github.com/blevesearch/scorch_segment_api/v2 v2.3.1/go.mod h1:Np3Y03rsemM5TsyFxQ3wy+tG97EcviLTbp2S5W0tpRY= github.com/blevesearch/segment v0.9.1 h1:+dThDy+Lvgj5JMxhmOVlgFfkUtZV2kw49xax4+jTfSU= github.com/blevesearch/segment v0.9.1/go.mod h1:zN21iLm7+GnBHWTao9I+Au/7MBiL8pPFtJBJTsk6kQw= github.com/blevesearch/snowball v0.6.1 h1:cDYjn/NCH+wwt2UdehaLpr2e4BwLIjN4V/TdLsL+B5A= @@ -33,18 +33,18 @@ github.com/blevesearch/upsidedown_store_api v1.0.2 h1:U53Q6YoWEARVLd1OYNc9kvhBMG github.com/blevesearch/upsidedown_store_api v1.0.2/go.mod h1:M01mh3Gpfy56Ps/UXHjEO/knbqyQ1Oamg8If49gRwrQ= github.com/blevesearch/vellum v1.1.0 h1:CinkGyIsgVlYf8Y2LUQHvdelgXr6PYuvoDIajq6yR9w= github.com/blevesearch/vellum v1.1.0/go.mod h1:QgwWryE8ThtNPxtgWJof5ndPfx0/YMBh+W2weHKPw8Y= -github.com/blevesearch/zapx/v11 v11.3.10 h1:hvjgj9tZ9DeIqBCxKhi70TtSZYMdcFn7gDb71Xo/fvk= -github.com/blevesearch/zapx/v11 v11.3.10/go.mod h1:0+gW+FaE48fNxoVtMY5ugtNHHof/PxCqh7CnhYdnMzQ= -github.com/blevesearch/zapx/v12 v12.3.10 h1:yHfj3vXLSYmmsBleJFROXuO08mS3L1qDCdDK81jDl8s= -github.com/blevesearch/zapx/v12 v12.3.10/go.mod h1:0yeZg6JhaGxITlsS5co73aqPtM04+ycnI6D1v0mhbCs= -github.com/blevesearch/zapx/v13 v13.3.10 h1:0KY9tuxg06rXxOZHg3DwPJBjniSlqEgVpxIqMGahDE8= -github.com/blevesearch/zapx/v13 v13.3.10/go.mod h1:w2wjSDQ/WBVeEIvP0fvMJZAzDwqwIEzVPnCPrz93yAk= -github.com/blevesearch/zapx/v14 v14.3.10 h1:SG6xlsL+W6YjhX5N3aEiL/2tcWh3DO75Bnz77pSwwKU= -github.com/blevesearch/zapx/v14 v14.3.10/go.mod h1:qqyuR0u230jN1yMmE4FIAuCxmahRQEOehF78m6oTgns= -github.com/blevesearch/zapx/v15 v15.3.17 h1:NkkMI98pYLq/uHnB6YWcITrrLpCVyvZ9iP+AyfpW1Ys= -github.com/blevesearch/zapx/v15 v15.3.17/go.mod h1:vXRQzJJvlGVCdmOD5hg7t7JdjUT5DmDPhsAfjvtzIq8= -github.com/blevesearch/zapx/v16 v16.1.11-0.20241219160422-82553cdd4b38 h1:iJ3Q3sbyo2d0bjfb720RmGjj7cqzh/EdP3528ggDIMY= -github.com/blevesearch/zapx/v16 v16.1.11-0.20241219160422-82553cdd4b38/go.mod h1:JTZseJiEpogtkepKSubIKAmfgbQiOReJXfmjxB1qta4= +github.com/blevesearch/zapx/v11 v11.3.11 h1:r6/wFHFAKWvXJb82f5aO53l6p+gRH6eiX7S1tb3VGc0= +github.com/blevesearch/zapx/v11 v11.3.11/go.mod h1:0+gW+FaE48fNxoVtMY5ugtNHHof/PxCqh7CnhYdnMzQ= +github.com/blevesearch/zapx/v12 v12.3.11 h1:GBBAmXesxXLV5UZ+FZ0qILb7HPssT+kxEkbPPfp5HPM= +github.com/blevesearch/zapx/v12 v12.3.11/go.mod h1:0yeZg6JhaGxITlsS5co73aqPtM04+ycnI6D1v0mhbCs= +github.com/blevesearch/zapx/v13 v13.3.11 h1:H5ZvgS1qM1XKzsAuwp3kvDfh5sJFu9bLH/B8U6Im5e8= +github.com/blevesearch/zapx/v13 v13.3.11/go.mod h1:w2wjSDQ/WBVeEIvP0fvMJZAzDwqwIEzVPnCPrz93yAk= +github.com/blevesearch/zapx/v14 v14.3.11 h1:pg+c/YFzMJ32GkOwLzH/HAQ/GBr6y1Ar7/K5ZQpxTNo= +github.com/blevesearch/zapx/v14 v14.3.11/go.mod h1:qqyuR0u230jN1yMmE4FIAuCxmahRQEOehF78m6oTgns= +github.com/blevesearch/zapx/v15 v15.3.18 h1:yJcQnQyHGNF6rAiwq85OHn3HaXo26t7vgd83RclEw7U= +github.com/blevesearch/zapx/v15 v15.3.18/go.mod h1:vXRQzJJvlGVCdmOD5hg7t7JdjUT5DmDPhsAfjvtzIq8= +github.com/blevesearch/zapx/v16 v16.1.11-0.20250107152255-021e66397612 h1:LhORiqEVyUPUrVETzmmVuT0Yudsz2R3qGLFJWUpMsQo= +github.com/blevesearch/zapx/v16 v16.1.11-0.20250107152255-021e66397612/go.mod h1:+FIylxb+5Z/sFVmNaGpppGLHKBMUEnPSbkKoi+izER8= github.com/couchbase/ghistogram v0.1.0 h1:b95QcQTCzjTUocDXp/uMgSNQi8oj1tGwnJ4bODWZnps= github.com/couchbase/ghistogram v0.1.0/go.mod h1:s1Jhy76zqfEecpNWJfWUiKZookAFaiGOEoyzgHt9i7k= github.com/couchbase/moss v0.2.0 h1:VCYrMzFwEryyhRSeI+/b3tRBSeTpi/8gn5Kf6dxqn+o= diff --git a/index/scorch/snapshot_index.go b/index/scorch/snapshot_index.go index 6d0a0b60e..ece32eee6 100644 --- a/index/scorch/snapshot_index.go +++ b/index/scorch/snapshot_index.go @@ -42,8 +42,9 @@ type asynchSegmentResult struct { dict segment.TermDictionary dictItr segment.DictionaryIterator - index int - docs *roaring.Bitmap + cardinality int + index int + docs *roaring.Bitmap thesItr segment.ThesaurusIterator @@ -137,6 +138,7 @@ func (is *IndexSnapshot) newIndexSnapshotFieldDict(field string, results := make(chan *asynchSegmentResult) var totalBytesRead uint64 + var fieldCardinality int64 for _, s := range is.segment { go func(s *SegmentSnapshot) { dict, err := s.segment.Dictionary(field) @@ -146,6 +148,7 @@ func (is *IndexSnapshot) newIndexSnapshotFieldDict(field string, if dictStats, ok := dict.(segment.DiskStatsReporter); ok { atomic.AddUint64(&totalBytesRead, dictStats.BytesRead()) } + atomic.AddInt64(&fieldCardinality, int64(dict.Cardinality())) if randomLookup { results <- &asynchSegmentResult{dict: dict} } else { @@ -160,6 +163,7 @@ func (is *IndexSnapshot) newIndexSnapshotFieldDict(field string, snapshot: is, cursors: make([]*segmentDictCursor, 0, len(is.segment)), } + for count := 0; count < len(is.segment); count++ { asr := <-results if asr.err != nil && err == nil { @@ -183,6 +187,7 @@ func (is *IndexSnapshot) newIndexSnapshotFieldDict(field string, } } } + rv.cardinality = int(fieldCardinality) rv.bytesRead = totalBytesRead // after ensuring we've read all items on channel if err != nil { diff --git a/index/scorch/snapshot_index_dict.go b/index/scorch/snapshot_index_dict.go index 658aa8148..2ae789c6b 100644 --- a/index/scorch/snapshot_index_dict.go +++ b/index/scorch/snapshot_index_dict.go @@ -28,10 +28,12 @@ type segmentDictCursor struct { } type IndexSnapshotFieldDict struct { - snapshot *IndexSnapshot - cursors []*segmentDictCursor - entry index.DictEntry - bytesRead uint64 + cardinality int + bytesRead uint64 + + snapshot *IndexSnapshot + cursors []*segmentDictCursor + entry index.DictEntry } func (i *IndexSnapshotFieldDict) BytesRead() uint64 { @@ -94,6 +96,10 @@ func (i *IndexSnapshotFieldDict) Next() (*index.DictEntry, error) { return &i.entry, nil } +func (i *IndexSnapshotFieldDict) Cardinality() int { + return i.cardinality +} + func (i *IndexSnapshotFieldDict) Close() error { return nil } diff --git a/index/upsidedown/field_dict.go b/index/upsidedown/field_dict.go index 4875680c9..c990fd47b 100644 --- a/index/upsidedown/field_dict.go +++ b/index/upsidedown/field_dict.go @@ -77,6 +77,10 @@ func (r *UpsideDownCouchFieldDict) Next() (*index.DictEntry, error) { } +func (r *UpsideDownCouchFieldDict) Cardinality() int { + return 0 +} + func (r *UpsideDownCouchFieldDict) Close() error { return r.iterator.Close() } diff --git a/index_alias_impl.go b/index_alias_impl.go index 766240b4a..a4f724e34 100644 --- a/index_alias_impl.go +++ b/index_alias_impl.go @@ -192,9 +192,11 @@ func (i *indexAliasImpl) SearchInContext(ctx context.Context, req *SearchRequest // indicates that this index alias is set as an Index // in another alias, so we need to do a preSearch search // and NOT a real search + bm25PreSearch := isBM25Enabled(i.mapping) flags := &preSearchFlags{ knn: requestHasKNN(req), synonyms: !isMatchNoneQuery(req.Query), + bm25: bm25PreSearch, } return preSearchDataSearch(ctx, req, flags, i.indexes...) } @@ -234,7 +236,7 @@ func (i *indexAliasImpl) SearchInContext(ctx context.Context, req *SearchRequest // - the request requires preSearch var preSearchDuration time.Duration var sr *SearchResult - flags, err := preSearchRequired(req, i.mapping) + flags, err := preSearchRequired(ctx, req, i.mapping) if err != nil { return nil, err } @@ -244,6 +246,7 @@ func (i *indexAliasImpl) SearchInContext(ctx context.Context, req *SearchRequest if err != nil { return nil, err } + // check if the preSearch result has any errors and if so // return the search result as is without executing the query // so that the errors are not lost @@ -573,11 +576,20 @@ type asyncSearchResult struct { type preSearchFlags struct { knn bool synonyms bool + bm25 bool // needs presearch for this too } -// preSearchRequired checks if preSearch is required and returns a boolean flag -// It only allocates the preSearchFlags struct if necessary -func preSearchRequired(req *SearchRequest, m mapping.IndexMapping) (*preSearchFlags, error) { +func isBM25Enabled(m mapping.IndexMapping) bool { + var rv bool + if m, ok := m.(*mapping.IndexMappingImpl); ok { + rv = m.ScoringModel == index.BM25Scoring + } + return rv +} + +// preSearchRequired checks if preSearch is required and returns the presearch flags struct +// indicating which preSearch is required +func preSearchRequired(ctx context.Context, req *SearchRequest, m mapping.IndexMapping) (*preSearchFlags, error) { // Check for KNN query knn := requestHasKNN(req) var synonyms bool @@ -598,18 +610,32 @@ func preSearchRequired(req *SearchRequest, m mapping.IndexMapping) (*preSearchFl } } } - if knn || synonyms { + var bm25 bool + if !isMatchNoneQuery(req.Query) { + if ctx != nil { + if searchType := ctx.Value(search.SearchTypeKey); searchType != nil { + if searchType.(string) == search.GlobalScoring { + bm25 = isBM25Enabled(m) + } + } + } + } + + if knn || synonyms || bm25 { return &preSearchFlags{ knn: knn, synonyms: synonyms, + bm25: bm25, }, nil } return nil, nil } func preSearch(ctx context.Context, req *SearchRequest, flags *preSearchFlags, indexes ...Index) (*SearchResult, error) { + // create a dummy request with a match none query + // since we only care about the preSearchData in PreSearch var dummyQuery = req.Query - if !flags.synonyms { + if !flags.bm25 && !flags.synonyms { // create a dummy request with a match none query // since we only care about the preSearchData in PreSearch dummyQuery = query.NewMatchNoneQuery() @@ -694,6 +720,19 @@ func constructSynonymPreSearchData(rv map[string]map[string]interface{}, sr *Sea return rv } +func constructBM25PreSearchData(rv map[string]map[string]interface{}, sr *SearchResult, indexes []Index) map[string]map[string]interface{} { + bmStats := sr.BM25Stats + if bmStats != nil { + for _, index := range indexes { + rv[index.Name()][search.BM25PreSearchDataKey] = &search.BM25Stats{ + DocCount: bmStats.DocCount, + FieldCardinality: bmStats.FieldCardinality, + } + } + } + return rv +} + func constructPreSearchData(req *SearchRequest, flags *preSearchFlags, preSearchResult *SearchResult, indexes []Index) (map[string]map[string]interface{}, error) { if flags == nil || preSearchResult == nil { @@ -713,6 +752,9 @@ func constructPreSearchData(req *SearchRequest, flags *preSearchFlags, if flags.synonyms { mergedOut = constructSynonymPreSearchData(mergedOut, preSearchResult, indexes) } + if flags.bm25 { + mergedOut = constructBM25PreSearchData(mergedOut, preSearchResult, indexes) + } return mergedOut, nil } @@ -822,6 +864,12 @@ func redistributePreSearchData(req *SearchRequest, indexes []Index) (map[string] rv[index.Name()][search.SynonymPreSearchDataKey] = fts } } + + if bm25Data, ok := req.PreSearchData[search.BM25PreSearchDataKey].(*search.BM25Stats); ok { + for _, index := range indexes { + rv[index.Name()][search.BM25PreSearchDataKey] = bm25Data + } + } return rv, nil } @@ -1009,3 +1057,7 @@ func (f *indexAliasImplFieldDict) Close() error { defer f.index.mutex.RUnlock() return f.fieldDict.Close() } + +func (f *indexAliasImplFieldDict) Cardinality() int { + return f.fieldDict.Cardinality() +} diff --git a/index_impl.go b/index_impl.go index 289014f6c..d59dfb9a1 100644 --- a/index_impl.go +++ b/index_impl.go @@ -485,6 +485,8 @@ func (i *indexImpl) preSearch(ctx context.Context, req *SearchRequest, reader in } var fts search.FieldTermSynonymMap + var count uint64 + var fieldCardinality map[string]int if !isMatchNoneQuery(req.Query) { if synMap, ok := i.m.(mapping.SynonymMapping); ok { if synReader, ok := reader.(index.ThesaurusReader); ok { @@ -494,6 +496,26 @@ func (i *indexImpl) preSearch(ctx context.Context, req *SearchRequest, reader in } } } + if ok := isBM25Enabled(i.m); ok { + fieldCardinality = make(map[string]int) + count, err = reader.DocCount() + if err != nil { + return nil, err + } + + fs := make(query.FieldSet) + fs, err := query.ExtractFields(req.Query, i.m, fs) + if err != nil { + return nil, err + } + for field := range fs { + dict, err := reader.FieldDict(field) + if err != nil { + return nil, err + } + fieldCardinality[field] = dict.Cardinality() + } + } } return &SearchResult{ @@ -503,6 +525,10 @@ func (i *indexImpl) preSearch(ctx context.Context, req *SearchRequest, reader in }, Hits: knnHits, SynonymResult: fts, + BM25Stats: &search.BM25Stats{ + DocCount: float64(count), + FieldCardinality: fieldCardinality, + }, }, nil } @@ -558,6 +584,7 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr var fts search.FieldTermSynonymMap var skipSynonymCollector bool + var bm25Data *search.BM25Stats var ok bool if req.PreSearchData != nil { for k, v := range req.PreSearchData { @@ -578,6 +605,13 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr } skipSynonymCollector = true } + case search.BM25PreSearchDataKey: + if v != nil { + bm25Data, ok = v.(*search.BM25Stats) + if !ok { + return nil, fmt.Errorf("bm25 preSearchData must be of type map[string]interface{}") + } + } } } } @@ -605,6 +639,21 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr ctx = context.WithValue(ctx, search.FieldTermSynonymMapKey, fts) } + scoringModelCallback := func() string { + if isBM25Enabled(i.m) { + return index.BM25Scoring + } + return index.DefaultScoringModel + } + ctx = context.WithValue(ctx, search.GetScoringModelCallbackKey, + search.GetScoringModelCallbackFn(scoringModelCallback)) + + // set the bm25 presearch data (stats important for consistent scoring) in + // the context object + if bm25Data != nil { + ctx = context.WithValue(ctx, search.BM25PreSearchDataKey, bm25Data) + } + // This callback and variable handles the tracking of bytes read // 1. as part of creation of tfr and its Next() calls which is // accounted by invoking this callback when the TFR is closed. @@ -1107,6 +1156,10 @@ func (f *indexImplFieldDict) Close() error { return f.indexReader.Close() } +func (f *indexImplFieldDict) Cardinality() int { + return f.fieldDict.Cardinality() +} + // helper function to remove duplicate entries from slice of strings func deDuplicate(fields []string) []string { entries := make(map[string]struct{}) diff --git a/index_test.go b/index_test.go index 82be0d947..c2844584a 100644 --- a/index_test.go +++ b/index_test.go @@ -350,6 +350,216 @@ func TestBytesWritten(t *testing.T) { cleanupTmpIndexPath(t, tmpIndexPath4) } +func createIndexMappingOnSampleData() *mapping.IndexMappingImpl { + indexMapping := NewIndexMapping() + indexMapping.TypeField = "type" + indexMapping.DefaultAnalyzer = "en" + indexMapping.ScoringModel = index.DefaultScoringModel + documentMapping := NewDocumentMapping() + indexMapping.AddDocumentMapping("hotel", documentMapping) + indexMapping.StoreDynamic = false + indexMapping.DocValuesDynamic = false + contentFieldMapping := NewTextFieldMapping() + contentFieldMapping.Store = false + + reviewsMapping := NewDocumentMapping() + reviewsMapping.AddFieldMappingsAt("content", contentFieldMapping) + documentMapping.AddSubDocumentMapping("reviews", reviewsMapping) + + typeFieldMapping := NewTextFieldMapping() + typeFieldMapping.Store = false + documentMapping.AddFieldMappingsAt("type", typeFieldMapping) + + return indexMapping +} + +func TestBM25TFIDFScoring(t *testing.T) { + tmpIndexPath1 := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath1) + tmpIndexPath2 := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath2) + + indexMapping := createIndexMappingOnSampleData() + indexMapping.ScoringModel = index.BM25Scoring + indexBM25, err := NewUsing(tmpIndexPath1, indexMapping, Config.DefaultIndexType, Config.DefaultMemKVStore, nil) + if err != nil { + t.Fatal(err) + } + + indexMapping1 := createIndexMappingOnSampleData() + indexTFIDF, err := NewUsing(tmpIndexPath2, indexMapping1, Config.DefaultIndexType, Config.DefaultMemKVStore, nil) + if err != nil { + t.Fatal(err) + } + + defer func() { + err := indexBM25.Close() + if err != nil { + t.Fatal(err) + } + + err = indexTFIDF.Close() + if err != nil { + t.Fatal(err) + } + }() + + batch, err := getBatchFromData(indexBM25, "sample-data.json") + if err != nil { + t.Fatalf("failed to form a batch") + } + err = indexBM25.Batch(batch) + if err != nil { + t.Fatalf("failed to index batch %v\n", err) + } + query := NewMatchQuery("Hotel") + query.FieldVal = "name" + searchRequest := NewSearchRequestOptions(query, int(10), 0, true) + + resBM25, err := indexBM25.Search(searchRequest) + if err != nil { + t.Error(err) + } + + batch, err = getBatchFromData(indexTFIDF, "sample-data.json") + if err != nil { + t.Fatalf("failed to form a batch") + } + err = indexTFIDF.Batch(batch) + if err != nil { + t.Fatalf("failed to index batch %v\n", err) + } + + resTFIDF, err := indexTFIDF.Search(searchRequest) + if err != nil { + t.Error(err) + } + + for i, hit := range resTFIDF.Hits { + if hit.Score < resBM25.Hits[i].Score { + t.Fatalf("expected the score to be higher for BM25, got %v and %v", + resBM25.Hits[i].Score, hit.Score) + } + } +} + +func TestBM25GlobalScoring(t *testing.T) { + tmpIndexPath := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath) + + indexMapping := createIndexMappingOnSampleData() + indexMapping.ScoringModel = index.BM25Scoring + idxSinglePartition, err := NewUsing(tmpIndexPath, indexMapping, Config.DefaultIndexType, Config.DefaultMemKVStore, nil) + if err != nil { + t.Fatal(err) + } + + defer func() { + err := idxSinglePartition.Close() + if err != nil { + t.Fatal(err) + } + }() + + batch, err := getBatchFromData(idxSinglePartition, "sample-data.json") + if err != nil { + t.Fatalf("failed to form a batch") + } + err = idxSinglePartition.Batch(batch) + if err != nil { + t.Fatalf("failed to index batch %v\n", err) + } + query := NewMatchQuery("Hotel") + query.FieldVal = "name" + searchRequest := NewSearchRequestOptions(query, int(10), 0, true) + + res, err := idxSinglePartition.Search(searchRequest) + if err != nil { + t.Error(err) + } + + singlePartHits := res.Hits + + dataset, _ := readDataFromFile("sample-data.json") + tmpIndexPath1 := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath1) + + idxPart1, err := NewUsing(tmpIndexPath1, indexMapping, Config.DefaultIndexType, Config.DefaultMemKVStore, nil) + if err != nil { + t.Fatal(err) + } + + defer func() { + err := idxPart1.Close() + if err != nil { + t.Fatal(err) + } + }() + + batch1 := idxPart1.NewBatch() + for _, doc := range dataset[:len(dataset)/2] { + err = batch1.Index(fmt.Sprintf("%d", doc["id"]), doc) + if err != nil { + t.Fatal(err) + } + } + err = idxPart1.Batch(batch1) + if err != nil { + t.Fatal(err) + } + + tmpIndexPath2 := createTmpIndexPath(t) + defer cleanupTmpIndexPath(t, tmpIndexPath2) + + idxPart2, err := NewUsing(tmpIndexPath2, indexMapping, Config.DefaultIndexType, Config.DefaultMemKVStore, nil) + if err != nil { + t.Fatal(err) + } + + defer func() { + err := idxPart2.Close() + if err != nil { + t.Fatal(err) + } + }() + + batch2 := idxPart2.NewBatch() + for _, doc := range dataset[len(dataset)/2:] { + err = batch2.Index(fmt.Sprintf("%d", doc["id"]), doc) + if err != nil { + t.Fatal(err) + } + } + err = idxPart2.Batch(batch2) + if err != nil { + t.Fatal(err) + } + + multiPartIndex := NewIndexAlias(idxPart1, idxPart2) + err = multiPartIndex.SetIndexMapping(indexMapping) + if err != nil { + t.Fatal(err) + } + + ctx := context.Background() + // this key is set to ensure that we have a consistent scoring at the index alias + // level (it forces a pre search phase which can have a small overhead) + ctx = context.WithValue(ctx, search.SearchTypeKey, search.GlobalScoring) + + res, err = multiPartIndex.SearchInContext(ctx, searchRequest) + if err != nil { + t.Error(err) + } + + for i, hit := range res.Hits { + if hit.Score != singlePartHits[i].Score { + t.Fatalf("expected the scores to be the same, got %v and %v", + hit.Score, singlePartHits[i].Score) + } + } + +} + func TestBytesRead(t *testing.T) { tmpIndexPath := createTmpIndexPath(t) defer cleanupTmpIndexPath(t, tmpIndexPath) @@ -671,23 +881,30 @@ func TestBytesReadStored(t *testing.T) { } } -func getBatchFromData(idx Index, fileName string) (*Batch, error) { +func readDataFromFile(fileName string) ([]map[string]interface{}, error) { pwd, err := os.Getwd() if err != nil { return nil, err } path := filepath.Join(pwd, "data", "test", fileName) - batch := idx.NewBatch() + var dataset []map[string]interface{} fileContent, err := os.ReadFile(path) if err != nil { return nil, err } + err = json.Unmarshal(fileContent, &dataset) if err != nil { return nil, err } + return dataset, nil +} + +func getBatchFromData(idx Index, fileName string) (*Batch, error) { + dataset, err := readDataFromFile(fileName) + batch := idx.NewBatch() for _, doc := range dataset { err = batch.Index(fmt.Sprintf("%d", doc["id"]), doc) if err != nil { diff --git a/mapping/field.go b/mapping/field.go index ce2878b18..cfb390b40 100644 --- a/mapping/field.go +++ b/mapping/field.go @@ -74,8 +74,8 @@ type FieldMapping struct { Dims int `json:"dims,omitempty"` // Similarity is the similarity algorithm used for scoring - // vector fields. - // See: index.DefaultSimilarityMetric & index.SupportedSimilarityMetrics + // field's content while performing search on it. + // See: index.SimilarityModels Similarity string `json:"similarity,omitempty"` // Applicable to vector fields only - optimization string diff --git a/mapping/index.go b/mapping/index.go index 8a0d5e34a..6150f2a38 100644 --- a/mapping/index.go +++ b/mapping/index.go @@ -50,6 +50,7 @@ type IndexMappingImpl struct { DefaultAnalyzer string `json:"default_analyzer"` DefaultDateTimeParser string `json:"default_datetime_parser"` DefaultSynonymSource string `json:"default_synonym_source,omitempty"` + ScoringModel string `json:"scoring_model,omitempty"` DefaultField string `json:"default_field"` StoreDynamic bool `json:"store_dynamic"` IndexDynamic bool `json:"index_dynamic"` @@ -201,6 +202,11 @@ func (im *IndexMappingImpl) Validate() error { return err } } + + if _, ok := index.SupportedScoringModels[im.ScoringModel]; !ok && im.ScoringModel != "" { + return fmt.Errorf("unsupported scoring model: %s", im.ScoringModel) + } + return nil } @@ -303,6 +309,12 @@ func (im *IndexMappingImpl) UnmarshalJSON(data []byte) error { if err != nil { return err } + case "scoring_model": + err := util.UnmarshalJSON(v, &im.ScoringModel) + if err != nil { + return err + } + default: invalidKeys = append(invalidKeys, k) } diff --git a/mapping/mapping_vectors.go b/mapping/mapping_vectors.go index dbfde1fb0..20cbac6a8 100644 --- a/mapping/mapping_vectors.go +++ b/mapping/mapping_vectors.go @@ -204,7 +204,7 @@ func validateVectorFieldAlias(field *FieldMapping, parentName string, } if field.Similarity == "" { - field.Similarity = index.DefaultSimilarityMetric + field.Similarity = index.DefaultVectorSimilarityMetric } if field.VectorIndexOptimizedFor == "" { @@ -249,10 +249,10 @@ func validateVectorFieldAlias(field *FieldMapping, parentName string, MinVectorDims, MaxVectorDims) } - if _, ok := index.SupportedSimilarityMetrics[field.Similarity]; !ok { + if _, ok := index.SupportedVectorSimilarityMetrics[field.Similarity]; !ok { return fmt.Errorf("field: '%s', invalid similarity "+ "metric: '%s', valid metrics are: %+v", field.Name, field.Similarity, - reflect.ValueOf(index.SupportedSimilarityMetrics).MapKeys()) + reflect.ValueOf(index.SupportedVectorSimilarityMetrics).MapKeys()) } if fieldAliasCtx != nil { // writing to a nil map is unsafe diff --git a/pre_search.go b/pre_search.go index 5fd710d68..3dd7e0fe3 100644 --- a/pre_search.go +++ b/pre_search.go @@ -82,6 +82,34 @@ func (s *synonymPreSearchResultProcessor) finalize(sr *SearchResult) { } } +type bm25PreSearchResultProcessor struct { + docCount float64 // bm25 specific stats + fieldCardinality map[string]int +} + +func newBM25PreSearchResultProcessor() *bm25PreSearchResultProcessor { + return &bm25PreSearchResultProcessor{ + fieldCardinality: make(map[string]int), + } +} + +// TODO How will this work for queries other than term queries? +func (b *bm25PreSearchResultProcessor) add(sr *SearchResult, indexName string) { + if sr.BM25Stats != nil { + b.docCount += sr.BM25Stats.DocCount + for field, cardinality := range sr.BM25Stats.FieldCardinality { + b.fieldCardinality[field] += cardinality + } + } +} + +func (b *bm25PreSearchResultProcessor) finalize(sr *SearchResult) { + sr.BM25Stats = &search.BM25Stats{ + DocCount: b.docCount, + FieldCardinality: b.fieldCardinality, + } +} + // ----------------------------------------------------------------------------- // Master struct that can hold any number of presearch result processors type compositePreSearchResultProcessor struct { @@ -122,6 +150,11 @@ func createPreSearchResultProcessor(req *SearchRequest, flags *preSearchFlags) p processors = append(processors, synonymProcessor) } } + if flags.bm25 { + if bm25Processtor := newBM25PreSearchResultProcessor(); bm25Processtor != nil { + processors = append(processors, bm25Processtor) + } + } // Return based on the number of processors, optimizing for the common case of 1 processor // If there are no processors, return nil switch len(processors) { diff --git a/search.go b/search.go index 72bfca5e2..e13a93703 100644 --- a/search.go +++ b/search.go @@ -447,6 +447,9 @@ type SearchResult struct { // special fields that are applicable only for search // results that are obtained from a presearch SynonymResult search.FieldTermSynonymMap `json:"synonym_result,omitempty"` + + // The following fields are applicable to BM25 preSearch + BM25Stats *search.BM25Stats `json:"bm25_stats,omitempty"` } func (sr *SearchResult) Size() int { diff --git a/search/query/knn.go b/search/query/knn.go index 4d105d943..8221fbcea 100644 --- a/search/query/knn.go +++ b/search/query/knn.go @@ -82,7 +82,7 @@ func (q *KNNQuery) Searcher(ctx context.Context, i index.IndexReader, fieldMapping := m.FieldMappingForPath(q.VectorField) similarityMetric := fieldMapping.Similarity if similarityMetric == "" { - similarityMetric = index.DefaultSimilarityMetric + similarityMetric = index.DefaultVectorSimilarityMetric } if q.K <= 0 || len(q.Vector) == 0 { return nil, fmt.Errorf("k must be greater than 0 and vector must be non-empty") diff --git a/search/query/query.go b/search/query/query.go index 86859ae5b..a1f7b3404 100644 --- a/search/query/query.go +++ b/search/query/query.go @@ -105,6 +105,19 @@ func ParsePreSearchData(input []byte) (map[string]interface{}, error) { rv = make(map[string]interface{}) } rv[search.SynonymPreSearchDataKey] = value + case search.BM25PreSearchDataKey: + var value *search.BM25Stats + if v != nil { + err := util.UnmarshalJSON(v, &value) + if err != nil { + return nil, err + } + } + if rv == nil { + rv = make(map[string]interface{}) + } + rv[search.BM25PreSearchDataKey] = value + } } return rv, nil diff --git a/search/scorer/scorer_term.go b/search/scorer/scorer_term.go index 7b60eda4e..f5f8ec935 100644 --- a/search/scorer/scorer_term.go +++ b/search/scorer/scorer_term.go @@ -35,8 +35,9 @@ type TermQueryScorer struct { queryTerm string queryField string queryBoost float64 - docTerm uint64 - docTotal uint64 + docTerm uint64 // number of documents containing the term + docTotal uint64 // total number of documents in the index + avgDocLength float64 idf float64 options search.SearcherOptions idfExplanation *search.Explanation @@ -61,19 +62,43 @@ func (s *TermQueryScorer) Size() int { return sizeInBytes } -func NewTermQueryScorer(queryTerm []byte, queryField string, queryBoost float64, docTotal, docTerm uint64, options search.SearcherOptions) *TermQueryScorer { +func (s *TermQueryScorer) computeIDF(avgDocLength float64, docTotal, docTerm uint64) float64 { + var rv float64 + if avgDocLength > 0 { + // avgDocLength is set only for bm25 scoring + rv = math.Log(1 + (float64(docTotal)-float64(docTerm)+0.5)/ + (float64(docTerm)+0.5)) + } else { + rv = 1.0 + math.Log(float64(docTotal)/ + float64(docTerm+1.0)) + } + + return rv +} + +// queryTerm - the specific term being scored by this scorer object +// queryField - the field in which the term is being searched +// queryBoost - the boost value for the query term +// docTotal - total number of documents in the index +// docTerm - number of documents containing the term +// avgDocLength - average document length in the index +// options - search options such as explain scoring, include the location of the term etc. +func NewTermQueryScorer(queryTerm []byte, queryField string, queryBoost float64, docTotal, + docTerm uint64, avgDocLength float64, options search.SearcherOptions) *TermQueryScorer { + rv := TermQueryScorer{ queryTerm: string(queryTerm), queryField: queryField, queryBoost: queryBoost, docTerm: docTerm, docTotal: docTotal, - idf: 1.0 + math.Log(float64(docTotal)/float64(docTerm+1.0)), + avgDocLength: avgDocLength, options: options, queryWeight: 1.0, includeScore: options.Score != "none", } + rv.idf = rv.computeIDF(avgDocLength, docTotal, docTerm) if options.Explain { rv.idfExplanation = &search.Explanation{ Value: rv.idf, @@ -114,6 +139,63 @@ func (s *TermQueryScorer) SetQueryNorm(qnorm float64) { } } +func (s *TermQueryScorer) docScore(tf, norm float64) (score float64, model string) { + if s.avgDocLength > 0 { + // bm25 scoring + // using the posting's norm value to recompute the field length for the doc num + fieldLength := 1 / (norm * norm) + + score = s.idf * (tf * search.BM25_k1) / + (tf + search.BM25_k1*(1-search.BM25_b+(search.BM25_b*fieldLength/s.avgDocLength))) + model = index.BM25Scoring + } else { + // tf-idf scoring by default + score = tf * norm * s.idf + model = index.DefaultScoringModel + } + return score, model +} + +func (s *TermQueryScorer) scoreExplanation(tf float64, termMatch *index.TermFieldDoc) []*search.Explanation { + var rv []*search.Explanation + if s.avgDocLength > 0 { + fieldLength := 1 / (termMatch.Norm * termMatch.Norm) + fieldNormVal := 1 - search.BM25_b + (search.BM25_b * fieldLength / s.avgDocLength) + fieldNormalizeExplanation := &search.Explanation{ + Value: fieldNormVal, + Message: fmt.Sprintf("fieldNorm(field=%s), b=%f, fieldLength=%f, avgFieldLength=%f)", + s.queryField, search.BM25_b, fieldLength, s.avgDocLength), + } + + saturationExplanation := &search.Explanation{ + Value: search.BM25_k1 / (tf + search.BM25_k1*fieldNormVal), + Message: fmt.Sprintf("saturation(term:%s), k1=%f/(tf=%f + k1*fieldNorm=%f))", + termMatch.Term, search.BM25_k1, tf, fieldNormVal), + Children: []*search.Explanation{fieldNormalizeExplanation}, + } + + rv = make([]*search.Explanation, 3) + rv[0] = &search.Explanation{ + Value: tf, + Message: fmt.Sprintf("tf(termFreq(%s:%s)=%d", s.queryField, s.queryTerm, termMatch.Freq), + } + rv[1] = saturationExplanation + rv[2] = s.idfExplanation + } else { + rv = make([]*search.Explanation, 3) + rv[0] = &search.Explanation{ + Value: tf, + Message: fmt.Sprintf("tf(termFreq(%s:%s)=%d", s.queryField, s.queryTerm, termMatch.Freq), + } + rv[1] = &search.Explanation{ + Value: termMatch.Norm, + Message: fmt.Sprintf("fieldNorm(field=%s, doc=%s)", s.queryField, termMatch.ID), + } + rv[2] = s.idfExplanation + } + return rv +} + func (s *TermQueryScorer) Score(ctx *search.SearchContext, termMatch *index.TermFieldDoc) *search.DocumentMatch { rv := ctx.DocumentMatchPool.Get() // perform any score computations only when needed @@ -125,22 +207,14 @@ func (s *TermQueryScorer) Score(ctx *search.SearchContext, termMatch *index.Term } else { tf = math.Sqrt(float64(termMatch.Freq)) } - score := tf * termMatch.Norm * s.idf + score, scoringModel := s.docScore(tf, termMatch.Norm) if s.options.Explain { - childrenExplanations := make([]*search.Explanation, 3) - childrenExplanations[0] = &search.Explanation{ - Value: tf, - Message: fmt.Sprintf("tf(termFreq(%s:%s)=%d", s.queryField, s.queryTerm, termMatch.Freq), - } - childrenExplanations[1] = &search.Explanation{ - Value: termMatch.Norm, - Message: fmt.Sprintf("fieldNorm(field=%s, doc=%s)", s.queryField, termMatch.ID), - } - childrenExplanations[2] = s.idfExplanation + childrenExplanations := s.scoreExplanation(tf, termMatch) scoreExplanation = &search.Explanation{ - Value: score, - Message: fmt.Sprintf("fieldWeight(%s:%s in %s), product of:", s.queryField, s.queryTerm, termMatch.ID), + Value: score, + Message: fmt.Sprintf("fieldWeight(%s:%s in %s), as per %s model, "+ + "product of:", s.queryField, s.queryTerm, termMatch.ID, scoringModel), Children: childrenExplanations, } } diff --git a/search/scorer/scorer_term_test.go b/search/scorer/scorer_term_test.go index ffe535183..097dbe243 100644 --- a/search/scorer/scorer_term_test.go +++ b/search/scorer/scorer_term_test.go @@ -30,7 +30,7 @@ func TestTermScorer(t *testing.T) { var queryTerm = []byte("beer") var queryField = "desc" var queryBoost = 1.0 - scorer := NewTermQueryScorer(queryTerm, queryField, queryBoost, docTotal, docTerm, search.SearcherOptions{Explain: true}) + scorer := NewTermQueryScorer(queryTerm, queryField, queryBoost, docTotal, docTerm, 0, search.SearcherOptions{Explain: true}) idf := 1.0 + math.Log(float64(docTotal)/float64(docTerm+1.0)) tests := []struct { @@ -58,7 +58,7 @@ func TestTermScorer(t *testing.T) { Sort: []string{}, Expl: &search.Explanation{ Value: math.Sqrt(1.0) * idf, - Message: "fieldWeight(desc:beer in one), product of:", + Message: "fieldWeight(desc:beer in one), as per tfidf model, product of:", Children: []*search.Explanation{ { Value: 1, @@ -100,7 +100,7 @@ func TestTermScorer(t *testing.T) { Sort: []string{}, Expl: &search.Explanation{ Value: math.Sqrt(1.0) * idf, - Message: "fieldWeight(desc:beer in one), product of:", + Message: "fieldWeight(desc:beer in one), as per tfidf model, product of:", Children: []*search.Explanation{ { Value: 1, @@ -131,7 +131,7 @@ func TestTermScorer(t *testing.T) { Sort: []string{}, Expl: &search.Explanation{ Value: math.Sqrt(65) * idf, - Message: "fieldWeight(desc:beer in one), product of:", + Message: "fieldWeight(desc:beer in one), as per tfidf model, product of:", Children: []*search.Explanation{ { Value: math.Sqrt(65), @@ -175,7 +175,7 @@ func TestTermScorerWithQueryNorm(t *testing.T) { var queryTerm = []byte("beer") var queryField = "desc" var queryBoost = 3.0 - scorer := NewTermQueryScorer(queryTerm, queryField, queryBoost, docTotal, docTerm, search.SearcherOptions{Explain: true}) + scorer := NewTermQueryScorer(queryTerm, queryField, queryBoost, docTotal, docTerm, 0, search.SearcherOptions{Explain: true}) idf := 1.0 + math.Log(float64(docTotal)/float64(docTerm+1.0)) scorer.SetQueryNorm(2.0) @@ -224,7 +224,7 @@ func TestTermScorerWithQueryNorm(t *testing.T) { }, { Value: math.Sqrt(1.0) * idf, - Message: "fieldWeight(desc:beer in one), product of:", + Message: "fieldWeight(desc:beer in one), as per tfidf model, product of:", Children: []*search.Explanation{ { Value: 1, diff --git a/search/searcher/search_disjunction.go b/search/searcher/search_disjunction.go index d165ec027..434c705e7 100644 --- a/search/searcher/search_disjunction.go +++ b/search/searcher/search_disjunction.go @@ -114,7 +114,7 @@ func optimizeCompositeSearcher(ctx context.Context, optimizationKind string, return nil, nil } - return newTermSearcherFromReader(indexReader, tfr, + return newTermSearcherFromReader(ctx, indexReader, tfr, []byte(optimizationKind), "*", 1.0, options) } diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index c519d8d51..1c33c6a41 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -16,6 +16,8 @@ package searcher import ( "context" + "fmt" + "math" "reflect" "github.com/blevesearch/bleve/v2/search" @@ -38,14 +40,16 @@ type TermSearcher struct { tfd index.TermFieldDoc } -func NewTermSearcher(ctx context.Context, indexReader index.IndexReader, term string, field string, boost float64, options search.SearcherOptions) (search.Searcher, error) { +func NewTermSearcher(ctx context.Context, indexReader index.IndexReader, + term string, field string, boost float64, options search.SearcherOptions) (search.Searcher, error) { if isTermQuery(ctx) { ctx = context.WithValue(ctx, search.QueryTypeKey, search.Term) } return NewTermSearcherBytes(ctx, indexReader, []byte(term), field, boost, options) } -func NewTermSearcherBytes(ctx context.Context, indexReader index.IndexReader, term []byte, field string, boost float64, options search.SearcherOptions) (search.Searcher, error) { +func NewTermSearcherBytes(ctx context.Context, indexReader index.IndexReader, + term []byte, field string, boost float64, options search.SearcherOptions) (search.Searcher, error) { if ctx != nil { if fts, ok := ctx.Value(search.FieldTermSynonymMapKey).(search.FieldTermSynonymMap); ok { if ts, exists := fts[field]; exists { @@ -60,17 +64,85 @@ func NewTermSearcherBytes(ctx context.Context, indexReader index.IndexReader, te if err != nil { return nil, err } - return newTermSearcherFromReader(indexReader, reader, term, field, boost, options) + return newTermSearcherFromReader(ctx, indexReader, reader, term, field, boost, options) } -func newTermSearcherFromReader(indexReader index.IndexReader, reader index.TermFieldReader, - term []byte, field string, boost float64, options search.SearcherOptions) (*TermSearcher, error) { +func tfIDFScoreMetrics(indexReader index.IndexReader) (uint64, error) { + // default tf-idf stats count, err := indexReader.DocCount() if err != nil { - _ = reader.Close() - return nil, err + return 0, err + } + + if count == 0 { + return 0, nil + } + return count, nil +} + +func bm25ScoreMetrics(ctx context.Context, field string, + indexReader index.IndexReader) (uint64, float64, error) { + var count uint64 + var fieldCardinality int + var err error + + bm25Stats, ok := ctx.Value(search.BM25PreSearchDataKey).(*search.BM25Stats) + if !ok { + count, err = indexReader.DocCount() + if err != nil { + return 0, 0, err + } + dict, err := indexReader.FieldDict(field) + if err != nil { + return 0, 0, err + } + fieldCardinality = dict.Cardinality() + } else { + count = uint64(bm25Stats.DocCount) + fieldCardinality, ok = bm25Stats.FieldCardinality[field] + if !ok { + return 0, 0, fmt.Errorf("field stat for bm25 not present %s", field) + } + } + + if count == 0 && fieldCardinality == 0 { + return 0, 0, nil + } + return count, math.Ceil(float64(fieldCardinality) / float64(count)), nil +} + +func newTermSearcherFromReader(ctx context.Context, indexReader index.IndexReader, + reader index.TermFieldReader, term []byte, field string, boost float64, + options search.SearcherOptions) (*TermSearcher, error) { + var count uint64 + var avgDocLength float64 + var err error + var similarityModel string + + // as a fallback case we track certain stats for tf-idf scoring + if ctx != nil { + if similaritModelCallback, ok := ctx.Value(search. + GetScoringModelCallbackKey).(search.GetScoringModelCallbackFn); ok { + similarityModel = similaritModelCallback() + } + } + switch similarityModel { + case index.BM25Scoring: + count, avgDocLength, err = bm25ScoreMetrics(ctx, field, indexReader) + if err != nil { + _ = reader.Close() + return nil, err + } + case index.TFIDFScoring: + fallthrough + default: + count, err = tfIDFScoreMetrics(indexReader) + if err != nil { + _ = reader.Close() + return nil, err + } } - scorer := scorer.NewTermQueryScorer(term, field, boost, count, reader.Count(), options) + scorer := scorer.NewTermQueryScorer(term, field, boost, count, reader.Count(), avgDocLength, options) return &TermSearcher{ indexReader: indexReader, reader: reader, @@ -85,7 +157,7 @@ func NewSynonymSearcher(ctx context.Context, indexReader index.IndexReader, term if err != nil { return nil, err } - return newTermSearcherFromReader(indexReader, reader, term, field, boostVal, options) + return newTermSearcherFromReader(ctx, indexReader, reader, term, field, boostVal, options) } // create a searcher for the term itself termSearcher, err := createTermSearcher(term, boost) diff --git a/search/util.go b/search/util.go index 2e95f1180..0530c6732 100644 --- a/search/util.go +++ b/search/util.go @@ -135,16 +135,47 @@ const MinGeoBufPoolSize = 24 type GeoBufferPoolCallbackFunc func() *s2.GeoBufferPool +// PreSearchKey indicates whether to perform a preliminary search to gather necessary +// information which would be used in the actual search down the line. +const PreSearchKey = "_presearch_key" + +// *PreSearchDataKey are used to store the data gathered during the presearch phase +// which would be use in the actual search phase. const KnnPreSearchDataKey = "_knn_pre_search_data_key" const SynonymPreSearchDataKey = "_synonym_pre_search_data_key" +const BM25PreSearchDataKey = "_bm25_pre_search_data_key" -const PreSearchKey = "_presearch_key" +// SearchTypeKey is used to identify type of the search being performed. +// +// for consistent scoring in cases an index is partitioned/sharded (using an +// index alias), GlobalScoring helps in aggregating the necessary stats across +// all the child bleve indexes (shards/partitions) first before the actual search +// is performed, such that the scoring involved using these stats would be at a +// global level. +const SearchTypeKey = "_search_type_key" + +// The following keys are used to invoke the callbacks at the start and end stages +// of optimizing the disjunction/conjunction searcher creation. +const SearcherStartCallbackKey = "_searcher_start_callback_key" +const SearcherEndCallbackKey = "_searcher_end_callback_key" -type ScoreExplCorrectionCallbackFunc func(queryMatch *DocumentMatch, knnMatch *DocumentMatch) (float64, *Explanation) +// FieldTermSynonymMapKey is used to store and transport the synonym definitions data +// to the actual search phase which would use the synonyms to perform the search. +const FieldTermSynonymMapKey = "_field_term_synonym_map_key" + +const GlobalScoring = "_global_scoring" + +// GetScoringModelCallbackKey is used to help the underlying searcher identify +// which scoring mechanism to use based on index mapping. +const GetScoringModelCallbackKey = "_get_scoring_model" type SearcherStartCallbackFn func(size uint64) error type SearcherEndCallbackFn func(size uint64) error +type GetScoringModelCallbackFn func() string + +type ScoreExplCorrectionCallbackFunc func(queryMatch *DocumentMatch, knnMatch *DocumentMatch) (float64, *Explanation) + // field -> term -> synonyms type FieldTermSynonymMap map[string]map[string][]string @@ -161,7 +192,17 @@ func (f FieldTermSynonymMap) MergeWith(fts FieldTermSynonymMap) { } } -const FieldTermSynonymMapKey = "_field_term_synonym_map_key" - -const SearcherStartCallbackKey = "_searcher_start_callback_key" -const SearcherEndCallbackKey = "_searcher_end_callback_key" +// BM25 specific multipliers which control the scoring of a document. +// +// BM25_b - controls the extent to which doc's field length normalize term frequency part of score +// BM25_k1 - controls the saturation of the score due to term frequency +// the default values are as per elastic search's implementation +// - https://www.elastic.co/guide/en/elasticsearch/reference/current/index-modules-similarity.html#bm25 +// - https://www.elastic.co/blog/practical-bm25-part-3-considerations-for-picking-b-and-k1-in-elasticsearch +var BM25_k1 float64 = 1.2 +var BM25_b float64 = 0.75 + +type BM25Stats struct { + DocCount float64 `json:"doc_count"` + FieldCardinality map[string]int `json:"field_cardinality"` +}