diff --git a/index/scorch/snapshot_index.go b/index/scorch/snapshot_index.go index 79840a41f..685f1c921 100644 --- a/index/scorch/snapshot_index.go +++ b/index/scorch/snapshot_index.go @@ -60,6 +60,7 @@ var reflectStaticSizeIndexSnapshot int // exported variable, or at the index level by setting the FieldTFRCacheThreshold // in the kvConfig. var DefaultFieldTFRCacheThreshold uint64 = 10 +var DefaultSynonymTermReaderCacheThreshold uint64 = 10 func init() { var is interface{} = IndexSnapshot{} @@ -87,8 +88,9 @@ type IndexSnapshot struct { m sync.Mutex // Protects the fields that follow. refs int64 - m2 sync.Mutex // Protects the fields that follow. - fieldTFRs map[string][]*IndexSnapshotTermFieldReader // keyed by field, recycled TFR's + m2 sync.Mutex // Protects the fields that follow. + fieldTFRs map[string][]*IndexSnapshotTermFieldReader // keyed by field, recycled TFR's + synonymTermReaders map[string][]*IndexSnapshotSynonymTermReader // keyed by thesaurus name, recycled thesaurus readers } func (i *IndexSnapshot) Segments() []*SegmentSnapshot { @@ -649,6 +651,15 @@ func (is *IndexSnapshot) getFieldTFRCacheThreshold() uint64 { return DefaultFieldTFRCacheThreshold } +func (is *IndexSnapshot) getSynonymTermReaderCacheThreshold() uint64 { + if is.parent.config != nil { + if _, ok := is.parent.config["SynonymTermReaderCacheThreshold"]; ok { + return is.parent.config["SynonymTermReaderCacheThreshold"].(uint64) + } + } + return DefaultSynonymTermReaderCacheThreshold +} + func (is *IndexSnapshot) recycleTermFieldReader(tfr *IndexSnapshotTermFieldReader) { if !tfr.recycle { // Do not recycle an optimized unadorned term field reader (used for @@ -677,6 +688,25 @@ func (is *IndexSnapshot) recycleTermFieldReader(tfr *IndexSnapshotTermFieldReade is.m2.Unlock() } +func (is *IndexSnapshot) recycleSynonymTermReader(str *IndexSnapshotSynonymTermReader) { + is.parent.rootLock.RLock() + obsolete := is.parent.root != is + is.parent.rootLock.RUnlock() + if obsolete { + // if we're not the current root (mutations happened), don't bother recycling + return + } + + is.m2.Lock() + if is.synonymTermReaders == nil { + is.synonymTermReaders = map[string][]*IndexSnapshotSynonymTermReader{} + } + if uint64(len(is.synonymTermReaders[str.name])) < is.getSynonymTermReaderCacheThreshold() { + is.synonymTermReaders[str.name] = append(is.synonymTermReaders[str.name], str) + } + is.m2.Unlock() +} + func docNumberToBytes(buf []byte, in uint64) []byte { if len(buf) != 8 { if cap(buf) >= 8 { @@ -956,3 +986,60 @@ func (is *IndexSnapshot) CloseCopyReader() error { // close the index snapshot normally return is.Close() } + +func (is *IndexSnapshot) allocSynonymTermReader(name string) (str *IndexSnapshotSynonymTermReader) { + is.m2.Lock() + if is.synonymTermReaders != nil { + strs := is.synonymTermReaders[name] + last := len(strs) - 1 + if last >= 0 { + str = strs[last] + strs[last] = nil + is.synonymTermReaders[name] = strs[:last] + is.m2.Unlock() + return + } + } + is.m2.Unlock() + return &IndexSnapshotSynonymTermReader{} +} + +func (is *IndexSnapshot) SynonymTermReader(ctx context.Context, thesaurusName string, term []byte) (index.SynonymTermReader, error) { + rv := is.allocSynonymTermReader(thesaurusName) + + rv.name = thesaurusName + rv.snapshot = is + if rv.postings == nil { + rv.postings = make([]segment.SynonymsList, len(is.segment)) + } + if rv.iterators == nil { + rv.iterators = make([]segment.SynonymsIterator, len(is.segment)) + } + rv.segmentOffset = 0 + + if rv.thesauri == nil { + rv.thesauri = make([]segment.Thesaurus, len(is.segment)) + for i, s := range is.segment { + if synSeg, ok := s.segment.(segment.SynonymSegment); ok { + thes, err := synSeg.Thesaurus(thesaurusName) + if err != nil { + return nil, err + } + rv.thesauri[i] = thes + } + } + } + + for i, s := range is.segment { + if _, ok := s.segment.(segment.SynonymSegment); ok { + pl, err := rv.thesauri[i].SynonymsList(term, s.deleted, rv.postings[i]) + if err != nil { + return nil, err + } + rv.postings[i] = pl + + rv.iterators[i] = pl.Iterator(rv.iterators[i]) + } + } + return rv, nil +} diff --git a/index/scorch/snapshot_index_str.go b/index/scorch/snapshot_index_str.go new file mode 100644 index 000000000..5dee83770 --- /dev/null +++ b/index/scorch/snapshot_index_str.go @@ -0,0 +1,82 @@ +// Copyright (c) 2024 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package scorch + +import ( + "reflect" + + "github.com/blevesearch/bleve/v2/size" + segment "github.com/blevesearch/scorch_segment_api/v2" +) + +var reflectStaticSizeIndexSnapshotSynonymTermReader int + +func init() { + var istr IndexSnapshotSynonymTermReader + reflectStaticSizeIndexSnapshotSynonymTermReader = int(reflect.TypeOf(istr).Size()) +} + +type IndexSnapshotSynonymTermReader struct { + name string + snapshot *IndexSnapshot + thesauri []segment.Thesaurus + postings []segment.SynonymsList + iterators []segment.SynonymsIterator + segmentOffset int +} + +func (i *IndexSnapshotSynonymTermReader) Size() int { + sizeInBytes := reflectStaticSizeIndexSnapshotSynonymTermReader + size.SizeOfPtr + + len(i.name) + + for _, thesaurus := range i.thesauri { + sizeInBytes += thesaurus.Size() + } + + for _, postings := range i.postings { + sizeInBytes += postings.Size() + } + + for _, iterator := range i.iterators { + sizeInBytes += iterator.Size() + } + + return sizeInBytes +} + +func (i *IndexSnapshotSynonymTermReader) Next() (string, error) { + // find the next hit + for i.segmentOffset < len(i.iterators) { + if i.iterators[i.segmentOffset] != nil { + next, err := i.iterators[i.segmentOffset].Next() + if err != nil { + return "", err + } + if next != nil { + synTerm := next.Term() + return synTerm, nil + } + i.segmentOffset++ + } + } + return "", nil +} + +func (i *IndexSnapshotSynonymTermReader) Close() error { + if i.snapshot != nil { + i.snapshot.recycleSynonymTermReader(i) + } + return nil +} diff --git a/index_impl.go b/index_impl.go index e6debf17a..f718e2cde 100644 --- a/index_impl.go +++ b/index_impl.go @@ -505,8 +505,10 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr } var knnHits []*search.DocumentMatch + var thesauri map[string]map[string][]string var ok bool var skipKnnCollector bool + var skipSynonymCollector bool if req.PreSearchData != nil { for k, v := range req.PreSearchData { switch k { @@ -516,8 +518,16 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr if !ok { return nil, fmt.Errorf("knn preSearchData must be of type []*search.DocumentMatch") } + skipKnnCollector = true + } + case search.SynonymPreSearchDataKey: + if v != nil { + thesauri, ok = v.(search.FieldTermSynonyms) + if !ok { + return nil, fmt.Errorf("synonym preSearchData must be of type map[string]map[string][]string") + } + skipSynonymCollector = true } - skipKnnCollector = true } } } @@ -528,6 +538,10 @@ func (i *indexImpl) SearchInContext(ctx context.Context, req *SearchRequest) (sr } } + if !skipSynonymCollector && mappingHasSynonymSources(i.m) { + terms := getTermsFromQuery(req.Query) + } + setKnnHitsInCollector(knnHits, req, coll) // This callback and variable handles the tracking of bytes read @@ -1127,3 +1141,10 @@ func (i *indexImpl) FireIndexEvent() { internalEventIndex.FireIndexEvent() } } + +func mappingHasSynonymSources(m mapping.IndexMapping) bool { + if im, ok := m.(*mapping.IndexMappingImpl); ok { + return len(im.SynonymSources) > 0 + } + return false +} diff --git a/mapping/index.go b/mapping/index.go index e7387d58a..94b2cdfa7 100644 --- a/mapping/index.go +++ b/mapping/index.go @@ -54,6 +54,7 @@ type IndexMappingImpl struct { IndexDynamic bool `json:"index_dynamic"` DocValuesDynamic bool `json:"docvalues_dynamic"` CustomAnalysis *customAnalysis `json:"analysis,omitempty"` + SynonymSources map[string]*SynonymSource `json:"synonym_sources,omitempty"` cache *registry.Cache } @@ -186,6 +187,12 @@ func (im *IndexMappingImpl) Validate() error { return err } } + for _, synSource := range im.SynonymSources { + err = synSource.Validate(im.cache) + if err != nil { + return err + } + } return nil } @@ -283,6 +290,14 @@ func (im *IndexMappingImpl) UnmarshalJSON(data []byte) error { if err != nil { return err } + case "synonym_sources": + if im.SynonymSources == nil { + im.SynonymSources = make(map[string]*SynonymSource) + } + err := util.UnmarshalJSON(v, &im.SynonymSources) + if err != nil { + return err + } default: invalidKeys = append(invalidKeys, k) } diff --git a/mapping/synonym.go b/mapping/synonym.go new file mode 100644 index 000000000..5065b3b2a --- /dev/null +++ b/mapping/synonym.go @@ -0,0 +1,56 @@ +// Copyright (c) 2024 Couchbase, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mapping + +import ( + "fmt" + + "github.com/blevesearch/bleve/v2/registry" +) + +type SynonymSource struct { + CollectionName string `json:"collection"` + AnalyzerName string `json:"analyzer"` +} + +func (s *SynonymSource) Collection() string { + return s.CollectionName +} + +func (s *SynonymSource) Analyzer() string { + return s.AnalyzerName +} + +func (s *SynonymSource) SetCollection(c string) { + s.CollectionName = c +} + +func (s *SynonymSource) SetAnalyzer(a string) { + s.AnalyzerName = a +} + +func (s *SynonymSource) Validate(c *registry.Cache) error { + if s.CollectionName == "" { + return fmt.Errorf("collection name is required") + } + if s.AnalyzerName == "" { + return fmt.Errorf("analyzer name is required") + } + _, err := c.AnalyzerNamed(s.AnalyzerName) + if err != nil { + return fmt.Errorf("analyzer named '%s' not found", s.AnalyzerName) + } + return nil +} diff --git a/search/searcher/search_term.go b/search/searcher/search_term.go index cd794ea32..a2244e8e0 100644 --- a/search/searcher/search_term.go +++ b/search/searcher/search_term.go @@ -38,15 +38,23 @@ type TermSearcher struct { tfd index.TermFieldDoc } -func NewTermSearcher(ctx context.Context, indexReader index.IndexReader, term string, field string, boost float64, options search.SearcherOptions) (*TermSearcher, error) { +func NewTermSearcher(ctx context.Context, indexReader index.IndexReader, term string, field string, boost float64, options search.SearcherOptions) (search.Searcher, error) { if isTermQuery(ctx) { ctx = context.WithValue(ctx, search.QueryTypeKey, search.Term) } return NewTermSearcherBytes(ctx, indexReader, []byte(term), field, boost, options) } -func NewTermSearcherBytes(ctx context.Context, indexReader index.IndexReader, term []byte, field string, boost float64, options search.SearcherOptions) (*TermSearcher, error) { +func NewTermSearcherBytes(ctx context.Context, indexReader index.IndexReader, term []byte, field string, boost float64, options search.SearcherOptions) (search.Searcher, error) { needFreqNorm := options.Score != "none" + if fieldTermSynonyms, ok := ctx.Value(search.FieldTermSynonymsKey).(search.FieldTermSynonyms); ok { + if termSynonyms, ok := fieldTermSynonyms[field]; ok { + synonyms := termSynonyms[string(term)] + if len(synonyms) > 0 { + return newSynonymSearcherFromReader(ctx, indexReader, term, synonyms, field, boost, options, needFreqNorm) + } + } + } reader, err := indexReader.TermFieldReader(ctx, term, field, needFreqNorm, needFreqNorm, options.IncludeTermVectors) if err != nil { return nil, err @@ -54,6 +62,46 @@ func NewTermSearcherBytes(ctx context.Context, indexReader index.IndexReader, te return newTermSearcherFromReader(indexReader, reader, term, field, boost, options) } +func newSynonymSearcherFromReader(ctx context.Context, indexReader index.IndexReader, term []byte, synonyms []string, + field string, boost float64, options search.SearcherOptions, needFreqNorm bool) (search.Searcher, error) { + qsearchers := make([]search.Searcher, 0, len(synonyms)+1) + qsearchersClose := func() { + for _, searcher := range qsearchers { + if searcher != nil { + _ = searcher.Close() + } + } + } + for _, synonym := range synonyms { + synonymReader, err := indexReader.TermFieldReader(ctx, []byte(synonym), field, needFreqNorm, needFreqNorm, options.IncludeTermVectors) + if err != nil { + return nil, err + } + searcher, err := newTermSearcherFromReader(indexReader, synonymReader, []byte(synonym), field, boost, options) + if err != nil { + qsearchersClose() + return nil, err + } + qsearchers = append(qsearchers, searcher) + } + reader, err := indexReader.TermFieldReader(ctx, term, field, needFreqNorm, needFreqNorm, options.IncludeTermVectors) + if err != nil { + return nil, err + } + searcher, err := newTermSearcherFromReader(indexReader, reader, term, field, boost, options) + if err != nil { + qsearchersClose() + return nil, err + } + qsearchers = append(qsearchers, searcher) + rv, err := newDisjunctionSearcher(ctx, indexReader, qsearchers, 1, options, true) + if err != nil { + qsearchersClose() + return nil, err + } + return rv, nil +} + func newTermSearcherFromReader(indexReader index.IndexReader, reader index.TermFieldReader, term []byte, field string, boost float64, options search.SearcherOptions) (*TermSearcher, error) { count, err := indexReader.DocCount() diff --git a/search/util.go b/search/util.go index 6472803d1..22e1dceb6 100644 --- a/search/util.go +++ b/search/util.go @@ -17,6 +17,10 @@ package search import ( "context" + "github.com/blevesearch/bleve/v2/mapping" + "github.com/blevesearch/bleve/v2/search/query" + "github.com/blevesearch/bleve/v2/util" + index "github.com/blevesearch/bleve_index_api" "github.com/blevesearch/geo/s2" ) @@ -137,6 +141,12 @@ type GeoBufferPoolCallbackFunc func() *s2.GeoBufferPool const KnnPreSearchDataKey = "_knn_pre_search_data_key" +const SynonymPreSearchDataKey = "_synonym_pre_search_data_key" + +const FieldTermSynonymsKey = "_field_term_synonyms_key" + +type FieldTermSynonyms map[string]map[string][]string + const PreSearchKey = "_presearch_key" type ScoreExplCorrectionCallbackFunc func(queryMatch *DocumentMatch, knnMatch *DocumentMatch) (float64, *Explanation) @@ -146,3 +156,146 @@ type SearcherEndCallbackFn func(size uint64) error const SearcherStartCallbackKey = "_searcher_start_callback_key" const SearcherEndCallbackKey = "_searcher_end_callback_key" + +func getFieldTermsFromQuery(q query.Query, m mapping.IndexMapping, r index.IndexReader, rv map[string][]string) error { + if q == nil { + return nil + } + switch q := q.(type) { + case *query.TermQuery: + field := q.Field() + if field == "" { + field = m.DefaultSearchField() + } + rv[field] = []string{string(q.Term)} + + case *query.FuzzyQuery: + field := q.Field() + if field == "" { + field = m.DefaultSearchField() + } + rv[field] = []string{q.Term} + + default: + return nil, nil + } + + _, hasFuzziness := tmp["fuzziness"] + _, isMatchQuery := tmp["match"] + _, isMatchPhraseQuery := tmp["match_phrase"] + _, hasTerms := tmp["terms"] + if hasFuzziness && !isMatchQuery && !isMatchPhraseQuery && !hasTerms { + var rv FuzzyQuery + err := util.UnmarshalJSON(input, &rv) + if err != nil { + return nil, err + } + return &rv, nil + } + if isMatchQuery { + var rv MatchQuery + err := util.UnmarshalJSON(input, &rv) + if err != nil { + return nil, err + } + return &rv, nil + } + if isMatchPhraseQuery { + var rv MatchPhraseQuery + err := util.UnmarshalJSON(input, &rv) + if err != nil { + return nil, err + } + return &rv, nil + } + if hasTerms { + var rv PhraseQuery + err := util.UnmarshalJSON(input, &rv) + if err != nil { + // now try multi-phrase + var rv2 MultiPhraseQuery + err = util.UnmarshalJSON(input, &rv2) + if err != nil { + return nil, err + } + return &rv2, nil + } + return &rv, nil + } + _, hasMust := tmp["must"] + _, hasShould := tmp["should"] + _, hasMustNot := tmp["must_not"] + if hasMust || hasShould || hasMustNot { + var rv BooleanQuery + err := util.UnmarshalJSON(input, &rv) + if err != nil { + return nil, err + } + return &rv, nil + } + _, hasConjuncts := tmp["conjuncts"] + if hasConjuncts { + var rv ConjunctionQuery + err := util.UnmarshalJSON(input, &rv) + if err != nil { + return nil, err + } + return &rv, nil + } + _, hasDisjuncts := tmp["disjuncts"] + if hasDisjuncts { + var rv DisjunctionQuery + err := util.UnmarshalJSON(input, &rv) + if err != nil { + return nil, err + } + return &rv, nil + } + + _, hasSyntaxQuery := tmp["query"] + if hasSyntaxQuery { + var rv QueryStringQuery + err := util.UnmarshalJSON(input, &rv) + if err != nil { + return nil, err + } + return &rv, nil + } + _, hasMinStr := tmp["min"].(string) + _, hasMaxStr := tmp["max"].(string) + if hasMinStr || hasMaxStr { + var rv TermRangeQuery + err := util.UnmarshalJSON(input, &rv) + if err != nil { + return nil, err + } + return &rv, nil + } + _, hasPrefix := tmp["prefix"] + if hasPrefix { + var rv PrefixQuery + err := util.UnmarshalJSON(input, &rv) + if err != nil { + return nil, err + } + return &rv, nil + } + _, hasRegexp := tmp["regexp"] + if hasRegexp { + var rv RegexpQuery + err := util.UnmarshalJSON(input, &rv) + if err != nil { + return nil, err + } + return &rv, nil + } + _, hasWildcard := tmp["wildcard"] + if hasWildcard { + var rv WildcardQuery + err := util.UnmarshalJSON(input, &rv) + if err != nil { + return nil, err + } + return &rv, nil + } +}