Add support for (reversible) fast field scoring
This patch replaces the original `ordered_by_XXX_fast_field` with a more explicit `fastfield::{ascending,descending}` idiom. Essentially, this is a very verbose way of encoding a runtime boolean check into the type system.
- Id
- 1c1617a246ceeef6853aef672fa7de58898e07c6
- Author
- Caio
- Commit time
- 2020-01-09T19:10:53+01:00
Modified tique/src/top_collector/mod.rs
mod conditional_collector;
mod custom_score;
-mod field;
mod topk;
+
+pub mod fastfield;
pub use conditional_collector::{
CheckCondition, CollectionResult, ConditionForSegment, ConditionalTopCollector,
ConditionalTopSegmentCollector, SearchMarker,
};
-pub use custom_score::CustomScoreTopCollector;
-pub use field::{ordered_by_f64_fast_field, ordered_by_i64_fast_field, ordered_by_u64_fast_field};
+pub use custom_score::{CustomScoreTopCollector, DocScorer, ScorerForSegment};
pub use topk::{Scored, TopK};
Created tique/src/top_collector/fastfield.rs
+use std::{marker::PhantomData, ops::Neg};
+
+use tantivy::{
+ fastfield::{FastFieldReader, FastValue},
+ schema::Field,
+ DocId, SegmentReader,
+};
+
+use super::{DocScorer, ScorerForSegment};
+
+pub struct PlainDocScorer<T: FastValue>(FastFieldReader<T>);
+pub struct ReversedDocScorer<T: FastValue>(FastFieldReader<T>);
+
+impl<T> DocScorer<T> for PlainDocScorer<T>
+where
+ T: FastValue + 'static,
+{
+ fn score(&self, doc_id: DocId) -> T {
+ self.0.get(doc_id)
+ }
+}
+
+impl DocScorer<u64> for ReversedDocScorer<u64> {
+ fn score(&self, doc_id: DocId) -> u64 {
+ std::u64::MAX - self.0.get(doc_id)
+ }
+}
+
+macro_rules! impl_neg_reversed_scorer {
+ ($type: ty) => {
+ impl DocScorer<$type> for ReversedDocScorer<$type> {
+ fn score(&self, doc_id: DocId) -> $type {
+ self.0.get(doc_id).neg()
+ }
+ }
+ };
+}
+
+impl_neg_reversed_scorer!(i64);
+impl_neg_reversed_scorer!(f64);
+
+pub struct DescendingFastFieldScorer<T>(Field, PhantomData<T>);
+pub struct AscendingFastFieldScorer<T>(Field, PhantomData<T>);
+
+pub fn descending<T>(field: Field) -> DescendingFastFieldScorer<T> {
+ DescendingFastFieldScorer(field, PhantomData)
+}
+
+pub fn ascending<T>(field: Field) -> AscendingFastFieldScorer<T> {
+ AscendingFastFieldScorer(field, PhantomData)
+}
+
+macro_rules! impl_scorer_for_segment {
+ ($type: ident) => {
+ impl ScorerForSegment<$type> for DescendingFastFieldScorer<$type> {
+ type Type = PlainDocScorer<$type>;
+
+ fn for_segment(&self, reader: &SegmentReader) -> Self::Type {
+ let scorer = reader.fast_fields().$type(self.0).expect("Field is FAST");
+ PlainDocScorer(scorer)
+ }
+ }
+
+ impl ScorerForSegment<$type> for AscendingFastFieldScorer<$type> {
+ type Type = ReversedDocScorer<$type>;
+
+ fn for_segment(&self, reader: &SegmentReader) -> Self::Type {
+ let scorer = reader.fast_fields().$type(self.0).expect("Field is FAST");
+ ReversedDocScorer(scorer)
+ }
+ }
+ };
+}
+
+impl_scorer_for_segment!(f64);
+impl_scorer_for_segment!(i64);
+impl_scorer_for_segment!(u64);
+
+#[cfg(test)]
+use super::{CollectionResult, CustomScoreTopCollector};
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use tantivy::{
+ query::AllQuery,
+ schema::{SchemaBuilder, FAST},
+ DocAddress, Document, Index, Result,
+ };
+
+ fn just_the_ids<T: PartialOrd>(res: CollectionResult<T>) -> Vec<DocId> {
+ res.items
+ .into_iter()
+ .map(|item| {
+ let DocAddress(_segment, id) = item.doc;
+ id
+ })
+ .collect()
+ }
+
+ macro_rules! check_order_from_sorted_values {
+ ($name: ident, $field: ident, $add: ident, $type: ty, $values: expr) => {
+ #[test]
+ fn $name() -> Result<()> {
+ let mut sb = SchemaBuilder::new();
+
+ let field = sb.$field("field", FAST);
+ let index = Index::create_in_ram(sb.build());
+ let mut writer = index.writer_with_num_threads(1, 3_000_000)?;
+
+ for v in $values.into_iter() {
+ let mut doc = Document::new();
+ doc.$add(field, *v);
+ writer.add_document(doc);
+ }
+
+ writer.commit()?;
+
+ let reader = index.reader()?;
+ let searcher = reader.searcher();
+ let size = $values.len();
+ let condition = true;
+
+ let collector =
+ CustomScoreTopCollector::new(size, condition, descending::<$type>(field));
+
+ let reversed_collector =
+ CustomScoreTopCollector::new(size, condition, ascending::<$type>(field));
+
+ let (top, reversed_top) =
+ searcher.search(&AllQuery, &(collector, reversed_collector))?;
+
+ let sorted_scores: Vec<$type> = top.items.iter().map(|r| r.score).collect();
+ assert_eq!(
+ $values,
+ sorted_scores.as_slice(),
+ "found scores don't match input"
+ );
+
+ let ids = just_the_ids(top);
+ let mut reversed_ids = just_the_ids(reversed_top);
+
+ reversed_ids.reverse();
+ assert_eq!(
+ ids,
+ reversed_ids.as_slice(),
+ "should have found the same ids, in reversed order"
+ );
+
+ Ok(())
+ }
+ };
+ }
+
+ check_order_from_sorted_values!(
+ u64_field_sort_functionality,
+ add_u64_field,
+ add_u64,
+ u64,
+ [3, 2, 1, 0]
+ );
+
+ check_order_from_sorted_values!(
+ i64_field_sort_functionality,
+ add_i64_field,
+ add_i64,
+ i64,
+ [3, 2, 1, 0, -1, -2, -3]
+ );
+
+ check_order_from_sorted_values!(
+ f64_field_sort_functionality,
+ add_f64_field,
+ add_f64,
+ f64,
+ [100.0, 42.0, 0.71, 0.42]
+ );
+}
Deleted tique/src/top_collector/field.rs
-use tantivy::{collector::Collector, schema::Field, SegmentReader};
-
-use super::{CollectionResult, ConditionForSegment, CustomScoreTopCollector};
-
-macro_rules! fast_field_custom_score_collector {
- ($name: ident, $type: ty, $reader: ident) => {
- pub fn $name<C>(
- field: Field,
- limit: usize,
- condition_factory: C,
- ) -> impl Collector<Fruit = CollectionResult<$type>>
- where
- C: ConditionForSegment<$type> + Sync,
- {
- let scorer_for_segment = move |reader: &SegmentReader| {
- let scorer = reader
- .fast_fields()
- .$reader(field)
- .expect("Not a fast field");
- move |doc_id| scorer.get(doc_id)
- };
- CustomScoreTopCollector::new(limit, condition_factory, scorer_for_segment)
- }
- };
-}
-
-fast_field_custom_score_collector!(ordered_by_i64_fast_field, i64, i64);
-fast_field_custom_score_collector!(ordered_by_u64_fast_field, u64, u64);
-fast_field_custom_score_collector!(ordered_by_f64_fast_field, f64, f64);
-
-#[cfg(test)]
-mod tests {
- use super::*;
-
- use tantivy::{
- query::AllQuery,
- schema::{SchemaBuilder, FAST},
- Document, Index, Result,
- };
-
- #[test]
- fn integration() -> Result<()> {
- let mut sb = SchemaBuilder::new();
-
- let u64_field = sb.add_u64_field("u64", FAST);
- let i64_field = sb.add_i64_field("i64", FAST);
- let f64_field = sb.add_f64_field("f64", FAST);
-
- let index = Index::create_in_ram(sb.build());
- let mut writer = index.writer_with_num_threads(1, 50_000_000)?;
-
- let add_doc = |a: u64, b: i64, c: f64| {
- let mut doc = Document::new();
- doc.add_u64(u64_field, a);
- doc.add_i64(i64_field, b);
- doc.add_f64(f64_field, c);
- writer.add_document(doc);
- };
-
- add_doc(10, -10, 7.2);
- add_doc(20, -20, 4.2);
-
- writer.commit()?;
-
- let reader = index.reader()?;
- let searcher = reader.searcher();
-
- let u64_collector = ordered_by_u64_fast_field(u64_field, 2, true);
- let i64_collector = ordered_by_i64_fast_field(i64_field, 2, true);
- let f64_collector = ordered_by_f64_fast_field(f64_field, 2, true);
-
- let (top_u64, top_i64, top_f64) =
- searcher.search(&AllQuery, &(u64_collector, i64_collector, f64_collector))?;
-
- let sorted_u64_scores: Vec<u64> = top_u64.items.into_iter().map(|r| r.score).collect();
- assert_eq!(vec![20, 10], sorted_u64_scores);
-
- let sorted_i64_scores: Vec<i64> = top_i64.items.into_iter().map(|r| r.score).collect();
- assert_eq!(vec![-10, -20], sorted_i64_scores);
-
- let sorted_f64_scores: Vec<f64> = top_f64.items.into_iter().map(|r| r.score).collect();
- assert_eq!(vec![7.2, 4.2], sorted_f64_scores);
-
- Ok(())
- }
-}