From b3dfeccb5f480c43210f9c0b56a75bdb375bf536 Mon Sep 17 00:00:00 2001 From: Andrew Valencik Date: Sat, 22 Feb 2025 19:39:14 -0500 Subject: [PATCH] Add topN arg to Scorer, use for SearchRequest.size --- .../pink/cozydev/protosearch/Scorer.scala | 31 +++++++++++++++---- .../protosearch/SearchInterpreter.scala | 2 +- .../cozydev/protosearch/ScorerSuite.scala | 14 ++++++++- .../pink/cozydev/protosearch/JsInterop.scala | 2 +- 4 files changed, 40 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/pink/cozydev/protosearch/Scorer.scala b/core/src/main/scala/pink/cozydev/protosearch/Scorer.scala index 88735611..ed81267f 100644 --- a/core/src/main/scala/pink/cozydev/protosearch/Scorer.scala +++ b/core/src/main/scala/pink/cozydev/protosearch/Scorer.scala @@ -20,7 +20,6 @@ import cats.data.NonEmptyList import cats.syntax.all.* import pink.cozydev.lucille.Query -import scala.collection.mutable.HashMap import pink.cozydev.protosearch.internal.PositionalIter import java.util.regex.PatternSyntaxException @@ -29,7 +28,7 @@ case class Scorer(index: MultiIndex, defaultOR: Boolean = true) { private val defaultIdx: Index = index.indexes(index.schema.defaultField) - def score(qs: Query, docs: Set[Int]): Either[String, List[(Int, Double)]] = { + def score(qs: Query, docs: Set[Int], topN: Int): Either[String, List[(Int, Double)]] = { def accScore( idx: Index, queries: NonEmptyList[Query], @@ -57,7 +56,7 @@ case class Scorer(index: MultiIndex, defaultOR: Boolean = true) { case q: Query.Boost => Left(s"Unsupported Boost in Scorer: $q") case q: Query.WildCard => Left(s"Unsupported WildCard in Scorer: $q") } - accScore(defaultIdx, NonEmptyList.one(qs)).map(combineMaps) + accScore(defaultIdx, NonEmptyList.one(qs)).map(ms => combineMaps(ms, topN)) } private def phraseScore( @@ -124,11 +123,31 @@ case class Scorer(index: MultiIndex, defaultOR: Boolean = true) { } } - private def combineMaps(ms: NonEmptyList[Map[Int, Double]]): List[(Int, Double)] = { - val mb = HashMap.empty ++ ms.head + private val ord = Ordering[(Double, Int)].on[(Int, Double)](idScore => (-idScore._2, idScore._1)) + + private def combineMaps(ms: NonEmptyList[Map[Int, Double]], topN: Int): List[(Int, Double)] = { + // Combine scores by Id + val mb = scala.collection.mutable.HashMap.empty[Int, Double] ++ ms.head ms.tail.foreach(m1 => m1.foreach { case (k: Int, v: Double) => mb.update(k, v + mb.getOrElse(k, 0.0)) } ) - mb.toList.sortBy(idScore => (-idScore._2, idScore._1)) + // Sort scores + val arr = new Array[Tuple2[Int, Double]](mb.size) + var i = 0 + mb.foreach { docScore => + arr(i) = docScore + i += 1 + } + java.util.Arrays.sort(arr, ord) + // Build list of topN + val bldr = List.newBuilder[(Int, Double)] + val resultSize = if (mb.size < topN) mb.size else topN + bldr.sizeHint(resultSize) + var n = 0 + while (n < resultSize) { + bldr += arr(n) + n += 1 + } + bldr.result() } } diff --git a/core/src/main/scala/pink/cozydev/protosearch/SearchInterpreter.scala b/core/src/main/scala/pink/cozydev/protosearch/SearchInterpreter.scala index 02949426..7644d47b 100644 --- a/core/src/main/scala/pink/cozydev/protosearch/SearchInterpreter.scala +++ b/core/src/main/scala/pink/cozydev/protosearch/SearchInterpreter.scala @@ -36,7 +36,7 @@ final case class SearchInterpreter( parseQ.flatMap(q => indexSearcher .search(q) - .flatMap(ds => scorer.score(q, ds)) + .flatMap(ds => scorer.score(q, ds, request.size)) ) val lstB = List.newBuilder[Hit] diff --git a/core/src/test/scala/pink/cozydev/protosearch/ScorerSuite.scala b/core/src/test/scala/pink/cozydev/protosearch/ScorerSuite.scala index eded8b7b..98f1d6e1 100644 --- a/core/src/test/scala/pink/cozydev/protosearch/ScorerSuite.scala +++ b/core/src/test/scala/pink/cozydev/protosearch/ScorerSuite.scala @@ -38,7 +38,7 @@ class ScorerSuite extends munit.FunSuite { def score(q: String, docs: Set[Int]): Either[String, List[(Int, Double)]] = QueryParser .parse(q) - .flatMap(q => scorer.score(q, docs)) + .flatMap(q => scorer.score(q, docs, 10)) def ordered(hits: Either[String, List[(Int, Double)]]): List[Int] = hits.fold(_ => Nil, ds => ds.map(_._1)) @@ -74,4 +74,16 @@ class ScorerSuite extends munit.FunSuite { assertEquals(ordered(hits), List(1)) } + test("scorer with topN=1 returns only top doc") { + val q = "Tale OR Two" // has 3 matches + val hits = QueryParser.parse(q).flatMap(q => scorer.score(q, allDocs, topN = 1)) + assertEquals(ordered(hits), List(1)) + } + + test("scorer topN can be bigger than number of matches") { + val q = "Tale OR Two" // has 3 matches + val hits = QueryParser.parse(q).flatMap(q => scorer.score(q, allDocs, topN = 999)) + assertEquals(ordered(hits), List(1, 0, 2)) + } + } diff --git a/jsinterop/src/main/scala/pink/cozydev/protosearch/JsInterop.scala b/jsinterop/src/main/scala/pink/cozydev/protosearch/JsInterop.scala index 6963e651..cc9bc91f 100644 --- a/jsinterop/src/main/scala/pink/cozydev/protosearch/JsInterop.scala +++ b/jsinterop/src/main/scala/pink/cozydev/protosearch/JsInterop.scala @@ -41,7 +41,7 @@ class Querier(val mIndex: MultiIndex) { @JSExport def search(query: String): js.Array[JsHit] = { - val req = SearchRequest(query, 10, highlightFields, resultFields, true) + val req = SearchRequest(query, size = 10, highlightFields, resultFields, lastTermPrefix = true) val hits = searcher .search(req) .fold(