caio.co/de/cantine

Use new DocSet/Scorer API

Introduced on tantivy.git @ e25284ba

This changeset is sufficient, however upstream's f71b04acb introduced
a `debug_assert!(self.doc() <= target)` for `SegmentPostings::seek`
that looks overzealous to me.

In release mode all tests pass, but given that lot has changed since
last I looked I'll be double checking the affected functionality
prior to letting this go wild.
Id
13fca90e44bc8be5b9d1f8769138bce041396e23
Author
Caio
Commit time
2020-10-31T18:58:43+01:00

Modified tique/Cargo.toml

@@ -20,7 +20,7
queryparser = ["nom"]

[dependencies]
-tantivy = "0.12"
+tantivy = "0.13"
nom = { version = "5", optional = true }

[dev-dependencies]

Modified tique/src/dismax.rs

@@ -1,7 +1,7
use tantivy::{
self,
query::{EmptyScorer, Explanation, Query, Scorer, Weight},
- DocId, DocSet, Result, Score, Searcher, SegmentReader, SkipResult, TantivyError,
+ DocId, DocSet, Result, Score, Searcher, SegmentReader, TantivyError, TERMINATED,
};

/// A Maximum Disjunction query, as popularized by Lucene/Solr
@@ -87,7 +87,7
fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result<Explanation> {
let mut scorer = self.scorer(reader, 1.0)?;

- if scorer.skip_next(doc) != SkipResult::Reached {
+ if scorer.doc() > doc || scorer.seek(doc) != doc {
return Err(TantivyError::InvalidArgument("Not a match".to_owned()));
}

@@ -111,16 +111,17

struct DisMaxScorer {
scorers: Vec<Box<dyn Scorer>>,
- current: Option<DocId>,
+ current: DocId,
tiebreaker: f32,
}

impl DisMaxScorer {
fn new(scorers: Vec<Box<dyn Scorer>>, tiebreaker: f32) -> Self {
+ let current = scorers.iter().map(|s| s.doc()).min().unwrap_or(TERMINATED);
Self {
scorers,
tiebreaker,
- current: None,
+ current,
}
}
}
@@ -130,9 +131,8
let mut max = 0.0;
let mut sum = 0.0;

- debug_assert!(self.current.is_some());
for scorer in &mut self.scorers {
- if self.current.map_or(false, |d| scorer.doc() == d) {
+ if scorer.doc() == self.current {
let score = scorer.score();
sum += score;

@@ -147,20 +147,21
}

impl DocSet for DisMaxScorer {
- fn advance(&mut self) -> bool {
- let mut next_target = None;
+ fn advance(&mut self) -> DocId {
+ let mut next_target = TERMINATED;
let mut to_remove = Vec::new();

for (idx, scorer) in self.scorers.iter_mut().enumerate() {
// Advance every scorer that's on target or behind
- if self.current.map_or(true, |d| d >= scorer.doc()) && !scorer.advance() {
+ if scorer.doc() <= self.current && scorer.advance() == TERMINATED {
to_remove.push(idx);
continue;
}

let doc = scorer.doc();
- if next_target.map_or(true, |next| doc < next) {
- next_target.replace(doc);
+
+ if doc < next_target {
+ next_target = doc;
}
}

@@ -168,17 +169,12
self.scorers.remove(idx);
}

- if let Some(target) = next_target {
- self.current.replace(target);
- true
- } else {
- false
- }
+ self.current = next_target;
+ next_target
}

fn doc(&self) -> tantivy::DocId {
- debug_assert!(self.current.is_some());
- self.current.unwrap_or(0)
+ self.current
}

fn size_hint(&self) -> u32 {
@@ -190,7 +186,7
mod tests {
use super::*;

- use std::{num::Wrapping, ops::Range};
+ use std::ops::Range;

use tantivy::{
doc,
@@ -199,11 +195,10
DocAddress, Index, Term,
};

- // XXX ConstScorer::from(VecDocSet::from(...)), but I can't seem
- // import tantivy::query::VecDocSet here??
struct VecScorer {
doc_ids: Vec<DocId>,
- cursor: Wrapping<usize>,
+ cursor: usize,
+ current: DocId,
}

impl Scorer for VecScorer {
@@ -212,14 +207,31
}
}

+ impl From<Range<DocId>> for VecScorer {
+ fn from(range: Range<DocId>) -> Self {
+ let doc_ids: Vec<DocId> = range.collect();
+ let current = *doc_ids.first().unwrap_or(&TERMINATED);
+ Self {
+ doc_ids,
+ cursor: 0,
+ current,
+ }
+ }
+ }
+
impl DocSet for VecScorer {
- fn advance(&mut self) -> bool {
- self.cursor += Wrapping(1);
- self.doc_ids.len() > self.cursor.0
+ fn advance(&mut self) -> DocId {
+ self.cursor += 1;
+ if self.cursor >= self.doc_ids.len() {
+ self.current = TERMINATED;
+ } else {
+ self.current = self.doc_ids[self.cursor];
+ }
+ self.doc()
}

fn doc(&self) -> DocId {
- self.doc_ids[self.cursor.0]
+ self.current
}

fn size_hint(&self) -> u32 {
@@ -227,61 +239,70
}
}

- fn test_scorer(range: Range<DocId>) -> Box<dyn Scorer> {
- Box::new(VecScorer {
- doc_ids: range.collect(),
- cursor: Wrapping(usize::max_value()),
- })
+ fn make_test_scorer(range: Range<DocId>) -> Box<dyn Scorer> {
+ Box::new(VecScorer::from(range))
}

#[test]
fn scorer_advances_as_union() {
+ // The union of all doc ids will be:
+ // 0,1,2,3,...,29,42
let scorers = vec![
- test_scorer(0..10),
- test_scorer(5..20),
- test_scorer(9..30),
- test_scorer(42..43),
- test_scorer(13..13), // empty docset
+ make_test_scorer(0..10),
+ make_test_scorer(5..20),
+ make_test_scorer(9..30),
+ make_test_scorer(42..43),
+ make_test_scorer(13..13), // empty docset
];

let mut dismax = DisMaxScorer::new(scorers, 0.0);

for i in 0..30 {
- assert!(dismax.advance(), "failed advance at i={}", i);
assert_eq!(i, dismax.doc());
+ assert!(dismax.advance() != TERMINATED, "failed advance at i={}", i);
}

- assert!(dismax.advance());
assert_eq!(42, dismax.doc());
- assert!(!dismax.advance(), "scorer should have ended by now");
+ assert!(
+ dismax.advance() == TERMINATED,
+ "scorer should have ended by now"
+ );
}

#[test]
#[allow(clippy::float_cmp)]
fn tiebreaker() {
- let scorers = vec![test_scorer(4..5), test_scorer(4..6), test_scorer(4..7)];
+ let scorers = vec![
+ make_test_scorer(4..5),
+ make_test_scorer(4..6),
+ make_test_scorer(4..7),
+ ];

// So now the score is the sum of scores for
// every matching scorer (VecScorer always yields 1)
let mut dismax = DisMaxScorer::new(scorers, 1.0);

- assert!(dismax.advance());
assert_eq!(3.0, dismax.score());
- assert!(dismax.advance());
+ assert!(dismax.advance() != TERMINATED);
assert_eq!(2.0, dismax.score());
- assert!(dismax.advance());
+ assert!(dismax.advance() != TERMINATED);
assert_eq!(1.0, dismax.score());
- assert!(!dismax.advance(), "scorer should have ended by now");
+ assert!(
+ dismax.advance() == TERMINATED,
+ "scorer should have ended by now"
+ );

- let scorers = vec![test_scorer(7..8), test_scorer(7..8)];
+ let scorers = vec![make_test_scorer(7..8), make_test_scorer(7..8)];

// With a tiebreaker 0, it actually uses
// the maximum disjunction
let mut dismax = DisMaxScorer::new(scorers, 0.0);
- assert!(dismax.advance());
// So now, even though doc=7 occurs twice, the score is 1
assert_eq!(1.0, dismax.score());
- assert!(!dismax.advance(), "scorer should have ended by now");
+ assert!(
+ dismax.advance() == TERMINATED,
+ "scorer should have ended by now"
+ );
}

#[test]

Modified tique/src/topterms.rs

@@ -69,7 +69,7
query::{BooleanQuery, BoostQuery, Occur, Query, TermQuery},
schema::{Field, FieldType, IndexRecordOption, Schema},
tokenizer::TextAnalyzer,
- DocAddress, DocSet, Index, IndexReader, Postings, Result, Searcher, SkipResult, Term,
+ DocAddress, DocSet, Index, IndexReader, Postings, Result, Searcher, Term,
};

use crate::conditional_collector::topk::{DescendingTopK, TopK};
@@ -297,7 +297,7
let mut postings =
inverted_index.read_postings_from_terminfo(terminfo, IndexRecordOption::WithFreqs);

- if postings.skip_next(doc_id) == SkipResult::Reached {
+ if postings.seek(doc_id) == doc_id {
let term = Term::from_field_text(field, text);
consumer(term, postings.term_freq());
}
@@ -437,7 +437,7
let topterms = TopTerms::new(&index, vec![source, quote])?;

let keyword_filter = |term: &Term, _tf, doc_freq, num_docs| {
- // Only words with more than characters 3
+ // Only words with more than 3 characters
term.text().chars().count() > 3
// that do NOT appear in every document at this field
&& doc_freq < num_docs

Modified tique/src/conditional_collector/custom_score.rs

@@ -46,8 +46,8
where
T: 'static + PartialOrd + Copy + Send + Sync,
P: 'static + Send + Sync + TopKProvider<T, DocId>,
- C: Sync + ConditionForSegment<T>,
- S: CustomScorer<T>,
+ C: Send + Sync + ConditionForSegment<T>,
+ S: Send + CustomScorer<T>,
{
type Fruit = CollectionResult<T>;
type Child = CustomScoreTopSegmentCollector<T, C::Type, S::Child, P::Child>;

Modified tique/src/conditional_collector/top_collector.rs

@@ -120,7 +120,7
{
/// Transforms this collector into that that uses the given
/// scorer instead of the default scoring functionality.
- pub fn with_custom_scorer<C: CustomScorer<T>>(
+ pub fn with_custom_scorer<C: Send + CustomScorer<T>>(
self,
custom_scorer: C,
) -> impl Collector<Fruit = CollectionResult<T>> {
@@ -167,7 +167,7
impl<P, CF> Collector for TopCollector<Score, P, CF>
where
P: 'static + Send + Sync + TopKProvider<Score, DocId>,
- CF: Sync + ConditionForSegment<Score>,
+ CF: Send + Sync + ConditionForSegment<Score>,
{
type Fruit = CollectionResult<Score>;
type Child = TopSegmentCollector<Score, P::Child, CF::Type>;