Sunset top_collector module
The `conditional_collector` module effectively replaces it. It still doesn't contain a `TweakedScoreTopCollector`, but that's simply because I don't need it anymore since I'm not modifying scores to simulate order change.
- Id
- 1757c20b79e1e12726f7f8ff1238e7554f827e8f
- Author
- Caio
- Commit time
- 2020-01-21T17:41:15+01:00
Modified tique/src/lib.rs
pub mod conditional_collector;
pub mod queryparser;
-pub mod top_collector;
mod derive;
Deleted tique/src/top_collector/conditional_collector.rs
-use std::marker::PhantomData;
-
-use tantivy::{
- collector::{Collector, SegmentCollector},
- DocAddress, DocId, Result, Score, SegmentLocalId, SegmentReader,
-};
-
-use super::{Scored, TopK};
-
-pub trait ConditionForSegment<T>: Clone {
- type Type: CheckCondition<T>;
- fn for_segment(&self, reader: &SegmentReader) -> Self::Type;
-}
-
-impl<T, C, F> ConditionForSegment<T> for F
-where
- F: Clone + Fn(&SegmentReader) -> C,
- C: CheckCondition<T>,
-{
- type Type = C;
- fn for_segment(&self, reader: &SegmentReader) -> Self::Type {
- (self)(reader)
- }
-}
-
-impl<T> ConditionForSegment<T> for bool {
- type Type = bool;
- fn for_segment(&self, _reader: &SegmentReader) -> Self::Type {
- *self
- }
-}
-
-pub trait CheckCondition<T>: 'static + Clone {
- fn check(&self, segment_id: SegmentLocalId, doc_id: DocId, score: T) -> bool;
-}
-
-impl<T> CheckCondition<T> for bool {
- fn check(&self, _: SegmentLocalId, _: DocId, _: T) -> bool {
- *self
- }
-}
-
-impl<F, T> CheckCondition<T> for F
-where
- F: 'static + Clone + Fn(SegmentLocalId, DocId, T) -> bool,
-{
- fn check(&self, segment_id: SegmentLocalId, doc_id: DocId, score: T) -> bool {
- (self)(segment_id, doc_id, score)
- }
-}
-
-pub type SearchMarker<T> = Scored<T, DocAddress>;
-
-impl<T> CheckCondition<T> for SearchMarker<T>
-where
- T: 'static + PartialOrd + Clone,
-{
- fn check(&self, segment_id: SegmentLocalId, doc_id: DocId, score: T) -> bool {
- // So: only collect items that would come _after_ this marker
- *self > Scored::new(score, DocAddress(segment_id, doc_id))
- }
-}
-
-pub struct ConditionalTopCollector<T, F>
-where
- F: ConditionForSegment<T>,
-{
- pub limit: usize,
- condition_factory: F,
- _marker: PhantomData<T>,
-}
-
-impl<T, F> ConditionalTopCollector<T, F>
-where
- T: PartialOrd,
- F: ConditionForSegment<T>,
-{
- pub fn with_limit(limit: usize, condition_factory: F) -> Self {
- if limit < 1 {
- panic!("Limit must be greater than 0");
- }
- ConditionalTopCollector {
- limit,
- condition_factory,
- _marker: PhantomData,
- }
- }
-
- pub fn merge_many(&self, children: Vec<CollectionResult<T>>) -> CollectionResult<T> {
- CollectionResult::merge_many(self.limit, children)
- }
-}
-
-impl<F> Collector for ConditionalTopCollector<Score, F>
-where
- F: ConditionForSegment<Score> + Sync,
-{
- type Fruit = CollectionResult<Score>;
- type Child = ConditionalTopSegmentCollector<Score, F::Type>;
-
- fn requires_scoring(&self) -> bool {
- true
- }
-
- fn merge_fruits(&self, children: Vec<Self::Fruit>) -> Result<Self::Fruit> {
- Ok(self.merge_many(children))
- }
-
- fn for_segment(
- &self,
- segment_id: SegmentLocalId,
- reader: &SegmentReader,
- ) -> Result<Self::Child> {
- Ok(ConditionalTopSegmentCollector::new(
- segment_id,
- self.limit,
- self.condition_factory.for_segment(reader),
- ))
- }
-}
-
-pub struct ConditionalTopSegmentCollector<T, F>
-where
- F: CheckCondition<T>,
-{
- segment_id: SegmentLocalId,
- collected: TopK<T, DocId>,
- visited: usize,
- total: usize,
- condition: F,
-}
-
-impl<T, F> ConditionalTopSegmentCollector<T, F>
-where
- T: PartialOrd + Copy,
- F: CheckCondition<T>,
-{
- pub fn new(segment_id: SegmentLocalId, limit: usize, condition: F) -> Self {
- ConditionalTopSegmentCollector {
- collected: TopK::new(limit),
- segment_id,
- condition,
- visited: 0,
- total: 0,
- }
- }
-
- #[cfg(test)]
- fn len(&self) -> usize {
- self.collected.len()
- }
-
- #[inline(always)]
- pub fn visit(&mut self, doc: DocId, score: T) {
- self.total += 1;
- if self.condition.check(self.segment_id, doc, score) {
- self.visited += 1;
- self.collected.visit(score, doc);
- }
- }
-
- pub fn into_collection_result(self) -> CollectionResult<T> {
- let segment_id = self.segment_id;
- let items = self
- .collected
- .into_vec()
- .into_iter()
- .map(|Scored { score, doc }| (score, DocAddress(segment_id, doc)))
- .collect();
-
- CollectionResult {
- total: self.total,
- visited: self.visited,
- items,
- }
- }
-}
-
-impl<F> SegmentCollector for ConditionalTopSegmentCollector<Score, F>
-where
- F: CheckCondition<Score>,
-{
- type Fruit = CollectionResult<Score>;
-
- fn collect(&mut self, doc: DocId, score: Score) {
- self.visit(doc, score);
- }
-
- fn harvest(self) -> Self::Fruit {
- self.into_collection_result()
- }
-}
-
-#[derive(Debug)]
-pub struct CollectionResult<T> {
- pub total: usize,
- pub visited: usize,
- pub items: Vec<(T, DocAddress)>,
-}
-
-impl<T: PartialOrd> CollectionResult<T> {
- pub fn merge_many(limit: usize, items: Vec<CollectionResult<T>>) -> CollectionResult<T> {
- let mut topk = TopK::new(limit);
- let mut total = 0;
- let mut visited = 0;
-
- for item in items {
- total += item.total;
- visited += item.visited;
-
- for (score, doc) in item.items {
- topk.visit(score, doc);
- }
- }
-
- CollectionResult {
- total,
- visited,
- items: topk
- .into_sorted_vec()
- .into_iter()
- .map(|Scored { score, doc }| (score, doc))
- .collect(),
- }
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- #[test]
- fn condition_is_checked() {
- const LIMIT: usize = 4;
-
- let mut nil_collector = ConditionalTopSegmentCollector::new(0, LIMIT, false);
-
- let mut top_collector = ConditionalTopSegmentCollector::new(0, LIMIT, true);
-
- let condition = |_sid, doc, _score| doc % 2 == 1;
-
- let mut just_odds = ConditionalTopSegmentCollector::new(0, LIMIT, condition);
-
- for i in 0..4 {
- nil_collector.collect(i, 420.0);
- top_collector.collect(i, 420.0);
- just_odds.collect(i, 420.0);
- }
-
- assert_eq!(0, nil_collector.len());
- assert_eq!(4, top_collector.len());
- assert_eq!(2, just_odds.len());
-
- // Verify that the collected items respect the condition
- let result = just_odds.harvest();
- assert_eq!(4, result.total);
- assert_eq!(2, result.visited);
- for (score, doc) in result.items {
- let DocAddress(seg_id, doc_id) = doc;
- assert!(condition(seg_id, doc_id, score))
- }
- }
-
- #[test]
- fn collection_with_a_marker_smoke() {
- // Doc id=4 on segment=0 had score=0.5
- let marker = Scored::new(0.5, DocAddress(0, 4));
- let mut collector = ConditionalTopSegmentCollector::new(0, 3, marker);
-
- // Every doc with a higher score has appeared already
- collector.collect(7, 0.6);
- collector.collect(5, 0.7);
- assert_eq!(0, collector.len());
-
- // Docs with the same score, but lower id too
- collector.collect(3, 0.5);
- collector.collect(2, 0.5);
- assert_eq!(0, collector.len());
-
- // And, of course, the same doc should not be collected
- collector.collect(4, 0.5);
- assert_eq!(0, collector.len());
-
- // Lower scores are in
- collector.collect(1, 0.0);
- // Same score but higher doc, too
- collector.collect(6, 0.5);
-
- assert_eq!(2, collector.len());
- }
-
- #[test]
- fn fruits_are_merged_correctly() {
- let collector = ConditionalTopCollector::with_limit(5, true);
-
- let merged = collector
- .merge_fruits(vec![
- // S0
- CollectionResult {
- total: 1,
- visited: 1,
- items: vec![(0.5, DocAddress(0, 1))],
- },
- // S1 has a doc that scored the same as S0, so
- // it should only appear *after* the one in S0
- CollectionResult {
- total: 1,
- visited: 1,
- items: vec![(0.5, DocAddress(1, 1)), (0.6, DocAddress(1, 2))],
- },
- // S2 has two evenly scored docs, the one with
- // the lowest internal id should appear first
- CollectionResult {
- total: 1,
- visited: 1,
- items: vec![(0.2, DocAddress(2, 2)), (0.2, DocAddress(2, 1))],
- },
- ])
- .unwrap();
-
- assert_eq!(
- vec![
- (0.6, DocAddress(1, 2)),
- (0.5, DocAddress(0, 1)),
- (0.5, DocAddress(1, 1)),
- (0.2, DocAddress(2, 1)),
- (0.2, DocAddress(2, 2))
- ],
- merged.items
- );
- }
-
- use tantivy::{query::AllQuery, schema, Document, Index, Result};
-
- #[test]
- fn only_collect_even_public_ids() -> Result<()> {
- let mut builder = schema::SchemaBuilder::new();
-
- let id_field = builder.add_u64_field("public_id", schema::FAST);
-
- let index = Index::create_in_ram(builder.build());
-
- let mut writer = index.writer_with_num_threads(1, 50_000_000)?;
-
- const NUM_DOCS: u64 = 10;
- for public_id in 0..NUM_DOCS {
- let mut doc = Document::new();
- doc.add_u64(id_field, public_id);
- writer.add_document(doc);
- }
-
- writer.commit()?;
-
- let reader = index.reader()?;
- let searcher = reader.searcher();
-
- let condition_factory = |reader: &SegmentReader| {
- let id_reader = reader.fast_fields().u64(id_field).unwrap();
-
- move |_segment_id, doc_id, _score| {
- let stored_id = id_reader.get(doc_id);
- stored_id % 2 == 0
- }
- };
- let results = searcher.search(
- &AllQuery,
- &ConditionalTopCollector::with_limit(NUM_DOCS as usize, condition_factory),
- )?;
-
- assert_eq!(5, results.items.len());
-
- Ok(())
- }
-}
Deleted tique/src/top_collector/custom_score.rs
-use tantivy::{
- collector::{Collector, SegmentCollector},
- DocId, Result, Score, SegmentLocalId, SegmentReader,
-};
-
-use super::{
- CheckCondition, CollectionResult, ConditionForSegment, ConditionalTopCollector,
- ConditionalTopSegmentCollector,
-};
-
-pub struct CustomScoreTopCollector<T, C, F>
-where
- C: ConditionForSegment<T>,
-{
- scorer_factory: F,
- condition_factory: C,
- collector: ConditionalTopCollector<T, C>,
-}
-
-impl<T, C, F> CustomScoreTopCollector<T, C, F>
-where
- T: 'static + PartialOrd + Copy + Sync + Send,
- C: ConditionForSegment<T>,
- F: 'static + Sync + ScorerForSegment<T>,
-{
- pub fn new(limit: usize, condition_factory: C, scorer_factory: F) -> Self {
- Self {
- collector: ConditionalTopCollector::with_limit(limit, condition_factory.clone()),
- scorer_factory,
- condition_factory,
- }
- }
-}
-
-pub trait ScorerForSegment<T>: Sync {
- type Type: DocScorer<T>;
- fn for_segment(&self, reader: &SegmentReader) -> Self::Type;
-}
-
-impl<T, C, F> ScorerForSegment<T> for F
-where
- F: 'static + Sync + Send + Fn(&SegmentReader) -> C,
- C: DocScorer<T>,
-{
- type Type = C;
-
- fn for_segment(&self, reader: &SegmentReader) -> Self::Type {
- (self)(reader)
- }
-}
-
-impl<T, C, F> Collector for CustomScoreTopCollector<T, C, F>
-where
- T: 'static + PartialOrd + Copy + Sync + Send,
- C: ConditionForSegment<T> + Sync,
- F: 'static + ScorerForSegment<T>,
-{
- type Fruit = CollectionResult<T>;
- type Child = CustomScoreTopSegmentCollector<T, C::Type, F::Type>;
-
- fn requires_scoring(&self) -> bool {
- false
- }
-
- fn merge_fruits(&self, children: Vec<Self::Fruit>) -> Result<Self::Fruit> {
- Ok(self.collector.merge_many(children))
- }
-
- fn for_segment(
- &self,
- segment_id: SegmentLocalId,
- reader: &SegmentReader,
- ) -> Result<Self::Child> {
- let scorer = self.scorer_factory.for_segment(reader);
- Ok(CustomScoreTopSegmentCollector::new(
- segment_id,
- self.collector.limit,
- self.condition_factory.for_segment(reader),
- scorer,
- ))
- }
-}
-
-pub struct CustomScoreTopSegmentCollector<T, C, F>
-where
- C: CheckCondition<T>,
-{
- scorer: F,
- collector: ConditionalTopSegmentCollector<T, C>,
-}
-
-impl<T, C, F> CustomScoreTopSegmentCollector<T, C, F>
-where
- T: PartialOrd + Copy,
- C: CheckCondition<T>,
- F: DocScorer<T>,
-{
- fn new(segment_id: SegmentLocalId, limit: usize, condition: C, scorer: F) -> Self {
- Self {
- scorer,
- collector: ConditionalTopSegmentCollector::new(segment_id, limit, condition),
- }
- }
-}
-
-impl<T, C, F> SegmentCollector for CustomScoreTopSegmentCollector<T, C, F>
-where
- T: 'static + PartialOrd + Copy + Sync + Send,
- C: CheckCondition<T>,
- F: DocScorer<T>,
-{
- type Fruit = CollectionResult<T>;
-
- fn collect(&mut self, doc: DocId, _: Score) {
- let score = self.scorer.score(doc);
- self.collector.visit(doc, score);
- }
-
- fn harvest(self) -> Self::Fruit {
- self.collector.into_collection_result()
- }
-}
-
-pub trait DocScorer<T>: 'static {
- fn score(&self, doc_id: DocId) -> T;
-}
-
-impl<F, T> DocScorer<T> for F
-where
- F: 'static + Sync + Send + Fn(DocId) -> T,
-{
- fn score(&self, doc_id: DocId) -> T {
- (self)(doc_id)
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- use tantivy::{query::AllQuery, schema::SchemaBuilder, Document, Index};
-
- #[test]
- fn custom_segment_scorer_gets_called() {
- // Use the doc_id as the score
- let mut collector = CustomScoreTopSegmentCollector::new(0, 1, true, |doc_id| doc_id);
-
- // So that whatever we provide as a score
- collector.collect(1, 42.0);
- let res = collector.harvest();
- assert_eq!(1, res.total);
-
- let got = &res.items[0];
- // Is disregarded and doc_id is used instead
- assert_eq!((got.1).1, got.0)
- }
-
- #[test]
- fn custom_top_scorer_integration() -> Result<()> {
- let builder = SchemaBuilder::new();
- let index = Index::create_in_ram(builder.build());
-
- let mut writer = index.writer_with_num_threads(1, 50_000_000)?;
-
- // We add 100 documents to our index
- for _ in 0..100 {
- writer.add_document(Document::new());
- }
-
- writer.commit()?;
-
- let reader = index.reader()?;
- let searcher = reader.searcher();
-
- let colletor = CustomScoreTopCollector::new(2, true, |_: &SegmentReader| {
- // Score is doc_id * 10
- |doc_id: DocId| doc_id * 10
- });
-
- let result = searcher.search(&AllQuery, &colletor)?;
-
- assert_eq!(100, result.total);
- assert_eq!(2, result.items.len());
-
- // So we expect that the highest score is 990
- assert_eq!(result.items[0].0, 990);
- assert_eq!(result.items[1].0, 980);
-
- Ok(())
- }
-}
Deleted tique/src/top_collector/fastfield.rs
-use std::{marker::PhantomData, ops::Neg};
-
-use tantivy::{
- fastfield::{FastFieldReader, FastValue},
- schema::Field,
- DocId, SegmentReader,
-};
-
-use super::{DocScorer, ScorerForSegment};
-
-pub fn descending<T>(field: Field) -> DescendingFastField<T> {
- DescendingFastField(field, PhantomData)
-}
-
-pub fn ascending<T>(field: Field) -> AscendingFastField<T> {
- AscendingFastField(field, PhantomData)
-}
-
-pub struct DescendingFastField<T>(Field, PhantomData<T>);
-
-pub struct AscendingFastField<T>(Field, PhantomData<T>);
-
-macro_rules! impl_scorer_for_segment {
- ($type: ident) => {
- impl ScorerForSegment<$type> for DescendingFastField<$type> {
- type Type = DescendingScorer<$type>;
-
- fn for_segment(&self, reader: &SegmentReader) -> Self::Type {
- let scorer = reader.fast_fields().$type(self.0).expect("Field is FAST");
- DescendingScorer(scorer)
- }
- }
-
- impl ScorerForSegment<$type> for AscendingFastField<$type> {
- type Type = AscendingScorer<$type>;
-
- fn for_segment(&self, reader: &SegmentReader) -> Self::Type {
- let scorer = reader.fast_fields().$type(self.0).expect("Field is FAST");
- AscendingScorer(scorer)
- }
- }
- };
-}
-
-impl_scorer_for_segment!(f64);
-impl_scorer_for_segment!(i64);
-impl_scorer_for_segment!(u64);
-
-pub struct DescendingScorer<T: FastValue>(FastFieldReader<T>);
-
-pub struct AscendingScorer<T: FastValue>(FastFieldReader<T>);
-
-impl<T> DocScorer<T> for DescendingScorer<T>
-where
- T: FastValue + 'static,
-{
- fn score(&self, doc_id: DocId) -> T {
- self.0.get(doc_id)
- }
-}
-
-impl DocScorer<u64> for AscendingScorer<u64> {
- fn score(&self, doc_id: DocId) -> u64 {
- std::u64::MAX - self.0.get(doc_id)
- }
-}
-
-macro_rules! impl_neg_reversed_scorer {
- ($type: ty) => {
- impl DocScorer<$type> for AscendingScorer<$type> {
- fn score(&self, doc_id: DocId) -> $type {
- self.0.get(doc_id).neg()
- }
- }
- };
-}
-
-impl_neg_reversed_scorer!(i64);
-impl_neg_reversed_scorer!(f64);
-
-#[cfg(test)]
-use super::{CollectionResult, CustomScoreTopCollector};
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- use tantivy::{
- query::AllQuery,
- schema::{SchemaBuilder, FAST},
- DocAddress, Document, Index, Result,
- };
-
- fn just_the_ids<T: PartialOrd>(res: CollectionResult<T>) -> Vec<DocId> {
- res.items
- .into_iter()
- .map(|(_score, doc)| {
- let DocAddress(_segment, id) = doc;
- id
- })
- .collect()
- }
-
- macro_rules! check_order_from_sorted_values {
- ($name: ident, $field: ident, $add: ident, $type: ty, $values: expr) => {
- #[test]
- fn $name() -> Result<()> {
- let mut sb = SchemaBuilder::new();
-
- let field = sb.$field("field", FAST);
- let index = Index::create_in_ram(sb.build());
- let mut writer = index.writer_with_num_threads(1, 3_000_000)?;
-
- for v in $values.iter() {
- let mut doc = Document::new();
- doc.$add(field, *v);
- writer.add_document(doc);
- }
-
- writer.commit()?;
-
- let reader = index.reader()?;
- let searcher = reader.searcher();
- let size = $values.len();
- let condition = true;
-
- let collector =
- CustomScoreTopCollector::new(size, condition, descending::<$type>(field));
-
- let reversed_collector =
- CustomScoreTopCollector::new(size, condition, ascending::<$type>(field));
-
- let (top, reversed_top) =
- searcher.search(&AllQuery, &(collector, reversed_collector))?;
-
- let sorted_scores: Vec<$type> =
- top.items.iter().map(|(score, _doc)| *score).collect();
- assert_eq!(
- $values,
- sorted_scores.as_slice(),
- "found scores don't match input"
- );
-
- let ids = just_the_ids(top);
- let mut reversed_ids = just_the_ids(reversed_top);
-
- reversed_ids.reverse();
- assert_eq!(
- ids,
- reversed_ids.as_slice(),
- "should have found the same ids, in reversed order"
- );
-
- Ok(())
- }
- };
- }
-
- check_order_from_sorted_values!(
- u64_field_sort_functionality,
- add_u64_field,
- add_u64,
- u64,
- [3, 2, 1, 0]
- );
-
- check_order_from_sorted_values!(
- i64_field_sort_functionality,
- add_i64_field,
- add_i64,
- i64,
- [3, 2, 1, 0, -1, -2, -3]
- );
-
- check_order_from_sorted_values!(
- f64_field_sort_functionality,
- add_f64_field,
- add_f64,
- f64,
- [100.0, 42.0, 0.71, 0.42]
- );
-}
Deleted tique/src/top_collector/mod.rs
-mod conditional_collector;
-mod custom_score;
-mod topk;
-mod tweaked_score;
-
-pub mod fastfield;
-
-pub use conditional_collector::{
- CheckCondition, CollectionResult, ConditionForSegment, ConditionalTopCollector,
- ConditionalTopSegmentCollector, SearchMarker,
-};
-pub use custom_score::{CustomScoreTopCollector, DocScorer, ScorerForSegment};
-pub use topk::{Scored, TopK};
-pub use tweaked_score::{ModifierForSegment, ScoreModifier, TweakedScoreTopCollector};
Deleted tique/src/top_collector/topk.rs
-use std::{
- cmp::{Ordering, Reverse},
- collections::BinaryHeap,
-};
-
-pub struct TopK<S, D> {
- limit: usize,
- heap: BinaryHeap<Reverse<Scored<S, D>>>,
-}
-
-impl<S: PartialOrd, D: PartialOrd> TopK<S, D> {
- pub fn new(limit: usize) -> Self {
- Self {
- limit,
- heap: BinaryHeap::with_capacity(limit),
- }
- }
-
- pub fn len(&self) -> usize {
- self.heap.len()
- }
-
- pub fn is_empty(&self) -> bool {
- self.heap.is_empty()
- }
-
- pub fn visit(&mut self, score: S, doc: D) {
- if self.heap.len() < self.limit {
- self.heap.push(Reverse(Scored { score, doc }));
- } else if let Some(mut head) = self.heap.peek_mut() {
- if match head.0.score.partial_cmp(&score) {
- Some(Ordering::Equal) => doc < head.0.doc,
- Some(Ordering::Less) => true,
- _ => false,
- } {
- head.0.score = score;
- head.0.doc = doc;
- }
- }
- }
-
- pub fn into_sorted_vec(self) -> Vec<Scored<S, D>> {
- self.heap
- .into_sorted_vec()
- .into_iter()
- .map(|Reverse(item)| item)
- .collect()
- }
-
- pub fn into_vec(self) -> Vec<Scored<S, D>> {
- self.heap
- .into_vec()
- .into_iter()
- .map(|Reverse(item)| item)
- .collect()
- }
-}
-
-// TODO warn about exposing docid,segment_id publicly
-#[derive(Debug, Clone)]
-pub struct Scored<S, D> {
- pub score: S,
- pub doc: D,
-}
-
-impl<S: PartialOrd, D: PartialOrd> Scored<S, D> {
- pub fn new(score: S, doc: D) -> Self {
- Self { score, doc }
- }
-}
-
-impl<S: PartialOrd, D: PartialOrd> PartialOrd for Scored<S, D> {
- fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
- Some(self.cmp(other))
- }
-}
-
-impl<S: PartialOrd, D: PartialOrd> Ord for Scored<S, D> {
- #[inline]
- fn cmp(&self, other: &Self) -> Ordering {
- // Highest score first
- match self.score.partial_cmp(&other.score) {
- Some(Ordering::Equal) | None => {
- // Break even by lowest id
- other.doc.partial_cmp(&self.doc).unwrap_or(Ordering::Equal)
- }
- Some(rest) => rest,
- }
- }
-}
-
-impl<S: PartialOrd, D: PartialOrd> PartialEq for Scored<S, D> {
- fn eq(&self, other: &Self) -> bool {
- self.cmp(other) == Ordering::Equal
- }
-}
-
-impl<S: PartialOrd, D: PartialOrd> Eq for Scored<S, D> {}
-
-#[cfg(test)]
-mod tests {
- use super::{Scored, TopK};
-
- #[test]
- fn not_at_capacity() {
- let mut topk = TopK::new(3);
-
- assert!(topk.is_empty());
-
- topk.visit(0.8, 1);
- topk.visit(0.2, 3);
- topk.visit(0.3, 5);
-
- assert_eq!(3, topk.len());
-
- assert_eq!(
- vec![
- Scored::new(0.8, 1),
- Scored::new(0.3, 5),
- Scored::new(0.2, 3)
- ],
- topk.into_sorted_vec()
- )
- }
-
- #[test]
- fn at_capacity() {
- let mut topk = TopK::new(4);
-
- topk.visit(0.8, 1);
- topk.visit(0.2, 3);
- topk.visit(0.3, 5);
- topk.visit(0.9, 7);
- topk.visit(-0.2, 9);
-
- assert_eq!(4, topk.len());
-
- assert_eq!(
- vec![
- Scored::new(0.9, 7),
- Scored::new(0.8, 1),
- Scored::new(0.3, 5),
- Scored::new(0.2, 3)
- ],
- topk.into_sorted_vec()
- );
- }
-
- #[test]
- fn break_even_scores_by_lowest_doc() {
- let mut topk = TopK::new(5);
- topk.visit(0.1, 3);
- topk.visit(0.1, 1);
- topk.visit(0.1, 6);
- topk.visit(0.5, 5);
- topk.visit(0.5, 4);
- topk.visit(0.1, 2);
- assert_eq!(
- vec![
- Scored::new(0.5, 4),
- Scored::new(0.5, 5),
- Scored::new(0.1, 1),
- Scored::new(0.1, 2),
- Scored::new(0.1, 3),
- ],
- topk.into_sorted_vec()
- );
- }
-}
Deleted tique/src/top_collector/tweaked_score.rs
-use tantivy::{
- collector::{Collector, SegmentCollector},
- DocId, Result, Score, SegmentLocalId, SegmentReader,
-};
-
-use super::{
- CheckCondition, CollectionResult, ConditionForSegment, ConditionalTopCollector,
- ConditionalTopSegmentCollector,
-};
-
-pub trait ScoreModifier<T>: 'static {
- fn modify(&self, doc_id: DocId, score: Score) -> T;
-}
-
-impl<F, T> ScoreModifier<T> for F
-where
- F: 'static + Fn(DocId, Score) -> T,
-{
- fn modify(&self, doc_id: DocId, score: Score) -> T {
- (self)(doc_id, score)
- }
-}
-
-pub trait ModifierForSegment<T>: Sync {
- type Type: ScoreModifier<T>;
- fn for_segment(&self, reader: &SegmentReader) -> Self::Type;
-}
-
-impl<T, C, F> ModifierForSegment<T> for F
-where
- F: 'static + Sync + Send + Fn(&SegmentReader) -> C,
- C: ScoreModifier<T>,
-{
- type Type = C;
-
- fn for_segment(&self, reader: &SegmentReader) -> Self::Type {
- (self)(reader)
- }
-}
-
-pub struct TweakedScoreTopSegmentCollector<T, C, F>
-where
- C: CheckCondition<T>,
-{
- modifier: F,
- collector: ConditionalTopSegmentCollector<T, C>,
-}
-
-impl<T, C, F> TweakedScoreTopSegmentCollector<T, C, F>
-where
- T: PartialOrd + Copy,
- C: CheckCondition<T>,
- F: ScoreModifier<T>,
-{
- fn new(segment_id: SegmentLocalId, limit: usize, condition: C, modifier: F) -> Self {
- Self {
- modifier,
- collector: ConditionalTopSegmentCollector::new(segment_id, limit, condition),
- }
- }
-}
-
-impl<T, C, F> SegmentCollector for TweakedScoreTopSegmentCollector<T, C, F>
-where
- T: 'static + PartialOrd + Copy + Sync + Send,
- C: CheckCondition<T>,
- F: ScoreModifier<T>,
-{
- type Fruit = CollectionResult<T>;
-
- fn collect(&mut self, doc: DocId, score: Score) {
- let score = self.modifier.modify(doc, score);
- self.collector.visit(doc, score);
- }
-
- fn harvest(self) -> Self::Fruit {
- self.collector.into_collection_result()
- }
-}
-
-pub struct TweakedScoreTopCollector<T, C, F>
-where
- C: ConditionForSegment<T>,
-{
- modifier_factory: F,
- condition_factory: C,
- collector: ConditionalTopCollector<T, C>,
-}
-
-impl<T, C, F> TweakedScoreTopCollector<T, C, F>
-where
- T: 'static + PartialOrd + Copy + Sync + Send,
- C: ConditionForSegment<T>,
- F: 'static + Sync + ModifierForSegment<T>,
-{
- pub fn new(limit: usize, condition_factory: C, modifier_factory: F) -> Self {
- Self {
- collector: ConditionalTopCollector::with_limit(limit, condition_factory.clone()),
- modifier_factory,
- condition_factory,
- }
- }
-}
-
-impl<T, C, F> Collector for TweakedScoreTopCollector<T, C, F>
-where
- T: 'static + PartialOrd + Copy + Sync + Send,
- C: ConditionForSegment<T> + Sync,
- F: 'static + ModifierForSegment<T>,
-{
- type Fruit = CollectionResult<T>;
- type Child = TweakedScoreTopSegmentCollector<T, C::Type, F::Type>;
-
- fn requires_scoring(&self) -> bool {
- true
- }
-
- fn merge_fruits(&self, children: Vec<Self::Fruit>) -> Result<Self::Fruit> {
- Ok(self.collector.merge_many(children))
- }
-
- fn for_segment(
- &self,
- segment_id: SegmentLocalId,
- reader: &SegmentReader,
- ) -> Result<Self::Child> {
- let modifier = self.modifier_factory.for_segment(reader);
- Ok(TweakedScoreTopSegmentCollector::new(
- segment_id,
- self.collector.limit,
- self.condition_factory.for_segment(reader),
- modifier,
- ))
- }
-}
-
-#[cfg(test)]
-mod tests {
-
- use super::*;
-
- use tantivy::{query::AllQuery, schema::SchemaBuilder, Document, Index};
-
- #[test]
- fn integration() -> Result<()> {
- let builder = SchemaBuilder::new();
- let index = Index::create_in_ram(builder.build());
-
- let mut writer = index.writer_with_num_threads(1, 3_000_000)?;
-
- for _ in 0..100 {
- writer.add_document(Document::new());
- }
-
- writer.commit()?;
-
- let reader = index.reader()?;
- let searcher = reader.searcher();
-
- let colletor = TweakedScoreTopCollector::new(100, true, |_: &SegmentReader| {
- |doc_id: DocId, score: Score| f64::from(score) * f64::from(doc_id)
- });
-
- let result = searcher.search(&AllQuery, &colletor)?;
-
- assert_eq!(100, result.items.len());
- let mut item_iter = result.items.into_iter();
- let mut last_score = item_iter.next().unwrap().0;
-
- // An AllQuery ends up with every doc scoring the same, so
- // this means highest ids will come first
- for item in item_iter {
- assert!(last_score > item.0);
- last_score = item.0;
- }
-
- Ok(())
- }
-}