Skip to content

Commit

Permalink
Add topN arg to Scorer, use for SearchRequest.size
Browse files Browse the repository at this point in the history
  • Loading branch information
valencik committed Feb 23, 2025
1 parent c09e445 commit b3dfecc
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 9 deletions.
31 changes: 25 additions & 6 deletions core/src/main/scala/pink/cozydev/protosearch/Scorer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import cats.data.NonEmptyList
import cats.syntax.all.*
import pink.cozydev.lucille.Query

import scala.collection.mutable.HashMap
import pink.cozydev.protosearch.internal.PositionalIter

import java.util.regex.PatternSyntaxException
Expand All @@ -29,7 +28,7 @@ case class Scorer(index: MultiIndex, defaultOR: Boolean = true) {

private val defaultIdx: Index = index.indexes(index.schema.defaultField)

def score(qs: Query, docs: Set[Int]): Either[String, List[(Int, Double)]] = {
def score(qs: Query, docs: Set[Int], topN: Int): Either[String, List[(Int, Double)]] = {
def accScore(
idx: Index,
queries: NonEmptyList[Query],
Expand Down Expand Up @@ -57,7 +56,7 @@ case class Scorer(index: MultiIndex, defaultOR: Boolean = true) {
case q: Query.Boost => Left(s"Unsupported Boost in Scorer: $q")
case q: Query.WildCard => Left(s"Unsupported WildCard in Scorer: $q")
}
accScore(defaultIdx, NonEmptyList.one(qs)).map(combineMaps)
accScore(defaultIdx, NonEmptyList.one(qs)).map(ms => combineMaps(ms, topN))
}

private def phraseScore(
Expand Down Expand Up @@ -124,11 +123,31 @@ case class Scorer(index: MultiIndex, defaultOR: Boolean = true) {
}
}

private def combineMaps(ms: NonEmptyList[Map[Int, Double]]): List[(Int, Double)] = {
val mb = HashMap.empty ++ ms.head
private val ord = Ordering[(Double, Int)].on[(Int, Double)](idScore => (-idScore._2, idScore._1))

private def combineMaps(ms: NonEmptyList[Map[Int, Double]], topN: Int): List[(Int, Double)] = {
// Combine scores by Id
val mb = scala.collection.mutable.HashMap.empty[Int, Double] ++ ms.head
ms.tail.foreach(m1 =>
m1.foreach { case (k: Int, v: Double) => mb.update(k, v + mb.getOrElse(k, 0.0)) }
)
mb.toList.sortBy(idScore => (-idScore._2, idScore._1))
// Sort scores
val arr = new Array[Tuple2[Int, Double]](mb.size)
var i = 0
mb.foreach { docScore =>
arr(i) = docScore
i += 1
}
java.util.Arrays.sort(arr, ord)
// Build list of topN
val bldr = List.newBuilder[(Int, Double)]
val resultSize = if (mb.size < topN) mb.size else topN
bldr.sizeHint(resultSize)
var n = 0
while (n < resultSize) {
bldr += arr(n)
n += 1
}
bldr.result()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ final case class SearchInterpreter(
parseQ.flatMap(q =>
indexSearcher
.search(q)
.flatMap(ds => scorer.score(q, ds))
.flatMap(ds => scorer.score(q, ds, request.size))
)

val lstB = List.newBuilder[Hit]
Expand Down
14 changes: 13 additions & 1 deletion core/src/test/scala/pink/cozydev/protosearch/ScorerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ScorerSuite extends munit.FunSuite {
def score(q: String, docs: Set[Int]): Either[String, List[(Int, Double)]] =
QueryParser
.parse(q)
.flatMap(q => scorer.score(q, docs))
.flatMap(q => scorer.score(q, docs, 10))

def ordered(hits: Either[String, List[(Int, Double)]]): List[Int] =
hits.fold(_ => Nil, ds => ds.map(_._1))
Expand Down Expand Up @@ -74,4 +74,16 @@ class ScorerSuite extends munit.FunSuite {
assertEquals(ordered(hits), List(1))
}

test("scorer with topN=1 returns only top doc") {
val q = "Tale OR Two" // has 3 matches
val hits = QueryParser.parse(q).flatMap(q => scorer.score(q, allDocs, topN = 1))
assertEquals(ordered(hits), List(1))
}

test("scorer topN can be bigger than number of matches") {
val q = "Tale OR Two" // has 3 matches
val hits = QueryParser.parse(q).flatMap(q => scorer.score(q, allDocs, topN = 999))
assertEquals(ordered(hits), List(1, 0, 2))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Querier(val mIndex: MultiIndex) {

@JSExport
def search(query: String): js.Array[JsHit] = {
val req = SearchRequest(query, 10, highlightFields, resultFields, true)
val req = SearchRequest(query, size = 10, highlightFields, resultFields, lastTermPrefix = true)
val hits = searcher
.search(req)
.fold(
Expand Down

0 comments on commit b3dfecc

Please sign in to comment.