caio.co/de/cantine

Use a simpler Feature trait

Flattening it all solves the "how to expose the generated structs"
problem and simplifies the code by a *lot*.

Besides, the previous incantation with support for multiple query
types was cute and all, but useless in practice: a custom query means
implementing business logic, and the main purpose purpose of this
thing is avoiding that.
Id
9377c12b46f727706ae97a90a36acb2956d86b2e
Author
Caio
Commit time
2020-02-04T20:58:28+01:00

Modified cantine_derive/src/lib.rs

@@ -57,14 +57,15
}
}

-pub trait Mergeable: Send + Sync {
+pub trait Aggregator<Q, F>: Send + Sync {
fn merge_same_size(&mut self, other: &Self);
+ fn collect(&mut self, query: &Q, feature: &F);
+ fn from_query(query: &Q) -> Self;
}

-pub trait Feature<TQuery>: Sync {
- type Agg: Mergeable + for<'a> From<&'a TQuery>;
-
- fn collect_into(&self, query: &TQuery, agg: &mut Self::Agg);
+pub trait Feature: Sized + Sync {
+ type Query: Sync + Clone;
+ type Agg: Aggregator<Self::Query, Self>;
}

pub trait FeatureForSegment<T>: Sync {
@@ -90,11 +91,10
_marker: PhantomData<T>,
}

-impl<T, A, Q, F, O> FeatureCollector<T, Q, F>
+impl<T, Q, F, O> FeatureCollector<T, Q, F>
where
- T: 'static + Feature<Q, Agg = A>,
+ T: 'static + Feature,
Q: 'static + Clone + Sync,
- A: 'static + Mergeable + for<'a> From<&'a Q>,
F: FeatureForSegment<T, Output = O>,
O: 'static + FeatureForDoc<T>,
{
@@ -107,16 +107,14
}
}

-impl<T, A, Q, F, O> Collector for FeatureCollector<T, Q, F>
+impl<T, F, O> Collector for FeatureCollector<T, T::Query, F>
where
- T: 'static + Feature<Q, Agg = A>,
- Q: 'static + Clone + Sync,
- A: 'static + Mergeable + for<'a> From<&'a Q>,
+ T: 'static + Feature,
F: FeatureForSegment<T, Output = O>,
O: 'static + FeatureForDoc<T>,
{
- type Fruit = A;
- type Child = FeatureSegmentCollector<T, A, Q, O>;
+ type Fruit = T::Agg;
+ type Child = FeatureSegmentCollector<T, O>;

fn for_segment(
&self,
@@ -124,10 +122,9
segment_reader: &SegmentReader,
) -> Result<Self::Child> {
Ok(FeatureSegmentCollector {
- agg: A::from(&self.query),
+ agg: T::Agg::from_query(&self.query),
query: self.query.clone(),
reader: self.reader_factory.for_segment(segment_reader),
- _marker: PhantomData,
})
}

@@ -138,7 +135,9
fn merge_fruits(&self, fruits: Vec<Self::Fruit>) -> Result<Self::Fruit> {
let mut iter = fruits.into_iter();

- let mut first = iter.next().unwrap_or_else(|| A::from(&self.query));
+ let mut first = iter
+ .next()
+ .unwrap_or_else(|| T::Agg::from_query(&self.query));

for fruit in iter {
first.merge_same_size(&fruit);
@@ -161,25 +160,22
}
}

-pub struct FeatureSegmentCollector<T, A, Q, F> {
- agg: A,
- query: Q,
+pub struct FeatureSegmentCollector<T: Feature, F> {
+ agg: T::Agg,
+ query: T::Query,
reader: F,
- _marker: PhantomData<T>,
}

-impl<T, A, Q, F> SegmentCollector for FeatureSegmentCollector<T, A, Q, F>
+impl<T, F> SegmentCollector for FeatureSegmentCollector<T, F>
where
- T: 'static + Feature<Q, Agg = A>,
- Q: 'static,
- A: 'static + Mergeable + for<'a> From<&'a Q>,
+ T: 'static + Feature,
F: 'static + FeatureForDoc<T>,
{
- type Fruit = A;
+ type Fruit = T::Agg;

fn collect(&mut self, doc: DocId, _score: Score) {
if let Some(item) = self.reader.for_doc(doc) {
- item.collect_into(&self.query, &mut self.agg);
+ self.agg.collect(&self.query, &item);
}
}

@@ -196,92 +192,29

use tantivy::{query::AllQuery, schema::SchemaBuilder, Document, Index};

- struct Metadata {
- a: i16,
- b: u16,
- }
-
// XXX Who will test the tests?
- impl Metadata {
- pub fn as_bytes(&self) -> [u8; 4] {
- let mut out = [0u8; 4];
- out[0..2].copy_from_slice(&self.a.to_le_bytes());
- out[2..].copy_from_slice(&self.b.to_le_bytes());
- out
- }
-
- pub fn from_bytes(src: [u8; 4]) -> Self {
- let a = i16::from_le_bytes(src[0..2].try_into().unwrap());
- let b = u16::from_le_bytes(src[2..].try_into().unwrap());
- Self { a, b }
- }
- }
-
- #[derive(Debug, Default)]
- struct MetaAgg {
- a: usize,
- b: usize,
- }
-
- impl Mergeable for MetaAgg {
- fn merge_same_size(&mut self, other: &Self) {
- self.a += other.a;
- self.b += other.b;
- }
- }
-
- #[derive(Clone)]
- struct LessThanMetaQuery {
- a: i16,
- b: u16,
- }
-
- impl From<&LessThanMetaQuery> for MetaAgg {
- fn from(_src: &LessThanMetaQuery) -> Self {
- Self::default()
- }
- }
-
- impl Feature<LessThanMetaQuery> for Metadata {
- type Agg = MetaAgg;
-
- fn collect_into(&self, query: &LessThanMetaQuery, agg: &mut Self::Agg) {
- if self.a < query.a {
- agg.a += 1;
- }
- if self.b < query.b {
- agg.b += 1;
- }
- }
- }
-
- #[derive(Clone)]
- struct CountARangesQuery(Vec<Range<i16>>);
-
- impl From<&CountARangesQuery> for Vec<i16> {
- fn from(src: &CountARangesQuery) -> Self {
- vec![0; src.0.len()]
- }
- }
-
- impl Mergeable for Vec<i16> {
+ impl Aggregator<Vec<Range<i16>>, i16> for Vec<i16> {
fn merge_same_size(&mut self, other: &Self) {
for (idx, tally) in other.iter().enumerate() {
self[idx] += tally;
}
}
- }
-
- impl Feature<CountARangesQuery> for Metadata {
- type Agg = Vec<i16>;
-
- fn collect_into(&self, query: &CountARangesQuery, agg: &mut Self::Agg) {
- for (idx, range) in query.0.iter().enumerate() {
- if range.contains(&self.a) {
- agg[idx] += 1;
+ fn collect(&mut self, query: &Vec<Range<i16>>, feature: &i16) {
+ for (idx, range) in query.iter().enumerate() {
+ if range.contains(&feature) {
+ self[idx] += 1;
}
}
}
+
+ fn from_query(query: &Vec<Range<i16>>) -> Self {
+ vec![0; query.len()]
+ }
+ }
+
+ impl Feature for i16 {
+ type Query = Vec<Range<i16>>;
+ type Agg = Vec<i16>;
}

#[test]
@@ -290,62 +223,37
let bytes_field = builder.add_bytes_field("metadata_as_bytes");

let index = Index::create_in_ram(builder.build());
-
let mut writer = index.writer_with_num_threads(1, 3_000_000)?;

- let add_doc = |meta: Metadata| {
+ for i in -4i16..0 {
let mut doc = Document::new();
- doc.add_bytes(bytes_field, meta.as_bytes().to_vec());
+ doc.add_bytes(bytes_field, i.to_le_bytes().to_vec());
writer.add_document(doc);
- };
-
- add_doc(Metadata { a: -1, b: 1 });
- add_doc(Metadata { a: -2, b: 2 });
- add_doc(Metadata { a: -3, b: 3 });
- add_doc(Metadata { a: -4, b: 4 });
+ }

writer.commit()?;

let reader = index.reader()?;
let searcher = reader.searcher();

- let less_than_collector = FeatureCollector::<Metadata, _, _>::new(
- // So we want count:
- // * Every document that has "a" < -1
- // * Every document that has "b" < 2
- LessThanMetaQuery { a: -1, b: 2 },
+ let ranges_collector = FeatureCollector::<i16, _, _>::new(
+ vec![-10..0, 0..10, -2..4],
move |reader: &SegmentReader| {
let bytes_reader = reader.fast_fields().bytes(bytes_field).unwrap();

move |doc_id| {
- let metadata_bytes = bytes_reader.get_bytes(doc_id);
- metadata_bytes.try_into().ok().map(Metadata::from_bytes)
+ bytes_reader
+ .get_bytes(doc_id)
+ .try_into()
+ .ok()
+ .map(i16::from_le_bytes)
}
},
);

- let a_ranges_collector = FeatureCollector::<Metadata, _, _>::new(
- // And here we'll get a count for:
- // * Every doc that a is within -10..0 (4)
- // * Every doc that a is within 0..10 (0)
- // * Every doc that a is within -2..4 (2)
- CountARangesQuery(vec![-10..0, 0..10, -2..4]),
- move |reader: &SegmentReader| {
- let bytes_reader = reader.fast_fields().bytes(bytes_field).unwrap();
+ let range_counts = searcher.search(&AllQuery, &ranges_collector)?;

- move |doc_id| {
- let metadata_bytes = bytes_reader.get_bytes(doc_id);
- metadata_bytes.try_into().ok().map(Metadata::from_bytes)
- }
- },
- );
-
- let (agg, a_range_counts) =
- searcher.search(&AllQuery, &(less_than_collector, a_ranges_collector))?;
-
- assert_eq!(3, agg.a);
- assert_eq!(1, agg.b);
- assert_eq!(vec![4, 0, 2], a_range_counts);
+ assert_eq!(vec![4, 0, 2], range_counts);

Ok(())
}

Modified cantine_derive/internal/src/lib.rs

@@ -24,43 +24,11

let agg_query = make_agg_query(&input);
let agg_result = make_agg_result(&input);
- let collector = impl_collector_traits(&input);

TokenStream::from(quote! {
#agg_query
#agg_result
- #collector
})
-}
-
-fn impl_collector_traits(input: &DeriveInput) -> TokenStream2 {
- let meta = &input.ident;
- let agg = format_ident!("{}AggregationResult", meta);
- let query = format_ident!("{}AggregationQuery", meta);
-
- quote! {
- impl cantine_derive::Mergeable for #agg {
- fn merge_same_size(&mut self, other: &Self) {
- <#agg>::merge_same_size(self, other);
- }
- }
-
- impl cantine_derive::Feature<#query> for #meta {
- type Agg = #agg;
-
- fn collect_into(&self, query: &#query, agg: &mut #agg) {
- agg.collect(&query, &self);
- }
- }
-
- impl cantine_derive::Feature<#query> for &#meta {
- type Agg = #agg;
-
- fn collect_into(&self, query: &#query, agg: &mut #agg) {
- agg.collect(&query, &self);
- }
- }
- }
}

fn make_filter_query(input: &DeriveInput) -> TokenStream2 {
@@ -359,14 +327,34
#(#fields),*
}

+ impl cantine_derive::Feature for #feature {
+ type Query = #agg_query;
+ type Agg = #name;
+ }
+
+ impl cantine_derive::Aggregator<#agg_query, #feature> for #name {
+ fn merge_same_size(&mut self, other: &Self) {
+ <#name>::merge_same_size(self, other);
+ }
+
+ fn collect(&mut self, query: &#agg_query, feature: &#feature) {
+ <#name>::collect(self, query, feature);
+ }
+
+ fn from_query(query: &#agg_query) -> Self {
+ <#name>::from(query)
+ }
+ }
+
impl #name {
- pub fn merge_same_size(&mut self, other: &Self) {
+ fn merge_same_size(&mut self, other: &Self) {
#(#merge_code);*
}

- pub fn collect(&mut self, query: &#agg_query, feature: &#feature) {
+ fn collect(&mut self, query: &#agg_query, feature: &#feature) {
#(#collect_code);*
}
+
}

impl From<&#agg_query> for #name {