Merge branch 'typed_topk'
- Id
- 798980e4aea2b4a2fd9b5a4273f01792179034a7
- Author
- Caio
- Commit time
- 2020-01-21T18:18:30+01:00
Modified cantine/src/index.rs
-use std::{cmp::Ordering, convert::TryFrom, ops::Neg};
+use std::{cmp::Ordering, convert::TryFrom};
use bincode;
use serde::{Deserialize, Serialize};
fastfield::FastFieldReader,
query::Query,
schema::{Field, Schema, SchemaBuilder, Value, FAST, STORED, TEXT},
- DocId, Document, Result, Score, Searcher, SegmentLocalId, SegmentReader, TantivyError,
+ DocAddress, DocId, Document, Result, Score, Searcher, SegmentLocalId, SegmentReader,
+ TantivyError,
};
use crate::model::{
Recipe, RecipeId, Sort,
};
-use tique::top_collector::{
- fastfield, CheckCondition, CollectionResult, ConditionForSegment, ConditionalTopCollector,
- CustomScoreTopCollector, SearchMarker, TweakedScoreTopCollector,
+use tique::conditional_collector::{
+ traits::{CheckCondition, ConditionForSegment},
+ Ascending, CollectionResult, Descending, TopCollector,
};
#[derive(Clone)]
doc
}
- fn addresses_to_ids<T>(
+ fn items_to_ids<T>(
&self,
searcher: &Searcher,
- addresses: &[SearchMarker<T>],
+ addresses: &[(T, DocAddress)],
) -> Result<Vec<RecipeId>> {
let mut items = Vec::with_capacity(addresses.len());
- for addr in addresses.iter() {
- let doc = searcher.doc(addr.doc)?;
+ for (_score, addr) in addresses.iter() {
+ let doc = searcher.doc(*addr)?;
if let Some(&Value::U64(id)) = doc.get_first(self.id) {
items.push(id);
} else {
macro_rules! collect {
($type: ty, $field:ident, $order:ident) => {
if let Some(after) = after {
- let top_collector = CustomScoreTopCollector::new(
- limit,
- after.as_paginator(self.id),
- fastfield::$order(self.features.$field),
- );
+ let top_collector =
+ TopCollector::<$type, $order, _>::new(limit, after.as_paginator(self.id))
+ .top_fast_field(self.features.$field);
self.render::<$type, _>(&searcher, query, top_collector)
} else {
- let top_collector = CustomScoreTopCollector::new(
- limit,
- true,
- fastfield::$order(self.features.$field),
- );
+ let top_collector = TopCollector::<$type, $order, _>::new(limit, true)
+ .top_fast_field(self.features.$field);
self.render::<$type, _>(&searcher, query, top_collector)
+ }
+ };
+
+ ($order:ident) => {
+ if let Some(after) = after {
+ let top_collector =
+ TopCollector::<_, $order, _>::new(limit, after.as_paginator(self.id));
+
+ self.render::<Score, _>(&searcher, query, top_collector)
+ } else {
+ let top_collector = TopCollector::<_, $order, _>::new(limit, true);
+
+ self.render::<Score, _>(&searcher, query, top_collector)
}
};
}
match sort {
- Sort::Relevance => {
- if let Some(after) = after {
- let top_collector =
- ConditionalTopCollector::with_limit(limit, after.as_paginator(self.id));
-
- self.render::<Score, _>(&searcher, query, top_collector)
- } else {
- let top_collector = ConditionalTopCollector::with_limit(limit, true);
-
- self.render::<Score, _>(&searcher, query, top_collector)
- }
- }
- Sort::RelevanceAsc => {
- if let Some(after) = after {
- let top_collector = TweakedScoreTopCollector::new(
- limit,
- after.as_paginator(self.id),
- |_: &SegmentReader| |_doc, score: Score| score.neg(),
- );
-
- self.render::<Score, _>(&searcher, query, top_collector)
- } else {
- let top_collector =
- TweakedScoreTopCollector::new(limit, true, |_: &SegmentReader| {
- |_doc, score: Score| score.neg()
- });
-
- self.render::<Score, _>(&searcher, query, top_collector)
- }
- }
- Sort::NumIngredients => collect!(u64, num_ingredients, descending),
- Sort::InstructionsLength => collect!(u64, instructions_length, descending),
- Sort::TotalTime => collect!(u64, total_time, descending),
- Sort::CookTime => collect!(u64, cook_time, descending),
- Sort::PrepTime => collect!(u64, prep_time, descending),
- Sort::Calories => collect!(u64, calories, descending),
- Sort::FatContent => collect!(f64, fat_content, descending),
- Sort::CarbContent => collect!(f64, carb_content, descending),
- Sort::ProteinContent => collect!(f64, protein_content, descending),
- Sort::NumIngredientsAsc => collect!(u64, num_ingredients, ascending),
- Sort::InstructionsLengthAsc => collect!(u64, instructions_length, ascending),
- Sort::TotalTimeAsc => collect!(u64, total_time, ascending),
- Sort::CookTimeAsc => collect!(u64, cook_time, ascending),
- Sort::PrepTimeAsc => collect!(u64, prep_time, ascending),
- Sort::CaloriesAsc => collect!(u64, calories, ascending),
- Sort::FatContentAsc => collect!(f64, fat_content, ascending),
- Sort::CarbContentAsc => collect!(f64, carb_content, ascending),
- Sort::ProteinContentAsc => collect!(f64, protein_content, ascending),
+ Sort::Relevance => collect!(Descending),
+ Sort::RelevanceAsc => collect!(Ascending),
+ Sort::NumIngredients => collect!(u64, num_ingredients, Descending),
+ Sort::InstructionsLength => collect!(u64, instructions_length, Descending),
+ Sort::TotalTime => collect!(u64, total_time, Descending),
+ Sort::CookTime => collect!(u64, cook_time, Descending),
+ Sort::PrepTime => collect!(u64, prep_time, Descending),
+ Sort::Calories => collect!(u64, calories, Descending),
+ Sort::FatContent => collect!(f64, fat_content, Descending),
+ Sort::CarbContent => collect!(f64, carb_content, Descending),
+ Sort::ProteinContent => collect!(f64, protein_content, Descending),
+ Sort::NumIngredientsAsc => collect!(u64, num_ingredients, Ascending),
+ Sort::InstructionsLengthAsc => collect!(u64, instructions_length, Ascending),
+ Sort::TotalTimeAsc => collect!(u64, total_time, Ascending),
+ Sort::CookTimeAsc => collect!(u64, cook_time, Ascending),
+ Sort::PrepTimeAsc => collect!(u64, prep_time, Ascending),
+ Sort::CaloriesAsc => collect!(u64, calories, Ascending),
+ Sort::FatContentAsc => collect!(f64, fat_content, Ascending),
+ Sort::CarbContentAsc => collect!(f64, carb_content, Ascending),
+ Sort::ProteinContentAsc => collect!(f64, protein_content, Ascending),
}
}
C: Collector<Fruit = CollectionResult<T>>,
{
let result = searcher.search(query, &collector)?;
- let items = self.addresses_to_ids(&searcher, &result.items)?;
+ let items = self.items_to_ids(&searcher, &result.items)?;
let num_items = items.len();
let cursor = if result.visited.saturating_sub(num_items) > 0 {
- let last_score = result.items[num_items - 1].score;
+ let last_score = result.items[num_items - 1].0;
let last_id = items[num_items - 1];
Some(last_score.as_after(last_id))
} else {
where
T: 'static + PartialOrd + Clone,
{
- fn check(&self, _sid: SegmentLocalId, doc_id: DocId, score: T) -> bool {
+ fn check(&self, _sid: SegmentLocalId, doc_id: DocId, score: T, ascending: bool) -> bool {
let recipe_id = self.id_reader.get(doc_id);
match self.ref_score.partial_cmp(&score) {
- Some(Ordering::Greater) => true,
+ Some(Ordering::Greater) => !ascending,
+ Some(Ordering::Less) => ascending,
Some(Ordering::Equal) => self.ref_id < recipe_id,
- _ => false,
+ None => false,
}
}
}
Modified tique/src/lib.rs
+pub mod conditional_collector;
pub mod queryparser;
-pub mod top_collector;
mod derive;
Renamed tique/src/top_collector/conditional_collector.rs to tique/src/conditional_collector/topk.rs
-use std::marker::PhantomData;
-
-use tantivy::{
- collector::{Collector, SegmentCollector},
- DocAddress, DocId, Result, Score, SegmentLocalId, SegmentReader,
+use std::{
+ cmp::{Ordering, Reverse},
+ collections::BinaryHeap,
};
-use super::{Scored, TopK};
+use tantivy::DocId;
-pub trait ConditionForSegment<T>: Clone {
- type Type: CheckCondition<T>;
- fn for_segment(&self, reader: &SegmentReader) -> Self::Type;
+use super::CollectionResult;
+
+pub trait TopK<T, D> {
+ const ASCENDING: bool;
+ fn visit(&mut self, score: T, doc: D);
+ fn into_sorted_vec(self) -> Vec<(T, D)>;
+ fn into_vec(self) -> Vec<(T, D)>;
}
-impl<T, C, F> ConditionForSegment<T> for F
-where
- F: Clone + Fn(&SegmentReader) -> C,
- C: CheckCondition<T>,
-{
- type Type = C;
- fn for_segment(&self, reader: &SegmentReader) -> Self::Type {
- (self)(reader)
- }
+pub trait TopKProvider<T: PartialOrd> {
+ type Child: TopK<T, DocId>;
+
+ fn new_topk(limit: usize) -> Self::Child;
+ fn merge_many(limit: usize, items: Vec<CollectionResult<T>>) -> CollectionResult<T>;
}
-impl<T> ConditionForSegment<T> for bool {
- type Type = bool;
- fn for_segment(&self, _reader: &SegmentReader) -> Self::Type {
- *self
- }
-}
+pub struct Ascending;
-pub trait CheckCondition<T>: 'static + Clone {
- fn check(&self, segment_id: SegmentLocalId, doc_id: DocId, score: T) -> bool;
-}
+impl<T: PartialOrd> TopKProvider<T> for Ascending {
+ type Child = AscendingTopK<T, DocId>;
-impl<T> CheckCondition<T> for bool {
- fn check(&self, _: SegmentLocalId, _: DocId, _: T) -> bool {
- *self
- }
-}
-
-impl<F, T> CheckCondition<T> for F
-where
- F: 'static + Clone + Fn(SegmentLocalId, DocId, T) -> bool,
-{
- fn check(&self, segment_id: SegmentLocalId, doc_id: DocId, score: T) -> bool {
- (self)(segment_id, doc_id, score)
- }
-}
-
-pub type SearchMarker<T> = Scored<T, DocAddress>;
-
-impl<T> CheckCondition<T> for SearchMarker<T>
-where
- T: 'static + PartialOrd + Clone,
-{
- fn check(&self, segment_id: SegmentLocalId, doc_id: DocId, score: T) -> bool {
- // So: only collect items that would come _after_ this marker
- *self > Scored::new(score, DocAddress(segment_id, doc_id))
- }
-}
-
-pub struct ConditionalTopCollector<T, F>
-where
- F: ConditionForSegment<T>,
-{
- pub limit: usize,
- condition_factory: F,
- _marker: PhantomData<T>,
-}
-
-impl<T, F> ConditionalTopCollector<T, F>
-where
- T: PartialOrd,
- F: ConditionForSegment<T>,
-{
- pub fn with_limit(limit: usize, condition_factory: F) -> Self {
- if limit < 1 {
- panic!("Limit must be greater than 0");
- }
- ConditionalTopCollector {
- limit,
- condition_factory,
- _marker: PhantomData,
- }
+ fn new_topk(limit: usize) -> Self::Child {
+ AscendingTopK::new(limit)
}
- pub fn merge_many(&self, children: Vec<CollectionResult<T>>) -> CollectionResult<T> {
- CollectionResult::merge_many(self.limit, children)
- }
-}
+ fn merge_many(limit: usize, items: Vec<CollectionResult<T>>) -> CollectionResult<T> {
+ let mut topk = AscendingTopK::new(limit);
-impl<F> Collector for ConditionalTopCollector<Score, F>
-where
- F: ConditionForSegment<Score> + Sync,
-{
- type Fruit = CollectionResult<Score>;
- type Child = ConditionalTopSegmentCollector<Score, F::Type>;
-
- fn requires_scoring(&self) -> bool {
- true
- }
-
- fn merge_fruits(&self, children: Vec<Self::Fruit>) -> Result<Self::Fruit> {
- Ok(self.merge_many(children))
- }
-
- fn for_segment(
- &self,
- segment_id: SegmentLocalId,
- reader: &SegmentReader,
- ) -> Result<Self::Child> {
- Ok(ConditionalTopSegmentCollector::new(
- segment_id,
- self.limit,
- self.condition_factory.for_segment(reader),
- ))
- }
-}
-
-pub struct ConditionalTopSegmentCollector<T, F>
-where
- F: CheckCondition<T>,
-{
- segment_id: SegmentLocalId,
- collected: TopK<T, DocId>,
- visited: usize,
- total: usize,
- condition: F,
-}
-
-impl<T, F> ConditionalTopSegmentCollector<T, F>
-where
- T: PartialOrd + Copy,
- F: CheckCondition<T>,
-{
- pub fn new(segment_id: SegmentLocalId, limit: usize, condition: F) -> Self {
- ConditionalTopSegmentCollector {
- collected: TopK::new(limit),
- segment_id,
- condition,
- visited: 0,
- total: 0,
- }
- }
-
- #[cfg(test)]
- fn len(&self) -> usize {
- self.collected.len()
- }
-
- #[inline(always)]
- pub fn visit(&mut self, doc: DocId, score: T) {
- self.total += 1;
- if self.condition.check(self.segment_id, doc, score) {
- self.visited += 1;
- self.collected.visit(score, doc);
- }
- }
-
- pub fn into_collection_result(self) -> CollectionResult<T> {
- let segment_id = self.segment_id;
- let items = self
- .collected
- .into_vec()
- .into_iter()
- .map(|Scored { score, doc }| Scored {
- score,
- doc: DocAddress(segment_id, doc),
- })
- .collect();
-
- CollectionResult {
- total: self.total,
- visited: self.visited,
- items,
- }
- }
-}
-
-impl<F> SegmentCollector for ConditionalTopSegmentCollector<Score, F>
-where
- F: CheckCondition<Score>,
-{
- type Fruit = CollectionResult<Score>;
-
- fn collect(&mut self, doc: DocId, score: Score) {
- self.visit(doc, score);
- }
-
- fn harvest(self) -> Self::Fruit {
- self.into_collection_result()
- }
-}
-
-#[derive(Debug)]
-pub struct CollectionResult<T> {
- pub total: usize,
- pub visited: usize,
- pub items: Vec<SearchMarker<T>>,
-}
-
-impl<T: PartialOrd> CollectionResult<T> {
- pub fn merge_many(limit: usize, items: Vec<CollectionResult<T>>) -> CollectionResult<T> {
- let mut topk = TopK::new(limit);
let mut total = 0;
let mut visited = 0;
total += item.total;
visited += item.visited;
- for Scored { score, doc } in item.items {
+ for (score, doc) in item.items {
+ topk.visit(score, doc);
+ }
+ }
+
+ CollectionResult {
+ total,
+ visited,
+ items: topk.into_sorted_vec().into_iter().collect(),
+ }
+ }
+}
+
+pub struct Descending;
+
+impl<T: PartialOrd> TopKProvider<T> for Descending {
+ type Child = DescendingTopK<T, DocId>;
+
+ fn new_topk(limit: usize) -> Self::Child {
+ DescendingTopK {
+ limit,
+ heap: BinaryHeap::with_capacity(limit),
+ }
+ }
+
+ fn merge_many(limit: usize, items: Vec<CollectionResult<T>>) -> CollectionResult<T> {
+ let mut topk = DescendingTopK::new(limit);
+
+ let mut total = 0;
+ let mut visited = 0;
+
+ for item in items {
+ total += item.total;
+ visited += item.visited;
+
+ for (score, doc) in item.items {
topk.visit(score, doc);
}
}
}
}
+pub struct AscendingTopK<S, D> {
+ limit: usize,
+ heap: BinaryHeap<Scored<S, Reverse<D>>>,
+}
+
+pub struct DescendingTopK<S, D> {
+ limit: usize,
+ heap: BinaryHeap<Reverse<Scored<S, D>>>,
+}
+
+impl<T: PartialOrd, D: PartialOrd> AscendingTopK<T, D> {
+ pub(crate) fn new(limit: usize) -> Self {
+ Self {
+ limit,
+ heap: BinaryHeap::with_capacity(limit),
+ }
+ }
+
+ fn visit(&mut self, score: T, doc: D) {
+ let scored = Scored {
+ score,
+ doc: Reverse(doc),
+ };
+ if self.heap.len() < self.limit {
+ self.heap.push(scored);
+ } else if let Some(mut head) = self.heap.peek_mut() {
+ if head.cmp(&scored) == Ordering::Greater {
+ head.score = scored.score;
+ head.doc = scored.doc;
+ }
+ }
+ }
+
+ fn into_sorted_vec(self) -> Vec<(T, D)> {
+ self.heap
+ .into_sorted_vec()
+ .into_iter()
+ .map(|s| (s.score, s.doc.0))
+ .collect()
+ }
+
+ fn into_vec(self) -> Vec<(T, D)> {
+ self.heap
+ .into_vec()
+ .into_iter()
+ .map(|s| (s.score, s.doc.0))
+ .collect()
+ }
+}
+
+impl<T: PartialOrd, D: PartialOrd> DescendingTopK<T, D> {
+ pub(crate) fn new(limit: usize) -> Self {
+ Self {
+ limit,
+ heap: BinaryHeap::with_capacity(limit),
+ }
+ }
+
+ fn visit(&mut self, score: T, doc: D) {
+ let scored = Reverse(Scored { score, doc });
+ if self.heap.len() < self.limit {
+ self.heap.push(scored);
+ } else if let Some(mut head) = self.heap.peek_mut() {
+ if head.cmp(&scored) == Ordering::Greater {
+ head.0.score = scored.0.score;
+ head.0.doc = scored.0.doc;
+ }
+ }
+ }
+
+ fn into_sorted_vec(self) -> Vec<(T, D)> {
+ self.heap
+ .into_sorted_vec()
+ .into_iter()
+ .map(|s| (s.0.score, s.0.doc))
+ .collect()
+ }
+
+ fn into_vec(self) -> Vec<(T, D)> {
+ self.heap
+ .into_vec()
+ .into_iter()
+ .map(|s| (s.0.score, s.0.doc))
+ .collect()
+ }
+}
+
+impl<T: PartialOrd> TopK<T, DocId> for AscendingTopK<T, DocId> {
+ const ASCENDING: bool = true;
+
+ fn visit(&mut self, score: T, doc: DocId) {
+ AscendingTopK::visit(self, score, doc);
+ }
+
+ fn into_sorted_vec(self) -> Vec<(T, DocId)> {
+ AscendingTopK::into_sorted_vec(self)
+ }
+
+ fn into_vec(self) -> Vec<(T, DocId)> {
+ AscendingTopK::into_vec(self)
+ }
+}
+
+impl<T: PartialOrd> TopK<T, DocId> for DescendingTopK<T, DocId> {
+ const ASCENDING: bool = false;
+
+ fn visit(&mut self, score: T, doc: DocId) {
+ DescendingTopK::visit(self, score, doc);
+ }
+
+ fn into_sorted_vec(self) -> Vec<(T, DocId)> {
+ DescendingTopK::into_sorted_vec(self)
+ }
+
+ fn into_vec(self) -> Vec<(T, DocId)> {
+ DescendingTopK::into_vec(self)
+ }
+}
+
+pub(crate) struct Scored<S, D> {
+ pub score: S,
+ pub doc: D,
+}
+
+impl<S: PartialOrd, D: PartialOrd> Scored<S, D> {
+ pub(crate) fn new(score: S, doc: D) -> Self {
+ Self { score, doc }
+ }
+}
+
+impl<S: PartialOrd, D: PartialOrd> PartialOrd for Scored<S, D> {
+ fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
+ Some(self.cmp(other))
+ }
+}
+
+impl<S: PartialOrd, D: PartialOrd> Ord for Scored<S, D> {
+ #[inline]
+ fn cmp(&self, other: &Self) -> Ordering {
+ // Highest score first
+ match self.score.partial_cmp(&other.score) {
+ Some(Ordering::Equal) | None => {
+ // Break even by lowest id
+ other.doc.partial_cmp(&self.doc).unwrap_or(Ordering::Equal)
+ }
+ Some(rest) => rest,
+ }
+ }
+}
+
+impl<S: PartialOrd, D: PartialOrd> PartialEq for Scored<S, D> {
+ fn eq(&self, other: &Self) -> bool {
+ self.cmp(other) == Ordering::Equal
+ }
+}
+
+impl<S: PartialOrd, D: PartialOrd> Eq for Scored<S, D> {}
+
#[cfg(test)]
mod tests {
+
use super::*;
- #[test]
- fn condition_is_checked() {
- const LIMIT: usize = 4;
-
- let mut nil_collector = ConditionalTopSegmentCollector::new(0, LIMIT, false);
-
- let mut top_collector = ConditionalTopSegmentCollector::new(0, LIMIT, true);
-
- let condition = |_sid, doc, _score| doc % 2 == 1;
-
- let mut just_odds = ConditionalTopSegmentCollector::new(0, LIMIT, condition);
-
- for i in 0..4 {
- nil_collector.collect(i, 420.0);
- top_collector.collect(i, 420.0);
- just_odds.collect(i, 420.0);
+ fn check_topk<S, D, K>(mut topk: K, input: Vec<(S, D)>, wanted: Vec<(S, D)>)
+ where
+ S: PartialOrd + std::fmt::Debug,
+ D: PartialOrd + std::fmt::Debug,
+ K: TopK<S, D>,
+ {
+ for (score, id) in input {
+ topk.visit(score, id);
}
- assert_eq!(0, nil_collector.len());
- assert_eq!(4, top_collector.len());
- assert_eq!(2, just_odds.len());
-
- // Verify that the collected items respect the condition
- let result = just_odds.harvest();
- assert_eq!(4, result.total);
- assert_eq!(2, result.visited);
- for scored in result.items {
- let DocAddress(seg_id, doc_id) = scored.doc;
- assert!(condition(seg_id, doc_id, scored.score))
- }
+ assert_eq!(wanted, topk.into_sorted_vec());
}
#[test]
- fn collection_with_a_marker_smoke() {
- // Doc id=4 on segment=0 had score=0.5
- let marker = Scored::new(0.5, DocAddress(0, 4));
- let mut collector = ConditionalTopSegmentCollector::new(0, 3, marker);
+ fn not_at_capacity() {
+ let input = vec![(0.8, 1), (0.2, 3), (0.5, 4), (0.3, 5)];
+ let mut wanted = vec![(0.2, 3), (0.3, 5), (0.5, 4), (0.8, 1)];
- // Every doc with a higher score has appeared already
- collector.collect(7, 0.6);
- collector.collect(5, 0.7);
- assert_eq!(0, collector.len());
+ check_topk(AscendingTopK::new(4), input.clone(), wanted.clone());
- // Docs with the same score, but lower id too
- collector.collect(3, 0.5);
- collector.collect(2, 0.5);
- assert_eq!(0, collector.len());
-
- // And, of course, the same doc should not be collected
- collector.collect(4, 0.5);
- assert_eq!(0, collector.len());
-
- // Lower scores are in
- collector.collect(1, 0.0);
- // Same score but higher doc, too
- collector.collect(6, 0.5);
-
- assert_eq!(2, collector.len());
+ wanted.reverse();
+ check_topk(DescendingTopK::new(4), input, wanted);
}
#[test]
- fn fruits_are_merged_correctly() {
- let collector = ConditionalTopCollector::with_limit(5, true);
+ fn at_capacity() {
+ let input = vec![(0.8, 1), (0.2, 3), (0.3, 5), (0.9, 7), (-0.2, 9)];
- let merged = collector
- .merge_fruits(vec![
- // S0
- CollectionResult {
- total: 1,
- visited: 1,
- items: vec![Scored::new(0.5, DocAddress(0, 1))],
- },
- // S1 has a doc that scored the same as S0, so
- // it should only appear *after* the one in S0
- CollectionResult {
- total: 1,
- visited: 1,
- items: vec![
- Scored::new(0.5, DocAddress(1, 1)),
- Scored::new(0.6, DocAddress(1, 2)),
- ],
- },
- // S2 has two evenly scored docs, the one with
- // the lowest internal id should appear first
- CollectionResult {
- total: 1,
- visited: 1,
- items: vec![
- Scored::new(0.2, DocAddress(2, 2)),
- Scored::new(0.2, DocAddress(2, 1)),
- ],
- },
- ])
- .unwrap();
+ check_topk(
+ AscendingTopK::new(3),
+ input.clone(),
+ vec![(-0.2, 9), (0.2, 3), (0.3, 5)],
+ );
- assert_eq!(
- vec![
- Scored::new(0.6, DocAddress(1, 2)),
- Scored::new(0.5, DocAddress(0, 1)),
- Scored::new(0.5, DocAddress(1, 1)),
- Scored::new(0.2, DocAddress(2, 1)),
- Scored::new(0.2, DocAddress(2, 2))
- ],
- merged.items
+ check_topk(
+ DescendingTopK::new(3),
+ input,
+ vec![(0.9, 7), (0.8, 1), (0.3, 5)],
);
}
- use tantivy::{query::AllQuery, schema, Document, Index, Result};
-
#[test]
- fn only_collect_even_public_ids() -> Result<()> {
- let mut builder = schema::SchemaBuilder::new();
+ fn break_even_scores_by_lowest_doc() {
+ let input = vec![(0.1, 3), (0.1, 1), (0.1, 6), (0.5, 5), (0.5, 4), (0.1, 2)];
- let id_field = builder.add_u64_field("public_id", schema::FAST);
+ check_topk(
+ AscendingTopK::new(5),
+ input.clone(),
+ vec![(0.1, 1), (0.1, 2), (0.1, 3), (0.1, 6), (0.5, 4)],
+ );
- let index = Index::create_in_ram(builder.build());
-
- let mut writer = index.writer_with_num_threads(1, 50_000_000)?;
-
- const NUM_DOCS: u64 = 10;
- for public_id in 0..NUM_DOCS {
- let mut doc = Document::new();
- doc.add_u64(id_field, public_id);
- writer.add_document(doc);
- }
-
- writer.commit()?;
-
- let reader = index.reader()?;
- let searcher = reader.searcher();
-
- let condition_factory = |reader: &SegmentReader| {
- let id_reader = reader.fast_fields().u64(id_field).unwrap();
-
- move |_segment_id, doc_id, _score| {
- let stored_id = id_reader.get(doc_id);
- stored_id % 2 == 0
- }
- };
- let results = searcher.search(
- &AllQuery,
- &ConditionalTopCollector::with_limit(NUM_DOCS as usize, condition_factory),
- )?;
-
- assert_eq!(5, results.items.len());
-
- Ok(())
+ check_topk(
+ DescendingTopK::new(5),
+ input,
+ vec![(0.5, 4), (0.5, 5), (0.1, 1), (0.1, 2), (0.1, 3)],
+ );
}
}
Renamed tique/src/top_collector/fastfield.rs to tique/src/conditional_collector/traits.rs
-use std::{marker::PhantomData, ops::Neg};
+use std::cmp::Ordering;
-use tantivy::{
- fastfield::{FastFieldReader, FastValue},
- schema::Field,
- DocId, SegmentReader,
-};
+use tantivy::{DocAddress, DocId, SegmentLocalId, SegmentReader};
-use super::{DocScorer, ScorerForSegment};
+use super::topk::Scored;
-pub fn descending<T>(field: Field) -> DescendingFastField<T> {
- DescendingFastField(field, PhantomData)
+pub trait ConditionForSegment<T>: Clone {
+ type Type: CheckCondition<T>;
+ fn for_segment(&self, reader: &SegmentReader) -> Self::Type;
}
-pub fn ascending<T>(field: Field) -> AscendingFastField<T> {
- AscendingFastField(field, PhantomData)
-}
-
-pub struct DescendingFastField<T>(Field, PhantomData<T>);
-
-pub struct AscendingFastField<T>(Field, PhantomData<T>);
-
-macro_rules! impl_scorer_for_segment {
- ($type: ident) => {
- impl ScorerForSegment<$type> for DescendingFastField<$type> {
- type Type = DescendingScorer<$type>;
-
- fn for_segment(&self, reader: &SegmentReader) -> Self::Type {
- let scorer = reader.fast_fields().$type(self.0).expect("Field is FAST");
- DescendingScorer(scorer)
- }
- }
-
- impl ScorerForSegment<$type> for AscendingFastField<$type> {
- type Type = AscendingScorer<$type>;
-
- fn for_segment(&self, reader: &SegmentReader) -> Self::Type {
- let scorer = reader.fast_fields().$type(self.0).expect("Field is FAST");
- AscendingScorer(scorer)
- }
- }
- };
-}
-
-impl_scorer_for_segment!(f64);
-impl_scorer_for_segment!(i64);
-impl_scorer_for_segment!(u64);
-
-pub struct DescendingScorer<T: FastValue>(FastFieldReader<T>);
-
-pub struct AscendingScorer<T: FastValue>(FastFieldReader<T>);
-
-impl<T> DocScorer<T> for DescendingScorer<T>
+impl<T, C, F> ConditionForSegment<T> for F
where
- T: FastValue + 'static,
+ F: Clone + Fn(&SegmentReader) -> C,
+ C: CheckCondition<T>,
+{
+ type Type = C;
+ fn for_segment(&self, reader: &SegmentReader) -> Self::Type {
+ (self)(reader)
+ }
+}
+
+impl<T> ConditionForSegment<T> for bool {
+ type Type = bool;
+ fn for_segment(&self, _reader: &SegmentReader) -> Self::Type {
+ *self
+ }
+}
+
+pub trait CheckCondition<T>: 'static + Clone {
+ fn check(&self, segment_id: SegmentLocalId, doc_id: DocId, score: T, ascending: bool) -> bool;
+}
+
+impl<T> CheckCondition<T> for bool {
+ fn check(&self, _: SegmentLocalId, _: DocId, _: T, _: bool) -> bool {
+ *self
+ }
+}
+
+impl<F, T> CheckCondition<T> for F
+where
+ F: 'static + Clone + Fn(SegmentLocalId, DocId, T, bool) -> bool,
+{
+ fn check(&self, segment_id: SegmentLocalId, doc_id: DocId, score: T, ascending: bool) -> bool {
+ (self)(segment_id, doc_id, score, ascending)
+ }
+}
+
+impl<T> CheckCondition<T> for (T, DocAddress)
+where
+ T: 'static + PartialOrd + Clone + Copy,
+{
+ fn check(&self, segment_id: SegmentLocalId, doc_id: DocId, score: T, ascending: bool) -> bool {
+ let wanted = if ascending {
+ Ordering::Less
+ } else {
+ Ordering::Greater
+ };
+
+ Scored::new(self.0, self.1).cmp(&Scored::new(score, DocAddress(segment_id, doc_id)))
+ == wanted
+ }
+}
+
+pub trait ScorerForSegment<T>: Sync {
+ type Type: DocScorer<T>;
+ fn for_segment(&self, reader: &SegmentReader) -> Self::Type;
+}
+
+impl<T, C, F> ScorerForSegment<T> for F
+where
+ F: 'static + Sync + Send + Fn(&SegmentReader) -> C,
+ C: DocScorer<T>,
+{
+ type Type = C;
+
+ fn for_segment(&self, reader: &SegmentReader) -> Self::Type {
+ (self)(reader)
+ }
+}
+
+pub trait DocScorer<T>: 'static {
+ fn score(&self, doc_id: DocId) -> T;
+}
+
+impl<F, T> DocScorer<T> for F
+where
+ F: 'static + Sync + Send + Fn(DocId) -> T,
{
fn score(&self, doc_id: DocId) -> T {
- self.0.get(doc_id)
+ (self)(doc_id)
}
-}
-
-impl DocScorer<u64> for AscendingScorer<u64> {
- fn score(&self, doc_id: DocId) -> u64 {
- std::u64::MAX - self.0.get(doc_id)
- }
-}
-
-macro_rules! impl_neg_reversed_scorer {
- ($type: ty) => {
- impl DocScorer<$type> for AscendingScorer<$type> {
- fn score(&self, doc_id: DocId) -> $type {
- self.0.get(doc_id).neg()
- }
- }
- };
-}
-
-impl_neg_reversed_scorer!(i64);
-impl_neg_reversed_scorer!(f64);
-
-#[cfg(test)]
-use super::{CollectionResult, CustomScoreTopCollector};
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- use tantivy::{
- query::AllQuery,
- schema::{SchemaBuilder, FAST},
- DocAddress, Document, Index, Result,
- };
-
- fn just_the_ids<T: PartialOrd>(res: CollectionResult<T>) -> Vec<DocId> {
- res.items
- .into_iter()
- .map(|item| {
- let DocAddress(_segment, id) = item.doc;
- id
- })
- .collect()
- }
-
- macro_rules! check_order_from_sorted_values {
- ($name: ident, $field: ident, $add: ident, $type: ty, $values: expr) => {
- #[test]
- fn $name() -> Result<()> {
- let mut sb = SchemaBuilder::new();
-
- let field = sb.$field("field", FAST);
- let index = Index::create_in_ram(sb.build());
- let mut writer = index.writer_with_num_threads(1, 3_000_000)?;
-
- for v in $values.iter() {
- let mut doc = Document::new();
- doc.$add(field, *v);
- writer.add_document(doc);
- }
-
- writer.commit()?;
-
- let reader = index.reader()?;
- let searcher = reader.searcher();
- let size = $values.len();
- let condition = true;
-
- let collector =
- CustomScoreTopCollector::new(size, condition, descending::<$type>(field));
-
- let reversed_collector =
- CustomScoreTopCollector::new(size, condition, ascending::<$type>(field));
-
- let (top, reversed_top) =
- searcher.search(&AllQuery, &(collector, reversed_collector))?;
-
- let sorted_scores: Vec<$type> = top.items.iter().map(|r| r.score).collect();
- assert_eq!(
- $values,
- sorted_scores.as_slice(),
- "found scores don't match input"
- );
-
- let ids = just_the_ids(top);
- let mut reversed_ids = just_the_ids(reversed_top);
-
- reversed_ids.reverse();
- assert_eq!(
- ids,
- reversed_ids.as_slice(),
- "should have found the same ids, in reversed order"
- );
-
- Ok(())
- }
- };
- }
-
- check_order_from_sorted_values!(
- u64_field_sort_functionality,
- add_u64_field,
- add_u64,
- u64,
- [3, 2, 1, 0]
- );
-
- check_order_from_sorted_values!(
- i64_field_sort_functionality,
- add_i64_field,
- add_i64,
- i64,
- [3, 2, 1, 0, -1, -2, -3]
- );
-
- check_order_from_sorted_values!(
- f64_field_sort_functionality,
- add_f64_field,
- add_f64,
- f64,
- [100.0, 42.0, 0.71, 0.42]
- );
}
Renamed tique/src/top_collector/custom_score.rs to tique/src/conditional_collector/custom_score.rs
+use std::marker::PhantomData;
+
use tantivy::{
collector::{Collector, SegmentCollector},
DocId, Result, Score, SegmentLocalId, SegmentReader,
};
use super::{
- CheckCondition, CollectionResult, ConditionForSegment, ConditionalTopCollector,
- ConditionalTopSegmentCollector,
+ top_collector::TopSegmentCollector,
+ topk::{TopK, TopKProvider},
+ traits::{CheckCondition, ConditionForSegment, DocScorer, ScorerForSegment},
+ CollectionResult,
};
-pub struct CustomScoreTopCollector<T, C, F>
+pub struct CustomScoreTopCollector<T, P, C, S>
where
+ T: PartialOrd,
+ P: TopKProvider<T>,
C: ConditionForSegment<T>,
{
- scorer_factory: F,
- condition_factory: C,
- collector: ConditionalTopCollector<T, C>,
+ limit: usize,
+ scorer_for_segment: S,
+ condition_for_segment: C,
+ _score: PhantomData<T>,
+ _provider: PhantomData<P>,
}
-impl<T, C, F> CustomScoreTopCollector<T, C, F>
+impl<T, P, C, S> CustomScoreTopCollector<T, P, C, S>
where
- T: 'static + PartialOrd + Copy + Sync + Send,
+ T: PartialOrd,
+ P: TopKProvider<T>,
C: ConditionForSegment<T>,
- F: 'static + Sync + ScorerForSegment<T>,
{
- pub fn new(limit: usize, condition_factory: C, scorer_factory: F) -> Self {
+ pub fn new(limit: usize, condition_for_segment: C, scorer_for_segment: S) -> Self {
Self {
- collector: ConditionalTopCollector::with_limit(limit, condition_factory.clone()),
- scorer_factory,
- condition_factory,
+ limit,
+ scorer_for_segment,
+ condition_for_segment,
+ _score: PhantomData,
+ _provider: PhantomData,
}
}
}
-pub trait ScorerForSegment<T>: Sync {
- type Type: DocScorer<T>;
- fn for_segment(&self, reader: &SegmentReader) -> Self::Type;
-}
-
-impl<T, C, F> ScorerForSegment<T> for F
+impl<T, P, C, S> Collector for CustomScoreTopCollector<T, P, C, S>
where
- F: 'static + Sync + Send + Fn(&SegmentReader) -> C,
- C: DocScorer<T>,
-{
- type Type = C;
-
- fn for_segment(&self, reader: &SegmentReader) -> Self::Type {
- (self)(reader)
- }
-}
-
-impl<T, C, F> Collector for CustomScoreTopCollector<T, C, F>
-where
- T: 'static + PartialOrd + Copy + Sync + Send,
- C: ConditionForSegment<T> + Sync,
- F: 'static + ScorerForSegment<T>,
+ T: 'static + PartialOrd + Copy + Send + Sync,
+ P: 'static + Send + Sync + TopKProvider<T>,
+ C: Sync + ConditionForSegment<T>,
+ S: ScorerForSegment<T>,
{
type Fruit = CollectionResult<T>;
- type Child = CustomScoreTopSegmentCollector<T, C::Type, F::Type>;
+ type Child = CustomScoreTopSegmentCollector<T, C::Type, S::Type, P::Child>;
fn requires_scoring(&self) -> bool {
false
}
fn merge_fruits(&self, children: Vec<Self::Fruit>) -> Result<Self::Fruit> {
- Ok(self.collector.merge_many(children))
+ Ok(P::merge_many(self.limit, children))
}
fn for_segment(
segment_id: SegmentLocalId,
reader: &SegmentReader,
) -> Result<Self::Child> {
- let scorer = self.scorer_factory.for_segment(reader);
+ let scorer = self.scorer_for_segment.for_segment(reader);
Ok(CustomScoreTopSegmentCollector::new(
segment_id,
- self.collector.limit,
- self.condition_factory.for_segment(reader),
+ P::new_topk(self.limit),
scorer,
+ self.condition_for_segment.for_segment(reader),
))
}
}
-pub struct CustomScoreTopSegmentCollector<T, C, F>
+pub struct CustomScoreTopSegmentCollector<T, C, S, K>
where
C: CheckCondition<T>,
+ K: TopK<T, DocId>,
{
- scorer: F,
- collector: ConditionalTopSegmentCollector<T, C>,
+ scorer: S,
+ collector: TopSegmentCollector<T, K, C>,
}
-impl<T, C, F> CustomScoreTopSegmentCollector<T, C, F>
+impl<T, C, S, K> CustomScoreTopSegmentCollector<T, C, S, K>
where
- T: PartialOrd + Copy,
+ T: Copy,
C: CheckCondition<T>,
- F: DocScorer<T>,
+ K: TopK<T, DocId>,
{
- fn new(segment_id: SegmentLocalId, limit: usize, condition: C, scorer: F) -> Self {
+ pub fn new(segment_id: SegmentLocalId, topk: K, scorer: S, condition: C) -> Self {
Self {
scorer,
- collector: ConditionalTopSegmentCollector::new(segment_id, limit, condition),
+ collector: TopSegmentCollector::new(segment_id, topk, condition),
}
}
}
-impl<T, C, F> SegmentCollector for CustomScoreTopSegmentCollector<T, C, F>
+impl<T, C, S, K> SegmentCollector for CustomScoreTopSegmentCollector<T, C, S, K>
where
- T: 'static + PartialOrd + Copy + Sync + Send,
+ T: 'static + PartialOrd + Copy + Send + Sync,
+ K: 'static + TopK<T, DocId>,
C: CheckCondition<T>,
- F: DocScorer<T>,
+ S: DocScorer<T>,
{
type Fruit = CollectionResult<T>;
fn collect(&mut self, doc: DocId, _: Score) {
let score = self.scorer.score(doc);
- self.collector.visit(doc, score);
+ self.collector.collect(doc, score);
}
fn harvest(self) -> Self::Fruit {
}
}
-pub trait DocScorer<T>: 'static {
- fn score(&self, doc_id: DocId) -> T;
-}
-
-impl<F, T> DocScorer<T> for F
-where
- F: 'static + Sync + Send + Fn(DocId) -> T,
-{
- fn score(&self, doc_id: DocId) -> T {
- (self)(doc_id)
- }
-}
-
#[cfg(test)]
mod tests {
use super::*;
+ use crate::conditional_collector::{topk::AscendingTopK, Descending};
use tantivy::{query::AllQuery, schema::SchemaBuilder, Document, Index};
#[test]
fn custom_segment_scorer_gets_called() {
- // Use the doc_id as the score
- let mut collector = CustomScoreTopSegmentCollector::new(0, 1, true, |doc_id| doc_id);
+ let mut collector = CustomScoreTopSegmentCollector::new(
+ 0,
+ AscendingTopK::new(1),
+ // Use the doc_id as the score
+ |doc_id| doc_id,
+ true,
+ );
// So that whatever we provide as a score
collector.collect(1, 42.0);
let got = &res.items[0];
// Is disregarded and doc_id is used instead
- assert_eq!(got.doc.1, got.score)
+ assert_eq!((got.1).1, got.0)
}
#[test]
let builder = SchemaBuilder::new();
let index = Index::create_in_ram(builder.build());
- let mut writer = index.writer_with_num_threads(1, 50_000_000)?;
+ let mut writer = index.writer_with_num_threads(1, 3_000_000)?;
// We add 100 documents to our index
for _ in 0..100 {
let reader = index.reader()?;
let searcher = reader.searcher();
- let colletor = CustomScoreTopCollector::new(2, true, |_: &SegmentReader| {
- // Score is doc_id * 10
- |doc_id: DocId| doc_id * 10
- });
+ let colletor =
+ CustomScoreTopCollector::<_, Descending, _, _>::new(2, true, |_: &SegmentReader| {
+ |doc_id: DocId| u64::from(doc_id * 10)
+ });
let result = searcher.search(&AllQuery, &colletor)?;
assert_eq!(2, result.items.len());
// So we expect that the highest score is 990
- assert_eq!(result.items[0].score, 990);
- assert_eq!(result.items[1].score, 980);
+ assert_eq!(result.items[0].0, 990);
+ assert_eq!(result.items[1].0, 980);
Ok(())
}
Created tique/src/conditional_collector/mod.rs
+mod custom_score;
+mod top_collector;
+mod topk;
+
+pub mod traits;
+
+pub use custom_score::CustomScoreTopCollector;
+pub use top_collector::{CollectionResult, TopCollector};
+pub use topk::{Ascending, Descending};
Created tique/src/conditional_collector/top_collector.rs
+use std::marker::PhantomData;
+
+use tantivy::{
+ collector::{Collector, SegmentCollector},
+ DocAddress, DocId, Result, Score, SegmentLocalId, SegmentReader,
+};
+
+use super::{
+ topk::{TopK, TopKProvider},
+ traits::{CheckCondition, ConditionForSegment},
+ CustomScoreTopCollector,
+};
+
+pub struct TopCollector<T, P, CF> {
+ limit: usize,
+ condition_for_segment: CF,
+ _score: PhantomData<T>,
+ _provider: PhantomData<P>,
+}
+
+impl<T, P, CF> TopCollector<T, P, CF>
+where
+ T: PartialOrd,
+ P: TopKProvider<T>,
+ CF: ConditionForSegment<T>,
+{
+ pub fn new(limit: usize, condition_for_segment: CF) -> Self {
+ if limit < 1 {
+ panic!("Limit must be greater than 0");
+ }
+ TopCollector {
+ limit,
+ condition_for_segment,
+ _score: PhantomData,
+ _provider: PhantomData,
+ }
+ }
+}
+
+macro_rules! impl_top_fast_field {
+ ($type: ident, $err: literal) => {
+ impl<P, CF> TopCollector<$type, P, CF>
+ where
+ P: 'static + Send + Sync + TopKProvider<$type>,
+ CF: Send + Sync + ConditionForSegment<$type>,
+ {
+ pub fn top_fast_field(
+ self,
+ field: tantivy::schema::Field,
+ ) -> impl Collector<Fruit = CollectionResult<$type>> {
+ let scorer_for_segment = move |reader: &SegmentReader| {
+ let ff = reader.fast_fields().$type(field).expect($err);
+ move |doc_id| ff.get(doc_id)
+ };
+ CustomScoreTopCollector::<$type, P, _, _>::new(
+ self.limit,
+ self.condition_for_segment,
+ scorer_for_segment,
+ )
+ }
+ }
+ };
+}
+
+impl_top_fast_field!(u64, "Field is not a fast u64 field");
+impl_top_fast_field!(i64, "Field is not a fast i64 field");
+impl_top_fast_field!(f64, "Field is not a fast f64 field");
+
+impl<P, CF> Collector for TopCollector<Score, P, CF>
+where
+ P: 'static + Send + Sync + TopKProvider<Score>,
+ CF: Sync + ConditionForSegment<Score>,
+{
+ type Fruit = CollectionResult<Score>;
+ type Child = TopSegmentCollector<Score, P::Child, CF::Type>;
+
+ fn requires_scoring(&self) -> bool {
+ true
+ }
+
+ fn merge_fruits(&self, children: Vec<Self::Fruit>) -> Result<Self::Fruit> {
+ Ok(P::merge_many(self.limit, children))
+ }
+
+ fn for_segment(
+ &self,
+ segment_id: SegmentLocalId,
+ reader: &SegmentReader,
+ ) -> Result<Self::Child> {
+ Ok(TopSegmentCollector::new(
+ segment_id,
+ P::new_topk(self.limit),
+ self.condition_for_segment.for_segment(reader),
+ ))
+ }
+}
+
+pub struct TopSegmentCollector<T, K, C> {
+ total: usize,
+ visited: usize,
+ segment_id: SegmentLocalId,
+ topk: K,
+ condition: C,
+ _marker: PhantomData<T>,
+}
+
+impl<T, K, C> TopSegmentCollector<T, K, C>
+where
+ T: Copy,
+ K: TopK<T, DocId>,
+ C: CheckCondition<T>,
+{
+ pub fn new(segment_id: SegmentLocalId, topk: K, condition: C) -> Self {
+ Self {
+ total: 0,
+ visited: 0,
+ segment_id,
+ topk,
+ condition,
+ _marker: PhantomData,
+ }
+ }
+
+ #[cfg(test)]
+ fn into_topk(self) -> K {
+ self.topk
+ }
+
+ pub fn collect(&mut self, doc: DocId, score: T) {
+ self.total += 1;
+ if self
+ .condition
+ .check(self.segment_id, doc, score, K::ASCENDING)
+ {
+ self.visited += 1;
+ self.topk.visit(score, doc);
+ }
+ }
+
+ pub fn into_collection_result(self) -> CollectionResult<T> {
+ let segment_id = self.segment_id;
+ let items = self
+ .topk
+ .into_vec()
+ .into_iter()
+ .map(|(score, doc)| (score, DocAddress(segment_id, doc)))
+ .collect();
+
+ // XXX This is unsorted. It's ok because we sort during
+ // merge, but using the same time to mean two things is
+ // rather confusing
+ CollectionResult {
+ total: self.total,
+ visited: self.visited,
+ items,
+ }
+ }
+}
+
+impl<K, C> SegmentCollector for TopSegmentCollector<Score, K, C>
+where
+ K: 'static + TopK<Score, DocId>,
+ C: CheckCondition<Score>,
+{
+ type Fruit = CollectionResult<Score>;
+
+ fn collect(&mut self, doc: DocId, score: Score) {
+ TopSegmentCollector::collect(self, doc, score)
+ }
+
+ fn harvest(self) -> Self::Fruit {
+ TopSegmentCollector::into_collection_result(self)
+ }
+}
+
+#[derive(Debug)]
+pub struct CollectionResult<T> {
+ pub total: usize,
+ pub visited: usize,
+ pub items: Vec<(T, DocAddress)>,
+}
+
+#[cfg(test)]
+mod tests {
+
+ use super::*;
+ use crate::conditional_collector::{
+ topk::{AscendingTopK, DescendingTopK},
+ Ascending, Descending,
+ };
+
+ use tantivy::{
+ query::{AllQuery, TermQuery},
+ schema, Document, Index, Result, Term,
+ };
+
+ #[test]
+ fn condition_is_checked() {
+ const LIMIT: usize = 4;
+
+ let mut nil_collector = TopSegmentCollector::new(0, AscendingTopK::new(LIMIT), false);
+
+ let mut top_collector = TopSegmentCollector::new(0, AscendingTopK::new(LIMIT), true);
+
+ let condition = |_sid, doc, _score, _asc| doc % 2 == 1;
+
+ let mut just_odds = TopSegmentCollector::new(0, AscendingTopK::new(LIMIT), condition);
+
+ for i in 0..4 {
+ nil_collector.collect(i, 420.0);
+ top_collector.collect(i, 420.0);
+ just_odds.collect(i, 420.0);
+ }
+
+ assert_eq!(0, nil_collector.harvest().items.len());
+ assert_eq!(4, top_collector.harvest().items.len());
+
+ // Verify that the collected items respect the condition
+ let result = just_odds.harvest();
+ assert_eq!(4, result.total);
+ assert_eq!(2, result.items.len());
+ for (score, doc) in result.items {
+ let DocAddress(seg_id, doc_id) = doc;
+ assert!(condition(seg_id, doc_id, score, true))
+ }
+ }
+
+ fn check_segment_collector<K, C>(
+ topk: K,
+ condition: C,
+ input: Vec<(Score, DocId)>,
+ wanted: Vec<(Score, DocId)>,
+ ) where
+ K: TopK<Score, DocId> + 'static,
+ C: CheckCondition<Score>,
+ {
+ let mut collector = TopSegmentCollector::new(0, topk, condition);
+
+ for (score, id) in input {
+ collector.collect(id, score);
+ }
+
+ assert_eq!(wanted, collector.into_topk().into_sorted_vec());
+ }
+
+ #[test]
+ fn collection_with_a_marker_smoke() {
+ // XXX property test maybe? Essentially we are creating
+ // a Vec<(Score, DocId)> sorted as `Scored` would,
+ // then we pick an arbitrary position to pivot and
+ // expect the DescendingTopK to pick everything below
+ // and the AscendingTopK to pick everything above
+ let marker = (0.5, DocAddress(0, 4));
+
+ check_segment_collector(
+ DescendingTopK::new(10),
+ marker,
+ vec![
+ // Every doc with a higher score has appeared already
+ (0.6, 7),
+ (0.7, 5),
+ // Docs with the same score, but lower id too
+ (0.5, 3),
+ (0.5, 2),
+ // [pivot] And, of course, the same doc should not be collected
+ (0.5, 4),
+ // Lower scores are in
+ (0.0, 1),
+ // Same score but higher doc, too
+ (0.5, 6),
+ ],
+ vec![(0.5, 6), (0.0, 1)],
+ );
+
+ check_segment_collector(
+ AscendingTopK::new(10),
+ marker,
+ vec![
+ // Every doc with a higher score should be picked
+ (0.6, 7),
+ (0.7, 5),
+ // Same score but lower id as well
+ (0.5, 3),
+ (0.5, 2),
+ // [pivot] The same doc should not be collected
+ (0.5, 4),
+ // Docs with lower scores are discarded
+ (0.0, 1),
+ // Same score but higher doc is discaraded too
+ (0.5, 6),
+ ],
+ vec![(0.5, 2), (0.5, 3), (0.6, 7), (0.7, 5)],
+ );
+ }
+
+ #[test]
+ fn collection_ordering_integration() -> Result<()> {
+ let mut builder = schema::SchemaBuilder::new();
+
+ let text_field = builder.add_text_field("text", schema::TEXT);
+
+ let index = Index::create_in_ram(builder.build());
+ let mut writer = index.writer_with_num_threads(1, 3_000_000)?;
+
+ let add_doc = |text: &str| {
+ let mut doc = Document::new();
+ doc.add_text(text_field, text);
+ writer.add_document(doc);
+ };
+
+ const NUM_DOCS: usize = 3;
+ add_doc("the first doc is simple");
+ add_doc("the second doc is a bit larger");
+ add_doc("and the third document is rubbish");
+
+ writer.commit()?;
+
+ let reader = index.reader()?;
+ let searcher = reader.searcher();
+
+ let collector_asc = TopCollector::<_, Ascending, _>::new(NUM_DOCS, true);
+ let collector_desc = TopCollector::<_, Descending, _>::new(NUM_DOCS, true);
+
+ // Query for "the", which matches all docs and yields
+ // a distinct score for each
+ let query = TermQuery::new(
+ Term::from_field_text(text_field, "the"),
+ schema::IndexRecordOption::WithFreqsAndPositions,
+ );
+ let (asc, desc) = searcher.search(&query, &(collector_asc, collector_desc))?;
+
+ assert_eq!(NUM_DOCS, asc.items.len());
+ assert_eq!(NUM_DOCS, desc.items.len());
+
+ let asc_scores = asc
+ .items
+ .iter()
+ .map(|(score, _doc)| score)
+ .collect::<Vec<_>>();
+
+ let mut prev = None;
+ for score in &asc_scores {
+ if let Some(previous) = prev {
+ assert!(previous < score, "The scores should be ascending");
+ }
+ prev = Some(score)
+ }
+
+ let mut desc_scores = desc
+ .items
+ .iter()
+ .map(|(score, _doc)| score)
+ .collect::<Vec<_>>();
+
+ desc_scores.reverse();
+ assert_eq!(asc_scores, desc_scores);
+
+ Ok(())
+ }
+
+ #[test]
+ fn fast_field_collection() -> Result<()> {
+ let mut builder = schema::SchemaBuilder::new();
+
+ let field = builder.add_f64_field("field", schema::FAST);
+
+ let index = Index::create_in_ram(builder.build());
+ let mut writer = index.writer_with_num_threads(1, 3_000_000)?;
+
+ const NUM_DOCS: usize = 100;
+ for v in 0..NUM_DOCS {
+ let mut doc = Document::new();
+ doc.add_f64(field, f64::from(v as u32));
+ writer.add_document(doc);
+ }
+
+ writer.commit()?;
+
+ let reader = index.reader()?;
+ let searcher = reader.searcher();
+
+ let collector_asc =
+ TopCollector::<f64, Ascending, _>::new(NUM_DOCS, true).top_fast_field(field);
+ let collector_desc =
+ TopCollector::<f64, Descending, _>::new(NUM_DOCS, true).top_fast_field(field);
+
+ let (top_asc, mut top_desc) =
+ searcher.search(&AllQuery, &(collector_asc, collector_desc))?;
+
+ assert_eq!(NUM_DOCS, top_asc.items.len());
+
+ top_desc.items.reverse();
+ assert_eq!(top_asc.items, top_desc.items);
+
+ Ok(())
+ }
+}
Deleted tique/src/top_collector/mod.rs
-mod conditional_collector;
-mod custom_score;
-mod topk;
-mod tweaked_score;
-
-pub mod fastfield;
-
-pub use conditional_collector::{
- CheckCondition, CollectionResult, ConditionForSegment, ConditionalTopCollector,
- ConditionalTopSegmentCollector, SearchMarker,
-};
-pub use custom_score::{CustomScoreTopCollector, DocScorer, ScorerForSegment};
-pub use topk::{Scored, TopK};
-pub use tweaked_score::{ModifierForSegment, ScoreModifier, TweakedScoreTopCollector};
Deleted tique/src/top_collector/topk.rs
-use std::{
- cmp::{Ordering, Reverse},
- collections::BinaryHeap,
-};
-
-pub struct TopK<S, D> {
- limit: usize,
- heap: BinaryHeap<Reverse<Scored<S, D>>>,
-}
-
-impl<S: PartialOrd, D: PartialOrd> TopK<S, D> {
- pub fn new(limit: usize) -> Self {
- Self {
- limit,
- heap: BinaryHeap::with_capacity(limit),
- }
- }
-
- pub fn len(&self) -> usize {
- self.heap.len()
- }
-
- pub fn is_empty(&self) -> bool {
- self.heap.is_empty()
- }
-
- pub fn visit(&mut self, score: S, doc: D) {
- if self.heap.len() < self.limit {
- self.heap.push(Reverse(Scored { score, doc }));
- } else if let Some(mut head) = self.heap.peek_mut() {
- if match head.0.score.partial_cmp(&score) {
- Some(Ordering::Equal) => doc < head.0.doc,
- Some(Ordering::Less) => true,
- _ => false,
- } {
- head.0.score = score;
- head.0.doc = doc;
- }
- }
- }
-
- pub fn into_sorted_vec(self) -> Vec<Scored<S, D>> {
- self.heap
- .into_sorted_vec()
- .into_iter()
- .map(|Reverse(item)| item)
- .collect()
- }
-
- pub fn into_vec(self) -> Vec<Scored<S, D>> {
- self.heap
- .into_vec()
- .into_iter()
- .map(|Reverse(item)| item)
- .collect()
- }
-}
-
-// TODO warn about exposing docid,segment_id publicly
-#[derive(Debug, Clone)]
-pub struct Scored<S, D> {
- pub score: S,
- pub doc: D,
-}
-
-impl<S: PartialOrd, D: PartialOrd> Scored<S, D> {
- pub fn new(score: S, doc: D) -> Self {
- Self { score, doc }
- }
-}
-
-impl<S: PartialOrd, D: PartialOrd> PartialOrd for Scored<S, D> {
- fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
- Some(self.cmp(other))
- }
-}
-
-impl<S: PartialOrd, D: PartialOrd> Ord for Scored<S, D> {
- #[inline]
- fn cmp(&self, other: &Self) -> Ordering {
- // Highest score first
- match self.score.partial_cmp(&other.score) {
- Some(Ordering::Equal) | None => {
- // Break even by lowest id
- other.doc.partial_cmp(&self.doc).unwrap_or(Ordering::Equal)
- }
- Some(rest) => rest,
- }
- }
-}
-
-impl<S: PartialOrd, D: PartialOrd> PartialEq for Scored<S, D> {
- fn eq(&self, other: &Self) -> bool {
- self.cmp(other) == Ordering::Equal
- }
-}
-
-impl<S: PartialOrd, D: PartialOrd> Eq for Scored<S, D> {}
-
-#[cfg(test)]
-mod tests {
- use super::{Scored, TopK};
-
- #[test]
- fn not_at_capacity() {
- let mut topk = TopK::new(3);
-
- assert!(topk.is_empty());
-
- topk.visit(0.8, 1);
- topk.visit(0.2, 3);
- topk.visit(0.3, 5);
-
- assert_eq!(3, topk.len());
-
- assert_eq!(
- vec![
- Scored::new(0.8, 1),
- Scored::new(0.3, 5),
- Scored::new(0.2, 3)
- ],
- topk.into_sorted_vec()
- )
- }
-
- #[test]
- fn at_capacity() {
- let mut topk = TopK::new(4);
-
- topk.visit(0.8, 1);
- topk.visit(0.2, 3);
- topk.visit(0.3, 5);
- topk.visit(0.9, 7);
- topk.visit(-0.2, 9);
-
- assert_eq!(4, topk.len());
-
- assert_eq!(
- vec![
- Scored::new(0.9, 7),
- Scored::new(0.8, 1),
- Scored::new(0.3, 5),
- Scored::new(0.2, 3)
- ],
- topk.into_sorted_vec()
- );
- }
-
- #[test]
- fn break_even_scores_by_lowest_doc() {
- let mut topk = TopK::new(5);
- topk.visit(0.1, 3);
- topk.visit(0.1, 1);
- topk.visit(0.1, 6);
- topk.visit(0.5, 5);
- topk.visit(0.5, 4);
- topk.visit(0.1, 2);
- assert_eq!(
- vec![
- Scored::new(0.5, 4),
- Scored::new(0.5, 5),
- Scored::new(0.1, 1),
- Scored::new(0.1, 2),
- Scored::new(0.1, 3),
- ],
- topk.into_sorted_vec()
- );
- }
-}
Deleted tique/src/top_collector/tweaked_score.rs
-use tantivy::{
- collector::{Collector, SegmentCollector},
- DocId, Result, Score, SegmentLocalId, SegmentReader,
-};
-
-use super::{
- CheckCondition, CollectionResult, ConditionForSegment, ConditionalTopCollector,
- ConditionalTopSegmentCollector,
-};
-
-pub trait ScoreModifier<T>: 'static {
- fn modify(&self, doc_id: DocId, score: Score) -> T;
-}
-
-impl<F, T> ScoreModifier<T> for F
-where
- F: 'static + Fn(DocId, Score) -> T,
-{
- fn modify(&self, doc_id: DocId, score: Score) -> T {
- (self)(doc_id, score)
- }
-}
-
-pub trait ModifierForSegment<T>: Sync {
- type Type: ScoreModifier<T>;
- fn for_segment(&self, reader: &SegmentReader) -> Self::Type;
-}
-
-impl<T, C, F> ModifierForSegment<T> for F
-where
- F: 'static + Sync + Send + Fn(&SegmentReader) -> C,
- C: ScoreModifier<T>,
-{
- type Type = C;
-
- fn for_segment(&self, reader: &SegmentReader) -> Self::Type {
- (self)(reader)
- }
-}
-
-pub struct TweakedScoreTopSegmentCollector<T, C, F>
-where
- C: CheckCondition<T>,
-{
- modifier: F,
- collector: ConditionalTopSegmentCollector<T, C>,
-}
-
-impl<T, C, F> TweakedScoreTopSegmentCollector<T, C, F>
-where
- T: PartialOrd + Copy,
- C: CheckCondition<T>,
- F: ScoreModifier<T>,
-{
- fn new(segment_id: SegmentLocalId, limit: usize, condition: C, modifier: F) -> Self {
- Self {
- modifier,
- collector: ConditionalTopSegmentCollector::new(segment_id, limit, condition),
- }
- }
-}
-
-impl<T, C, F> SegmentCollector for TweakedScoreTopSegmentCollector<T, C, F>
-where
- T: 'static + PartialOrd + Copy + Sync + Send,
- C: CheckCondition<T>,
- F: ScoreModifier<T>,
-{
- type Fruit = CollectionResult<T>;
-
- fn collect(&mut self, doc: DocId, score: Score) {
- let score = self.modifier.modify(doc, score);
- self.collector.visit(doc, score);
- }
-
- fn harvest(self) -> Self::Fruit {
- self.collector.into_collection_result()
- }
-}
-
-pub struct TweakedScoreTopCollector<T, C, F>
-where
- C: ConditionForSegment<T>,
-{
- modifier_factory: F,
- condition_factory: C,
- collector: ConditionalTopCollector<T, C>,
-}
-
-impl<T, C, F> TweakedScoreTopCollector<T, C, F>
-where
- T: 'static + PartialOrd + Copy + Sync + Send,
- C: ConditionForSegment<T>,
- F: 'static + Sync + ModifierForSegment<T>,
-{
- pub fn new(limit: usize, condition_factory: C, modifier_factory: F) -> Self {
- Self {
- collector: ConditionalTopCollector::with_limit(limit, condition_factory.clone()),
- modifier_factory,
- condition_factory,
- }
- }
-}
-
-impl<T, C, F> Collector for TweakedScoreTopCollector<T, C, F>
-where
- T: 'static + PartialOrd + Copy + Sync + Send,
- C: ConditionForSegment<T> + Sync,
- F: 'static + ModifierForSegment<T>,
-{
- type Fruit = CollectionResult<T>;
- type Child = TweakedScoreTopSegmentCollector<T, C::Type, F::Type>;
-
- fn requires_scoring(&self) -> bool {
- true
- }
-
- fn merge_fruits(&self, children: Vec<Self::Fruit>) -> Result<Self::Fruit> {
- Ok(self.collector.merge_many(children))
- }
-
- fn for_segment(
- &self,
- segment_id: SegmentLocalId,
- reader: &SegmentReader,
- ) -> Result<Self::Child> {
- let modifier = self.modifier_factory.for_segment(reader);
- Ok(TweakedScoreTopSegmentCollector::new(
- segment_id,
- self.collector.limit,
- self.condition_factory.for_segment(reader),
- modifier,
- ))
- }
-}
-
-#[cfg(test)]
-mod tests {
-
- use super::*;
-
- use tantivy::{query::AllQuery, schema::SchemaBuilder, Document, Index};
-
- #[test]
- fn integration() -> Result<()> {
- let builder = SchemaBuilder::new();
- let index = Index::create_in_ram(builder.build());
-
- let mut writer = index.writer_with_num_threads(1, 3_000_000)?;
-
- for _ in 0..100 {
- writer.add_document(Document::new());
- }
-
- writer.commit()?;
-
- let reader = index.reader()?;
- let searcher = reader.searcher();
-
- let colletor = TweakedScoreTopCollector::new(100, true, |_: &SegmentReader| {
- |doc_id: DocId, score: Score| f64::from(score) * f64::from(doc_id)
- });
-
- let result = searcher.search(&AllQuery, &colletor)?;
-
- assert_eq!(100, result.items.len());
- let mut item_iter = result.items.into_iter();
- let mut last_score = item_iter.next().unwrap().score;
-
- // An AllQuery ends up with every doc scoring the same, so
- // this means highest ids will come first
- for item in item_iter {
- assert!(last_score > item.score);
- last_score = item.score;
- }
-
- Ok(())
- }
-}