caio.co/de/cantine

Add support for customizing the scores

Kind of a `tique::top_collector::custom_score` copy pasta with cleaned
up trait requirements and added support for type-based ordering.
Id
7f1c7f3c20e604fba6a35dc45aec5e3324086ba6
Author
Caio
Commit time
2020-01-21T16:02:12+01:00

Modified tique/src/conditional_collector/mod.rs

@@ -1,7 +1,9
+mod custom_score;
mod top_collector;
mod topk;
-mod traits;

-pub use top_collector::{CollectionResult, TopCollector, TopSegmentCollector};
+pub mod traits;
+
+pub use custom_score::CustomScoreTopCollector;
+pub use top_collector::{CollectionResult, TopCollector};
pub use topk::{Ascending, Descending};
-pub use traits::{CheckCondition, ConditionForSegment};

Modified tique/src/conditional_collector/top_collector.rs

@@ -7,12 +7,12

use super::{
topk::{TopK, TopKProvider},
- CheckCondition, ConditionForSegment,
+ traits::{CheckCondition, ConditionForSegment},
};

pub struct TopCollector<T, P, CF> {
limit: usize,
- condition_factory: CF,
+ condition_for_segment: CF,
_score: PhantomData<T>,
_provider: PhantomData<P>,
}
@@ -20,16 +20,16
impl<T, P, CF> TopCollector<T, P, CF>
where
T: PartialOrd,
- P: 'static + Send + Sync + TopKProvider<Score>,
- CF: ConditionForSegment<T> + Sync,
+ P: TopKProvider<T>,
+ CF: ConditionForSegment<T>,
{
- pub fn new(limit: usize, condition_factory: CF) -> Self {
+ pub fn new(limit: usize, condition_for_segment: CF) -> Self {
if limit < 1 {
panic!("Limit must be greater than 0");
}
TopCollector {
limit,
- condition_factory,
+ condition_for_segment,
_score: PhantomData,
_provider: PhantomData,
}
@@ -39,7 +39,7
impl<P, CF> Collector for TopCollector<Score, P, CF>
where
P: 'static + Send + Sync + TopKProvider<Score>,
- CF: ConditionForSegment<Score> + Sync,
+ CF: Sync + ConditionForSegment<Score>,
{
type Fruit = CollectionResult<Score>;
type Child = TopSegmentCollector<Score, P::Child, CF::Type>;
@@ -60,7 +60,7
Ok(TopSegmentCollector::new(
segment_id,
P::new_topk(self.limit),
- self.condition_factory.for_segment(reader),
+ self.condition_for_segment.for_segment(reader),
))
}
}
@@ -76,10 +76,11

impl<T, K, C> TopSegmentCollector<T, K, C>
where
+ T: Copy,
K: TopK<T, DocId>,
C: CheckCondition<T>,
{
- fn new(segment_id: SegmentLocalId, topk: K, condition: C) -> Self {
+ pub fn new(segment_id: SegmentLocalId, topk: K, condition: C) -> Self {
Self {
total: 0,
visited: 0,
@@ -94,16 +95,8
fn into_topk(self) -> K {
self.topk
}
-}

-impl<K, C> SegmentCollector for TopSegmentCollector<Score, K, C>
-where
- K: TopK<Score, DocId> + 'static,
- C: CheckCondition<Score>,
-{
- type Fruit = CollectionResult<Score>;
-
- fn collect(&mut self, doc: DocId, score: Score) {
+ pub fn collect(&mut self, doc: DocId, score: T) {
self.total += 1;
if self
.condition
@@ -114,7 +107,7
}
}

- fn harvest(self) -> Self::Fruit {
+ pub fn into_collection_result(self) -> CollectionResult<T> {
let segment_id = self.segment_id;
let items = self
.topk
@@ -131,6 +124,22
visited: self.visited,
items,
}
+ }
+}
+
+impl<K, C> SegmentCollector for TopSegmentCollector<Score, K, C>
+where
+ K: 'static + TopK<Score, DocId>,
+ C: CheckCondition<Score>,
+{
+ type Fruit = CollectionResult<Score>;
+
+ fn collect(&mut self, doc: DocId, score: Score) {
+ TopSegmentCollector::collect(self, doc, score)
+ }
+
+ fn harvest(self) -> Self::Fruit {
+ TopSegmentCollector::into_collection_result(self)
}
}

Modified tique/src/conditional_collector/traits.rs

@@ -61,3 +61,33
== wanted
}
}
+
+pub trait ScorerForSegment<T>: Sync {
+ type Type: DocScorer<T>;
+ fn for_segment(&self, reader: &SegmentReader) -> Self::Type;
+}
+
+impl<T, C, F> ScorerForSegment<T> for F
+where
+ F: 'static + Sync + Send + Fn(&SegmentReader) -> C,
+ C: DocScorer<T>,
+{
+ type Type = C;
+
+ fn for_segment(&self, reader: &SegmentReader) -> Self::Type {
+ (self)(reader)
+ }
+}
+
+pub trait DocScorer<T>: 'static {
+ fn score(&self, doc_id: DocId) -> T;
+}
+
+impl<F, T> DocScorer<T> for F
+where
+ F: 'static + Sync + Send + Fn(DocId) -> T,
+{
+ fn score(&self, doc_id: DocId) -> T {
+ (self)(doc_id)
+ }
+}

Created tique/src/conditional_collector/custom_score.rs

@@ -1,0 +1,180
+use std::marker::PhantomData;
+
+use tantivy::{
+ collector::{Collector, SegmentCollector},
+ DocId, Result, Score, SegmentLocalId, SegmentReader,
+};
+
+use super::{
+ top_collector::TopSegmentCollector,
+ topk::{TopK, TopKProvider},
+ traits::{CheckCondition, ConditionForSegment, DocScorer, ScorerForSegment},
+ CollectionResult,
+};
+
+pub struct CustomScoreTopCollector<T, P, C, S>
+where
+ T: PartialOrd,
+ P: TopKProvider<T>,
+ C: ConditionForSegment<T>,
+{
+ limit: usize,
+ scorer_for_segment: S,
+ condition_for_segment: C,
+ _score: PhantomData<T>,
+ _provider: PhantomData<P>,
+}
+
+impl<T, P, C, S> CustomScoreTopCollector<T, P, C, S>
+where
+ T: PartialOrd,
+ P: TopKProvider<T>,
+ C: ConditionForSegment<T>,
+{
+ pub fn new(limit: usize, condition_for_segment: C, scorer_for_segment: S) -> Self {
+ Self {
+ limit,
+ scorer_for_segment,
+ condition_for_segment,
+ _score: PhantomData,
+ _provider: PhantomData,
+ }
+ }
+}
+
+impl<T, P, C, S> Collector for CustomScoreTopCollector<T, P, C, S>
+where
+ T: 'static + PartialOrd + Copy + Send + Sync,
+ P: 'static + Send + Sync + TopKProvider<T>,
+ C: Sync + ConditionForSegment<T>,
+ S: ScorerForSegment<T>,
+{
+ type Fruit = CollectionResult<T>;
+ type Child = CustomScoreTopSegmentCollector<T, C::Type, S::Type, P::Child>;
+
+ fn requires_scoring(&self) -> bool {
+ false
+ }
+
+ fn merge_fruits(&self, children: Vec<Self::Fruit>) -> Result<Self::Fruit> {
+ Ok(P::merge_many(self.limit, children))
+ }
+
+ fn for_segment(
+ &self,
+ segment_id: SegmentLocalId,
+ reader: &SegmentReader,
+ ) -> Result<Self::Child> {
+ let scorer = self.scorer_for_segment.for_segment(reader);
+ Ok(CustomScoreTopSegmentCollector::new(
+ segment_id,
+ P::new_topk(self.limit),
+ scorer,
+ self.condition_for_segment.for_segment(reader),
+ ))
+ }
+}
+
+pub struct CustomScoreTopSegmentCollector<T, C, S, K>
+where
+ C: CheckCondition<T>,
+ K: TopK<T, DocId>,
+{
+ scorer: S,
+ collector: TopSegmentCollector<T, K, C>,
+}
+
+impl<T, C, S, K> CustomScoreTopSegmentCollector<T, C, S, K>
+where
+ T: Copy,
+ C: CheckCondition<T>,
+ K: TopK<T, DocId>,
+{
+ pub fn new(segment_id: SegmentLocalId, topk: K, scorer: S, condition: C) -> Self {
+ Self {
+ scorer,
+ collector: TopSegmentCollector::new(segment_id, topk, condition),
+ }
+ }
+}
+
+impl<T, C, S, K> SegmentCollector for CustomScoreTopSegmentCollector<T, C, S, K>
+where
+ T: 'static + PartialOrd + Copy + Send + Sync,
+ K: 'static + TopK<T, DocId>,
+ C: CheckCondition<T>,
+ S: DocScorer<T>,
+{
+ type Fruit = CollectionResult<T>;
+
+ fn collect(&mut self, doc: DocId, _: Score) {
+ let score = self.scorer.score(doc);
+ self.collector.collect(doc, score);
+ }
+
+ fn harvest(self) -> Self::Fruit {
+ self.collector.into_collection_result()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::conditional_collector::{topk::AscendingTopK, Descending};
+
+ use tantivy::{query::AllQuery, schema::SchemaBuilder, Document, Index};
+
+ #[test]
+ fn custom_segment_scorer_gets_called() {
+ let mut collector = CustomScoreTopSegmentCollector::new(
+ 0,
+ AscendingTopK::new(1),
+ // Use the doc_id as the score
+ |doc_id| doc_id,
+ true,
+ );
+
+ // So that whatever we provide as a score
+ collector.collect(1, 42.0);
+ let res = collector.harvest();
+ assert_eq!(1, res.total);
+
+ let got = &res.items[0];
+ // Is disregarded and doc_id is used instead
+ assert_eq!((got.1).1, got.0)
+ }
+
+ #[test]
+ fn custom_top_scorer_integration() -> Result<()> {
+ let builder = SchemaBuilder::new();
+ let index = Index::create_in_ram(builder.build());
+
+ let mut writer = index.writer_with_num_threads(1, 3_000_000)?;
+
+ // We add 100 documents to our index
+ for _ in 0..100 {
+ writer.add_document(Document::new());
+ }
+
+ writer.commit()?;
+
+ let reader = index.reader()?;
+ let searcher = reader.searcher();
+
+ let colletor =
+ CustomScoreTopCollector::<_, Descending, _, _>::new(2, true, |_: &SegmentReader| {
+ |doc_id: DocId| u64::from(doc_id * 10)
+ });
+
+ let result = searcher.search(&AllQuery, &colletor)?;
+
+ assert_eq!(100, result.total);
+ assert_eq!(2, result.items.len());
+
+ // So we expect that the highest score is 990
+ assert_eq!(result.items[0].0, 990);
+ assert_eq!(result.items[1].0, 980);
+
+ Ok(())
+ }
+}