Add `top_fast_field` helper method
This patch adds a method to `TopCollector` to convert it to a `CustomScoreTopCollector` sorted by a fast field. We can drive the whole thing via the basic top collector, which leads to something kinda cute: TopCollector::::new(10, condition) .top_fast_field(field);
- Id
- c62f1b132e52721cb152e4fa13df0e21bce59c20
- Author
- Caio
- Commit time
- 2020-01-21T17:33:27+01:00
Modified tique/src/conditional_collector/top_collector.rs
use super::{
topk::{TopK, TopKProvider},
traits::{CheckCondition, ConditionForSegment},
+ CustomScoreTopCollector,
};
pub struct TopCollector<T, P, CF> {
}
}
}
+
+macro_rules! impl_top_fast_field {
+ ($type: ident, $err: literal) => {
+ impl<P, CF> TopCollector<$type, P, CF>
+ where
+ P: 'static + Send + Sync + TopKProvider<$type>,
+ CF: Send + Sync + ConditionForSegment<$type>,
+ {
+ pub fn top_fast_field(
+ self,
+ field: tantivy::schema::Field,
+ ) -> impl Collector<Fruit = CollectionResult<$type>> {
+ let scorer_for_segment = move |reader: &SegmentReader| {
+ let ff = reader.fast_fields().$type(field).expect($err);
+ move |doc_id| ff.get(doc_id)
+ };
+ CustomScoreTopCollector::<$type, P, _, _>::new(
+ self.limit,
+ self.condition_for_segment,
+ scorer_for_segment,
+ )
+ }
+ }
+ };
+}
+
+impl_top_fast_field!(u64, "Field is not a fast u64 field");
+impl_top_fast_field!(i64, "Field is not a fast i64 field");
+impl_top_fast_field!(f64, "Field is not a fast f64 field");
impl<P, CF> Collector for TopCollector<Score, P, CF>
where
Ascending, Descending,
};
- use tantivy::{query::TermQuery, schema, Document, Index, Result, Term};
+ use tantivy::{
+ query::{AllQuery, TermQuery},
+ schema, Document, Index, Result, Term,
+ };
#[test]
fn condition_is_checked() {
desc_scores.reverse();
assert_eq!(asc_scores, desc_scores);
+
+ Ok(())
+ }
+
+ #[test]
+ fn fast_field_collection() -> Result<()> {
+ let mut builder = schema::SchemaBuilder::new();
+
+ let field = builder.add_f64_field("field", schema::FAST);
+
+ let index = Index::create_in_ram(builder.build());
+ let mut writer = index.writer_with_num_threads(1, 3_000_000)?;
+
+ const NUM_DOCS: usize = 100;
+ for v in 0..NUM_DOCS {
+ let mut doc = Document::new();
+ doc.add_f64(field, f64::from(v as u32));
+ writer.add_document(doc);
+ }
+
+ writer.commit()?;
+
+ let reader = index.reader()?;
+ let searcher = reader.searcher();
+
+ let collector_asc =
+ TopCollector::<f64, Ascending, _>::new(NUM_DOCS, true).top_fast_field(field);
+ let collector_desc =
+ TopCollector::<f64, Descending, _>::new(NUM_DOCS, true).top_fast_field(field);
+
+ let (top_asc, mut top_desc) =
+ searcher.search(&AllQuery, &(collector_asc, collector_desc))?;
+
+ assert_eq!(NUM_DOCS, top_asc.items.len());
+
+ top_desc.items.reverse();
+ assert_eq!(top_asc.items, top_desc.items);
Ok(())
}