Merge branch 'multi_field_query_parser'
- Id
- bc64a1c7db2475fd6f8c34adf1d1fa58974d1184
- Author
- Caio
- Commit time
- 2020-03-09T13:29:01+01:00
Modified tique/Cargo.toml
[dependencies]
tantivy = "0.12"
nom = { version = "5", optional = true }
+
+[dev-dependencies]
+quickcheck = "0.9"
Modified cantine/src/main.rs
let mut subqueries: Vec<(Occur, Box<dyn Query>)> = Vec::new();
if let Some(fulltext) = &query.fulltext {
- if let Some(parsed) = self.query_parser.parse(fulltext.as_str())? {
+ if let Some(parsed) = self.query_parser.parse(fulltext.as_str()) {
subqueries.push((Occur::Must, parsed));
}
}
let index = Index::open_in_dir(&index_path)?;
let recipe_index = RecipeIndex::try_from(&index.schema())?;
- let query_parser = QueryParser::new(
- recipe_index.fulltext,
- index.tokenizer_for_field(recipe_index.fulltext)?,
- true,
- );
+ let query_parser = QueryParser::new(&index, vec![recipe_index.fulltext])?;
let reader = index.reader()?;
let search_state = Arc::new(SearchState {
Modified cantine/tests/index_integration.rs
let reader = GLOBAL.index.reader()?;
let searcher = reader.searcher();
- let parser = QueryParser::new(
- GLOBAL.cantine.fulltext,
- GLOBAL.index.tokenizer_for_field(GLOBAL.cantine.fulltext)?,
- true,
- );
+ let parser = QueryParser::new(&GLOBAL.index, vec![GLOBAL.cantine.fulltext])?;
- let query = parser.parse("potato cheese")?.unwrap();
+ let query = parser.parse("+potato +cheese").unwrap();
let (_total, found_ids, _next) =
GLOBAL
Modified tique/src/lib.rs
#[cfg(feature = "unstable")]
pub mod queryparser;
+
+mod dismax;
+pub use dismax::DisMaxQuery;
Modified tique/src/queryparser/mod.rs
-mod interpreter;
mod parser;
+mod raw;
-pub use interpreter::QueryParser;
+pub use parser::QueryParser;
Modified tique/src/queryparser/parser.rs
-use nom::{
+use super::raw::{parse_query, FieldNameValidator, RawQuery};
+
+use tantivy::{
self,
- branch::alt,
- bytes::complete::take_while1,
- character::complete::{char as is_char, multispace0},
- combinator::map,
- multi::many0,
- sequence::{delimited, preceded},
- IResult,
+ query::{AllQuery, BooleanQuery, BoostQuery, Occur, PhraseQuery, Query, TermQuery},
+ schema::{Field, IndexRecordOption},
+ tokenizer::TextAnalyzer,
+ Index, Result, Term,
};
-#[derive(Debug, PartialEq)]
-pub enum Token<'a> {
- Phrase(&'a str, bool),
- Term(&'a str, bool),
+pub struct QueryParser {
+ state: Vec<(Option<String>, Option<f32>, Interpreter)>,
+ default_indices: Vec<usize>,
}
-fn parse_not_phrase(input: &str) -> IResult<&str, Token> {
- map(preceded(is_char('-'), parse_phrase), |t| match t {
- Token::Phrase(inner, false) => Token::Phrase(inner, true),
- _ => unreachable!(),
- })(input)
+impl QueryParser {
+ pub fn new(index: &Index, fields: Vec<Field>) -> Result<Self> {
+ let schema = index.schema();
+
+ let mut parser = QueryParser {
+ default_indices: (0..fields.len()).collect(),
+ state: Vec::with_capacity(fields.len()),
+ };
+
+ for field in fields.into_iter() {
+ parser.state.push((
+ Some(schema.get_field_name(field).to_owned()),
+ None,
+ Interpreter {
+ field,
+ analyzer: index.tokenizer_for_field(field)?,
+ },
+ ));
+ }
+
+ Ok(parser)
+ }
+
+ pub fn set_boost(&mut self, field: Field, boost: Option<f32>) {
+ if let Some(row) = self
+ .position_by_field(field)
+ .map(|pos| self.state.get_mut(pos))
+ .flatten()
+ {
+ row.1 = boost;
+ }
+ }
+
+ pub fn set_name(&mut self, field: Field, name: Option<String>) {
+ if let Some(row) = self
+ .position_by_field(field)
+ .map(|pos| self.state.get_mut(pos))
+ .flatten()
+ {
+ row.0 = name;
+ }
+ }
+
+ pub fn set_default_fields(&mut self, fields: Vec<Field>) {
+ let mut indices = Vec::with_capacity(fields.len());
+ for field in fields.into_iter() {
+ if let Some(idx) = self.position_by_field(field) {
+ indices.push(idx);
+ }
+ }
+ indices.sort();
+ self.default_indices = indices;
+ }
+
+ fn position_by_name(&self, field_name: &str) -> Option<usize> {
+ self.state
+ .iter()
+ .position(|(opt_name, _opt_boost, _interpreter)| {
+ opt_name
+ .as_ref()
+ .map(|name| name == field_name)
+ .unwrap_or(false)
+ })
+ }
+
+ fn position_by_field(&self, field: Field) -> Option<usize> {
+ self.state
+ .iter()
+ .position(|(_opt_name, _opt_boost, interpreter)| interpreter.field == field)
+ }
+
+ pub fn parse(&self, input: &str) -> Option<Box<dyn Query>> {
+ let (_, parsed) = parse_query(input, self).ok()?;
+
+ match parsed.len() {
+ 0 => None,
+ 1 => {
+ let raw = &parsed[0];
+ let query = self.query_from_raw(&raw)?;
+
+ if raw.occur == Occur::MustNot {
+ Some(Box::new(BooleanQuery::from(vec![
+ (Occur::MustNot, query),
+ (Occur::Must, Box::new(AllQuery)),
+ ])))
+ } else {
+ Some(query)
+ }
+ }
+ _ => {
+ let mut subqueries: Vec<(Occur, Box<dyn Query>)> = Vec::new();
+
+ let mut num_must_not = 0;
+ for tok in parsed {
+ if let Some(query) = self.query_from_raw(&tok) {
+ if tok.occur == Occur::MustNot {
+ num_must_not += 1;
+ }
+
+ subqueries.push((tok.occur, query));
+ }
+ }
+
+ // Detect boolean queries with only MustNot clauses
+ // and appends a AllQuery otherwise the resulting
+ // query will match nothing
+ if num_must_not > 1 && num_must_not == subqueries.len() {
+ subqueries.push((Occur::Must, Box::new(AllQuery)));
+ }
+
+ match subqueries.len() {
+ 0 => None,
+ 1 => Some(subqueries.pop().expect("Element always present").1),
+ _ => Some(Box::new(BooleanQuery::from(subqueries))),
+ }
+ }
+ }
+ }
+
+ fn query_from_raw(&self, raw_query: &RawQuery) -> Option<Box<dyn Query>> {
+ let indices = if let Some(position) = raw_query
+ .field_name
+ .map(|field_name| self.position_by_name(field_name))
+ .flatten()
+ {
+ vec![position]
+ } else {
+ self.default_indices.clone()
+ };
+
+ let queries: Vec<Box<dyn Query>> = indices
+ .into_iter()
+ .flat_map(|i| self.state.get(i))
+ .flat_map(|(_, boost, interpreter)| {
+ interpreter.to_query(raw_query).map(|query| {
+ if let Some(val) = boost {
+ Box::new(BoostQuery::new(query, *val))
+ } else {
+ query
+ }
+ })
+ })
+ .collect();
+
+ match queries.len() {
+ 0 => None,
+ 1 => Some(queries.into_iter().nth(0).unwrap()),
+ _ => Some(Box::new(BooleanQuery::from(
+ queries
+ .into_iter()
+ .map(|q| (Occur::Should, q))
+ .collect::<Vec<_>>(),
+ ))),
+ }
+ }
}
-fn parse_phrase(input: &str) -> IResult<&str, Token> {
- map(
- delimited(is_char('"'), take_while1(|c| c != '"'), is_char('"')),
- |s| Token::Phrase(s, false),
- )(input)
+impl FieldNameValidator for QueryParser {
+ fn check(&self, field_name: &str) -> bool {
+ self.state
+ .iter()
+ .any(|(opt_name, _opt_boost, _interpreter)| {
+ opt_name
+ .as_ref()
+ .map(|name| name == field_name)
+ .unwrap_or(false)
+ })
+ }
}
-fn parse_term(input: &str) -> IResult<&str, Token> {
- map(take_while1(is_term_char), |s| Token::Term(s, false))(input)
+struct Interpreter {
+ field: Field,
+ analyzer: TextAnalyzer,
}
-fn parse_not_term(input: &str) -> IResult<&str, Token> {
- map(preceded(is_char('-'), parse_term), |t| match t {
- Token::Term(inner, false) => Token::Term(inner, true),
- _ => unreachable!(),
- })(input)
-}
+impl Interpreter {
+ fn to_query(&self, raw_query: &RawQuery) -> Option<Box<dyn Query>> {
+ let mut terms = Vec::new();
+ let mut stream = self.analyzer.token_stream(raw_query.input);
-fn is_term_char(c: char) -> bool {
- !(c == ' ' || c == '\t' || c == '\r' || c == '\n')
-}
+ stream.process(&mut |token| {
+ terms.push(Term::from_field_text(self.field, &token.text));
+ });
-pub fn parse_query(input: &str) -> IResult<&str, Vec<Token>> {
- many0(delimited(
- multispace0,
- alt((parse_not_phrase, parse_phrase, parse_not_term, parse_term)),
- multispace0,
- ))(input)
+ if terms.is_empty() {
+ return None;
+ }
+
+ let query: Box<dyn Query> = if terms.len() == 1 {
+ Box::new(TermQuery::new(
+ terms.pop().unwrap(),
+ IndexRecordOption::WithFreqs,
+ ))
+ } else if raw_query.is_phrase {
+ Box::new(PhraseQuery::new(terms))
+ } else {
+ // An analyzer might emit multiple tokens even if the
+ // raw parser only got one (say: raw takes "word", but
+ // analyzer is actually a char tokenizer)
+ Box::new(BooleanQuery::new_multiterms_query(terms))
+ };
+
+ Some(query)
+ }
}
#[cfg(test)]
mod tests {
use super::*;
- use super::Token::*;
+ use tantivy::tokenizer::TokenizerManager;
- #[test]
- fn term_extraction() {
- assert_eq!(parse_term("gula"), Ok(("", Term("gula", false))));
+ fn test_interpreter() -> Interpreter {
+ Interpreter {
+ field: Field::from_field_id(0),
+ analyzer: TokenizerManager::default().get("en_stem").unwrap(),
+ }
}
#[test]
- fn not_term_extraction() {
- assert_eq!(parse_not_term("-ads"), Ok(("", Term("ads", true))))
+ fn empty_raw_is_none() {
+ assert!(test_interpreter().to_query(&RawQuery::new("")).is_none());
}
#[test]
- fn phrase_extraction() {
- assert_eq!(
- parse_phrase("\"gula recipes\""),
- Ok(("", Phrase("gula recipes", false)))
- );
+ fn simple_raw_is_termquery() {
+ let query = test_interpreter()
+ .to_query(&RawQuery::new("word"))
+ .expect("parses to a Some(Query)");
+
+ assert!(query.as_any().downcast_ref::<TermQuery>().is_some());
}
#[test]
- fn not_phrase_extraction() {
- assert_eq!(
- parse_not_phrase("-\"ads and tracking\""),
- Ok(("", Phrase("ads and tracking", true)))
- );
+ fn phrase_raw_is_phrasequery() {
+ let query = test_interpreter()
+ .to_query(&RawQuery::new("sweet potato").phrase())
+ .expect("parses to a Some(Query)");
+
+ assert!(query.as_any().downcast_ref::<PhraseQuery>().is_some());
}
#[test]
- fn empty_term_not_allowed() {
- assert!(parse_term("").is_err());
+ fn single_word_raw_phrase_is_termquery() {
+ let query = test_interpreter()
+ .to_query(&RawQuery::new("single").phrase())
+ .expect("parses to a Some(Query)");
+
+ assert!(query.as_any().downcast_ref::<TermQuery>().is_some());
+ }
+
+ fn single_field_test_parser() -> QueryParser {
+ QueryParser {
+ default_indices: vec![0],
+ state: vec![(
+ None,
+ None,
+ Interpreter {
+ field: Field::from_field_id(0),
+ analyzer: TokenizerManager::default().get("en_stem").unwrap(),
+ },
+ )],
+ }
}
#[test]
- fn empty_phrase_not_allowed() {
- assert!(parse_phrase("\"\"").is_err());
+ fn empty_query_results_in_none() {
+ assert!(single_field_test_parser().parse("").is_none());
+ }
+
+ use tantivy::{
+ collector::TopDocs,
+ doc,
+ schema::{SchemaBuilder, TEXT},
+ DocAddress,
+ };
+
+ #[test]
+ fn index_integration() -> Result<()> {
+ let mut builder = SchemaBuilder::new();
+ let title = builder.add_text_field("title", TEXT);
+ let plot = builder.add_text_field("plot", TEXT);
+ let index = Index::create_in_ram(builder.build());
+ let mut writer = index.writer_with_num_threads(1, 3_000_000)?;
+
+ let doc_across = DocAddress(0, 0);
+ writer.add_document(doc!(
+ title => "Across the Universe",
+ plot => "Musical based on The Beatles songbook and set in the 60s England, \
+ America, and Vietnam. The love story of Lucy and Jude is intertwined \
+ with the anti-war movement and social protests of the 60s."
+ ));
+
+ let doc_moulin = DocAddress(0, 1);
+ writer.add_document(doc!(
+ title => "Moulin Rouge!",
+ plot => "A poet falls for a beautiful courtesan whom a jealous duke covets in \
+ this stylish musical, with music drawn from familiar 20th century sources."
+ ));
+
+ let doc_once = DocAddress(0, 2);
+ writer.add_document(doc!(
+ title => "Once",
+ plot => "A modern-day musical about a busker and an immigrant and their eventful\
+ week in Dublin, as they write, rehearse and record songs that tell their \
+ love story."
+ ));
+
+ writer.commit()?;
+ let reader = index.reader()?;
+ let searcher = reader.searcher();
+
+ let parser = QueryParser::new(&index, vec![title, plot])?;
+
+ let search = |input, limit| {
+ let query = parser.parse(input).expect("given input yields Some()");
+ searcher
+ .search(&query, &TopDocs::with_limit(limit))
+ .expect("working index")
+ };
+
+ let found = search("+title:Once musical", 2);
+ // Even if "musical" matches every document,
+ // there's a MUST query that only one matches
+ assert_eq!(1, found.len());
+ assert_eq!(doc_once, found[0].1);
+
+ let found = search("\"the beatles\"", 1);
+ assert!(!found.is_empty());
+ assert_eq!(doc_across, found[0].1);
+
+ // Purely negative queries should work too
+ for input in &["-story -love", "-\"love story\""] {
+ let found = search(input, 3);
+ assert_eq!(1, found.len());
+ assert_eq!(doc_moulin, found[0].1);
+
+ let found = search("-music -", 3);
+ assert_eq!(1, found.len());
+ assert_eq!(doc_moulin, found[0].1);
+ }
+
+ Ok(())
}
#[test]
- fn parse_query_works() {
- assert_eq!(
- parse_query(" peanut -\"peanut butter\" -sugar "),
- Ok((
- "",
- vec![
- Term("peanut", false),
- Phrase("peanut butter", true),
- Term("sugar", true)
- ]
- ))
- );
- }
+ fn field_boosting() -> Result<()> {
+ let mut builder = SchemaBuilder::new();
+ let field_a = builder.add_text_field("a", TEXT);
+ let field_b = builder.add_text_field("b", TEXT);
+ let index = Index::create_in_ram(builder.build());
+ let mut writer = index.writer_with_num_threads(1, 3_000_000)?;
- #[test]
- fn parse_query_accepts_empty_string() {
- assert_eq!(parse_query(""), Ok(("", vec![])));
- assert_eq!(parse_query(" "), Ok((" ", vec![])));
- }
+ writer.add_document(doc!(
+ field_a => "bar",
+ field_b => "foo baz",
+ ));
- #[test]
- fn garbage_is_extracted_as_term() {
- assert_eq!(
- parse_query("- \""),
- Ok(("", vec![Term("-", false), Term("\"", false)]))
- );
+ writer.add_document(doc!(
+ field_a => "foo",
+ field_b => "bar",
+ ));
+
+ writer.add_document(doc!(
+ field_a => "bar",
+ field_b => "foo",
+ ));
+
+ writer.commit()?;
+
+ let mut parser = QueryParser::new(&index, vec![field_a, field_b])?;
+
+ let input = "foo baz";
+ let normal_query = parser.parse(&input).unwrap();
+
+ let reader = index.reader()?;
+ let searcher = reader.searcher();
+
+ let found = searcher.search(&normal_query, &TopDocs::with_limit(3))?;
+ assert_eq!(3, found.len());
+ // the first doc matches perfectly on `field_b`
+ assert_eq!(DocAddress(0, 0), found[0].1);
+
+ parser.set_boost(field_a, Some(1.5));
+ let boosted_query = parser.parse(&input).unwrap();
+
+ let found = searcher.search(&boosted_query, &TopDocs::with_limit(3))?;
+ assert_eq!(3, found.len());
+ // the first doc matches perfectly on field_b
+ // but now matching on `field_a` is super important
+ assert_eq!(DocAddress(0, 1), found[0].1);
+
+ Ok(())
}
}
Renamed tique/src/queryparser/interpreter.rs to tique/src/queryparser/raw.rs
-use super::parser::{parse_query, Token};
-
-use tantivy::{
+use nom::{
self,
- query::{AllQuery, BooleanQuery, Occur, PhraseQuery, Query, TermQuery},
- schema::{Field, IndexRecordOption},
- tokenizer::TextAnalyzer,
- Result, Term,
+ branch::alt,
+ bytes::complete::take_while1,
+ character::complete::{char as is_char, multispace0},
+ combinator::{map, map_res},
+ multi::many0,
+ sequence::{delimited, preceded, separated_pair},
+ IResult,
};
+use tantivy::query::Occur;
-pub struct QueryParser {
- field: Field,
- tokenizer: TextAnalyzer,
- occur: Occur,
+#[derive(Debug, PartialEq)]
+pub struct RawQuery<'a> {
+ pub input: &'a str,
+ pub is_phrase: bool,
+ pub field_name: Option<&'a str>,
+ pub occur: Occur,
}
-impl QueryParser {
- pub fn new(field: Field, tokenizer: TextAnalyzer, match_all: bool) -> QueryParser {
- QueryParser {
- field,
- tokenizer,
- occur: if match_all {
- Occur::Must
+const FIELD_SEP: char = ':';
+
+impl<'a> RawQuery<'a> {
+ pub fn new(input: &'a str) -> Self {
+ Self {
+ input,
+ is_phrase: false,
+ field_name: None,
+ occur: Occur::Should,
+ }
+ }
+
+ pub fn must_not(mut self) -> Self {
+ debug_assert_eq!(Occur::Should, self.occur);
+ self.occur = Occur::MustNot;
+ self
+ }
+
+ pub fn must(mut self) -> Self {
+ debug_assert_eq!(Occur::Should, self.occur);
+ self.occur = Occur::Must;
+ self
+ }
+
+ pub fn phrase(mut self) -> Self {
+ debug_assert!(!self.is_phrase);
+ self.is_phrase = true;
+ self
+ }
+
+ pub fn with_field(mut self, name: &'a str) -> Self {
+ debug_assert_eq!(None, self.field_name);
+ self.field_name = Some(name);
+ self
+ }
+}
+
+pub trait FieldNameValidator {
+ fn check(&self, field_name: &str) -> bool;
+}
+
+impl<T> FieldNameValidator for Vec<T>
+where
+ T: for<'a> PartialEq<&'a str>,
+{
+ fn check(&self, field_name: &str) -> bool {
+ self.iter().any(|item| item == &field_name)
+ }
+}
+
+impl FieldNameValidator for bool {
+ fn check(&self, _field_name: &str) -> bool {
+ *self
+ }
+}
+
+pub fn parse_query<'a, C: FieldNameValidator>(
+ input: &'a str,
+ validator: &'a C,
+) -> IResult<&'a str, Vec<RawQuery<'a>>> {
+ many0(delimited(
+ multispace0,
+ alt((
+ |input| prohibited_query(input, validator),
+ |input| mandatory_query(input, validator),
+ |input| field_prefixed_query(input, validator),
+ any_field_query,
+ )),
+ multispace0,
+ ))(input)
+}
+
+fn prohibited_query<'a, C: FieldNameValidator>(
+ input: &'a str,
+ validator: &'a C,
+) -> IResult<&'a str, RawQuery<'a>> {
+ map(
+ preceded(
+ is_char('-'),
+ alt((
+ |input| field_prefixed_query(input, validator),
+ any_field_query,
+ )),
+ ),
+ |query| query.must_not(),
+ )(input)
+}
+
+fn mandatory_query<'a, C: FieldNameValidator>(
+ input: &'a str,
+ validator: &'a C,
+) -> IResult<&'a str, RawQuery<'a>> {
+ map(
+ preceded(
+ is_char('+'),
+ alt((
+ |input| field_prefixed_query(input, validator),
+ any_field_query,
+ )),
+ ),
+ |query| query.must(),
+ )(input)
+}
+
+fn field_prefixed_query<'a, C: FieldNameValidator>(
+ input: &'a str,
+ validator: &'a C,
+) -> IResult<&'a str, RawQuery<'a>> {
+ map_res(
+ separated_pair(
+ take_while1(|c| c != FIELD_SEP && is_term_char(c)),
+ is_char(FIELD_SEP),
+ any_field_query,
+ ),
+ |(name, term)| {
+ if validator.check(name) {
+ Ok(term.with_field(name))
} else {
- Occur::Should
- },
- }
- }
-
- pub fn parse(&self, input: &str) -> Result<Option<Box<dyn Query>>> {
- let (_, parsed) = parse_query(input)
- .map_err(|e| tantivy::TantivyError::InvalidArgument(format!("{:?}", e)))?;
-
- Ok(match parsed.len() {
- 0 => None,
- 1 => self.query_from_token(&parsed[0])?,
- _ => {
- let mut subqueries: Vec<(Occur, Box<dyn Query>)> = Vec::new();
-
- for tok in parsed {
- if let Some(query) = self.query_from_token(&tok)? {
- subqueries.push((self.occur, query));
- }
- }
-
- match subqueries.len() {
- 0 => None,
- 1 => Some(subqueries.pop().expect("Element always present").1),
- _ => Some(Box::new(BooleanQuery::from(subqueries))),
- }
+ Err("Invalid field")
}
- })
- }
+ },
+ )(input)
+}
- fn assemble_query(&self, text: &str, allow_phrase: bool) -> Result<Option<Box<dyn Query>>> {
- let tokens = self.tokenize(text);
+fn any_field_query(input: &str) -> IResult<&str, RawQuery> {
+ alt((parse_phrase, parse_term))(input)
+}
- match &tokens[..] {
- [] => Ok(None),
- [(_, term)] => Ok(Some(Box::new(TermQuery::new(
- term.clone(),
- IndexRecordOption::WithFreqs,
- )))),
- _ => {
- if allow_phrase {
- Ok(Some(Box::new(PhraseQuery::new_with_offset(tokens))))
- } else {
- Err(tantivy::TantivyError::InvalidArgument(
- "More than one token found but allow_phrase is false".to_owned(),
- ))
- }
- }
- }
- }
+fn parse_phrase(input: &str) -> IResult<&str, RawQuery> {
+ map(
+ delimited(is_char('"'), take_while1(|c| c != '"'), is_char('"')),
+ |s| RawQuery::new(s).phrase(),
+ )(input)
+}
- //Not[Inner] queries are always [MatchAllDocs() - Inner]
- fn negate_query(inner: Box<dyn Query>) -> Box<dyn Query> {
- let subqueries: Vec<(Occur, Box<dyn Query>)> =
- vec![(Occur::MustNot, inner), (Occur::Must, Box::new(AllQuery))];
+fn parse_term(input: &str) -> IResult<&str, RawQuery> {
+ map(take_while1(is_term_char), RawQuery::new)(input)
+}
- let bq: BooleanQuery = subqueries.into();
- Box::new(bq)
- }
-
- // May result in Ok(None) because the tokenizer might give us nothing
- fn query_from_token(&self, token: &Token) -> Result<Option<Box<dyn Query>>> {
- let (query, negate) = match token {
- Token::Term(t, neg) => (self.assemble_query(t, false)?, *neg),
- Token::Phrase(p, neg) => (self.assemble_query(p, true)?, *neg),
- };
-
- if negate {
- Ok(query.map(|inner| Self::negate_query(inner)))
- } else {
- Ok(query)
- }
- }
-
- fn tokenize(&self, phrase: &str) -> Vec<(usize, Term)> {
- let mut terms: Vec<(usize, Term)> = Vec::new();
- let mut stream = self.tokenizer.token_stream(phrase);
-
- stream.process(&mut |token| {
- let term = Term::from_field_text(self.field, &token.text);
- terms.push((token.position, term));
- });
-
- terms
- }
+fn is_term_char(c: char) -> bool {
+ !(c == ' ' || c == '\t' || c == '\r' || c == '\n')
}
#[cfg(test)]
mod tests {
use super::*;
- use tantivy::tokenizer::TokenizerManager;
+ fn parse_no_fields(input: &str) -> IResult<&str, Vec<RawQuery>> {
+ parse_query(input, &false)
+ }
- fn test_parser() -> QueryParser {
- QueryParser::new(
- Field::from_field_id(0),
- TokenizerManager::default().get("en_stem").unwrap(),
- true,
+ #[test]
+ fn term_extraction() {
+ assert_eq!(
+ parse_no_fields("gula"),
+ Ok(("", vec![RawQuery::new("gula")]))
+ );
+ }
+
+ #[test]
+ fn prohibited_term_extraction() {
+ assert_eq!(
+ parse_no_fields("-ads"),
+ Ok(("", vec![RawQuery::new("ads").must_not()]))
)
}
- fn parsed(input: &str) -> Box<dyn Query> {
- test_parser()
- .parse(input)
- .unwrap()
- .expect("Should have gotten Some(dyn Query)")
+ #[test]
+ fn mandatory_term_extraction() {
+ assert_eq!(
+ parse_no_fields("+love"),
+ Ok(("", vec![RawQuery::new("love").must()]))
+ )
}
#[test]
- fn can_parse_term_query() {
- assert!(parsed("gula")
- .as_any()
- .downcast_ref::<TermQuery>()
- .is_some());
+ fn phrase_extraction() {
+ assert_eq!(
+ parse_no_fields("\"gula recipes\""),
+ Ok(("", vec![RawQuery::new("gula recipes").phrase()]))
+ );
}
#[test]
- fn can_parse_phrase_query() {
- assert!(parsed(" \"gula recipes\" ")
- .as_any()
- .downcast_ref::<PhraseQuery>()
- .is_some());
+ fn prohibited_phrase_extraction() {
+ assert_eq!(
+ parse_no_fields("-\"ads and tracking\""),
+ Ok((
+ "",
+ vec![RawQuery::new("ads and tracking").must_not().phrase()]
+ ))
+ );
}
#[test]
- fn single_term_phrase_query_becomes_term_query() {
- assert!(parsed(" \"gula\" ")
- .as_any()
- .downcast_ref::<TermQuery>()
- .is_some());
+ fn mandatory_phrase_extraction() {
+ assert_eq!(
+ parse_no_fields("+\"great food\""),
+ Ok(("", vec![RawQuery::new("great food").must().phrase()]))
+ );
}
#[test]
- fn negation_works() {
- let input = vec!["-hunger", "-\"ads and tracking\""];
-
- for i in input {
- let p = parsed(i);
- let query = p
- .as_any()
- .downcast_ref::<BooleanQuery>()
- .expect("Must be a boolean query");
-
- let clauses = query.clauses();
-
- assert_eq!(2, clauses.len());
- // XXX First clause is the wrapped {Term,Phrase}Query
-
- // Second clause is the MatchAllDocs()
- let (occur, inner) = &clauses[1];
- assert_eq!(Occur::Must, *occur);
- assert!(inner.as_any().downcast_ref::<AllQuery>().is_some())
- }
+ fn parse_query_works() {
+ assert_eq!(
+ parse_no_fields(" +peanut -\"peanut butter\" -sugar roast"),
+ Ok((
+ "",
+ vec![
+ RawQuery::new("peanut").must(),
+ RawQuery::new("peanut butter").phrase().must_not(),
+ RawQuery::new("sugar").must_not(),
+ RawQuery::new("roast")
+ ]
+ ))
+ );
}
- fn check_match_all(match_all: bool, wanted: Occur) -> Result<()> {
- let parser = QueryParser::new(
- Field::from_field_id(0),
- TokenizerManager::default().get("en_stem").unwrap(),
- match_all,
+ #[test]
+ fn check_field_behavior() {
+ let input = "title:banana ingredient:sugar";
+
+ // No field support: fields end up in the term
+ assert_eq!(
+ parse_query(input, &false),
+ Ok((
+ "",
+ vec![
+ RawQuery::new("title:banana"),
+ RawQuery::new("ingredient:sugar"),
+ ]
+ ))
);
- let parsed = parser.parse("two terms")?.unwrap();
+ // Any field support: field names are not valitdated at all
+ assert_eq!(
+ parse_query(input, &true),
+ Ok((
+ "",
+ vec![
+ RawQuery::new("banana").with_field("title"),
+ RawQuery::new("sugar").with_field("ingredient"),
+ ]
+ ))
+ );
- let bq = parsed
- .as_any()
- .downcast_ref::<BooleanQuery>()
- .expect("Must be a boolean query");
+ // Strict field support: known fields are identified, unknown
+ // ones are part of the term
+ assert_eq!(
+ parse_query(input, &vec!["ingredient"]),
+ Ok((
+ "",
+ vec![
+ RawQuery::new("title:banana"),
+ RawQuery::new("sugar").with_field("ingredient"),
+ ]
+ ))
+ );
+ }
- let clauses = bq.clauses();
+ #[test]
+ fn garbage_handling() {
+ assert_eq!(
+ parse_query("- -field: -\"\" body:\"\"", &true),
+ Ok((
+ "",
+ vec![
+ RawQuery::new("-"),
+ RawQuery::new("field:").must_not(),
+ RawQuery::new("\"\"").must_not(),
+ RawQuery::new("\"\"").with_field("body"),
+ ]
+ ))
+ );
+ }
- assert_eq!(2, clauses.len());
+ #[test]
+ fn parse_term_with_field() {
+ assert_eq!(
+ parse_query("title:potato:queen +instructions:mash -body:\"how to fail\" ingredient:\"golden peeler\"", &true),
+ Ok((
+ "",
+ vec![
+ RawQuery::new("potato:queen").with_field("title"),
+ RawQuery::new("mash").with_field("instructions").must(),
+ RawQuery::new("how to fail").with_field("body").must_not().phrase(),
+ RawQuery::new("golden peeler").with_field("ingredient").phrase()
+ ]
+ ))
+ );
+ }
- for (occur, _query) in clauses {
- assert_eq!(wanted, *occur);
+ use quickcheck::QuickCheck;
+
+ #[test]
+ fn can_handle_arbitrary_input() {
+ fn prop(input: String) -> bool {
+ parse_query(input.as_str(), &false).is_ok()
+ && parse_query(input.as_str(), &true).is_ok()
}
- Ok(())
- }
-
- #[test]
- fn queries_are_joined_according_to_match_all() -> Result<()> {
- check_match_all(true, Occur::Must)?;
- check_match_all(false, Occur::Should)
- }
-
- #[test]
- fn cannot_assemble_phrase_when_allow_phrase_is_false() {
- assert!(test_parser().assemble_query("hello world", false).is_err());
- }
-
- #[test]
- fn empty_query_results_in_none() {
- assert!(test_parser().parse("").unwrap().is_none());
- }
-
- #[test]
- fn tokenizer_may_make_query_empty() {
- // The test parses uses en_stem
- let parser = test_parser();
- // A raw tokenizer would yield Term<'> here
- assert!(parser.parse("'").unwrap().is_none());
- // And here would be a BooleanQuery with each term
- assert!(parser.parse("' < !").unwrap().is_none());
+ QuickCheck::new().quickcheck(prop as fn(String) -> bool);
}
}
Created tique/src/dismax.rs
+use tantivy::{
+ self,
+ query::{EmptyScorer, Explanation, Query, Scorer, Weight},
+ DocId, DocSet, Result, Score, Searcher, SegmentReader, SkipResult, TantivyError,
+};
+
+/// A Maximum Disjunction query, as popularized by Lucene/Solr
+///
+/// A DisMax query is one that behaves as the union of its sub-queries and
+/// the resulting documents are scored as the best score over each sub-query
+/// plus a configurable increment based on additional matches.
+///
+/// The final score formula is `score = max + (sum - max) * tiebreaker`,
+/// so with a tiebreaker of `0.0` you get only the maximum score and if you
+/// turn it up to `1.0` the score ends up being the sum of all scores, just
+/// like a plain "should" BooleanQuery would.
+///
+#[derive(Debug)]
+pub struct DisMaxQuery {
+ disjuncts: Vec<Box<dyn Query>>,
+ tiebreaker: f32,
+}
+
+impl DisMaxQuery {
+ /// Create a union-like query that picks the best score instead of the sum
+ ///
+ /// Panics if tiebreaker is not within the `[0,1]` range
+ pub fn new(disjuncts: Vec<Box<dyn Query>>, tiebreaker: f32) -> Self {
+ assert!((0.0..=1.0).contains(&tiebreaker));
+ Self {
+ disjuncts,
+ tiebreaker,
+ }
+ }
+}
+
+impl Clone for DisMaxQuery {
+ fn clone(&self) -> Self {
+ Self {
+ disjuncts: self.disjuncts.iter().map(|q| q.box_clone()).collect(),
+ tiebreaker: self.tiebreaker,
+ }
+ }
+}
+
+impl Query for DisMaxQuery {
+ fn weight(&self, searcher: &Searcher, scoring_enabled: bool) -> Result<Box<dyn Weight>> {
+ Ok(Box::new(DisMaxWeight::new(
+ self.disjuncts
+ .iter()
+ .map(|d| d.weight(searcher, scoring_enabled))
+ .collect::<Result<Vec<_>>>()?,
+ self.tiebreaker,
+ )))
+ }
+}
+
+struct DisMaxWeight {
+ weights: Vec<Box<dyn Weight>>,
+ tiebreaker: f32,
+}
+
+impl DisMaxWeight {
+ fn new(weights: Vec<Box<dyn Weight>>, tiebreaker: f32) -> Self {
+ Self {
+ weights,
+ tiebreaker,
+ }
+ }
+}
+
+impl Weight for DisMaxWeight {
+ fn scorer(&self, reader: &SegmentReader, boost: f32) -> Result<Box<dyn Scorer>> {
+ match self.weights.len() {
+ 0 => Ok(Box::new(EmptyScorer)),
+ 1 => self.weights.get(0).unwrap().scorer(reader, boost),
+ _ => Ok(Box::new(DisMaxScorer::new(
+ self.weights
+ .iter()
+ .map(|w| w.scorer(reader, boost))
+ .collect::<Result<Vec<_>>>()?,
+ self.tiebreaker,
+ ))),
+ }
+ }
+
+ fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result<Explanation> {
+ let mut scorer = self.scorer(reader, 1.0)?;
+
+ if scorer.skip_next(doc) != SkipResult::Reached {
+ return Err(TantivyError::InvalidArgument("Not a match".to_owned()));
+ }
+
+ let mut explanation = Explanation::new(
+ format!(
+ "DisMaxQuery. Score = max + (sum - max) * {}",
+ self.tiebreaker
+ ),
+ scorer.score(),
+ );
+
+ for weight in &self.weights {
+ if let Ok(sub_explanation) = weight.explain(reader, doc) {
+ explanation.add_detail(sub_explanation);
+ }
+ }
+
+ Ok(explanation)
+ }
+}
+
+struct DisMaxScorer {
+ scorers: Vec<Box<dyn Scorer>>,
+ current: Option<DocId>,
+ tiebreaker: f32,
+}
+
+impl DisMaxScorer {
+ fn new(scorers: Vec<Box<dyn Scorer>>, tiebreaker: f32) -> Self {
+ Self {
+ scorers,
+ tiebreaker,
+ current: None,
+ }
+ }
+}
+
+impl Scorer for DisMaxScorer {
+ fn score(&mut self) -> Score {
+ let mut max = 0.0;
+ let mut sum = 0.0;
+
+ debug_assert!(self.current.is_some());
+ for scorer in self.scorers.iter_mut() {
+ if self.current.map(|d| scorer.doc() == d).unwrap_or(false) {
+ let score = scorer.score();
+ sum += score;
+
+ if score > max {
+ max = score;
+ }
+ }
+ }
+
+ max + (sum - max) * self.tiebreaker
+ }
+}
+
+impl DocSet for DisMaxScorer {
+ fn advance(&mut self) -> bool {
+ let mut next_target = None;
+ let mut to_remove = Vec::new();
+
+ for (idx, scorer) in self.scorers.iter_mut().enumerate() {
+ // Advance every scorer that's on target or behind
+ if self.current.map(|d| d >= scorer.doc()).unwrap_or(true) && !scorer.advance() {
+ to_remove.push(idx);
+ continue;
+ }
+
+ let doc = scorer.doc();
+ if next_target.map(|next| doc < next).unwrap_or(true) {
+ next_target.replace(doc);
+ }
+ }
+
+ while let Some(idx) = to_remove.pop() {
+ self.scorers.remove(idx);
+ }
+
+ if let Some(target) = next_target {
+ self.current.replace(target);
+ true
+ } else {
+ false
+ }
+ }
+
+ fn doc(&self) -> tantivy::DocId {
+ debug_assert!(self.current.is_some());
+ self.current.unwrap_or(0)
+ }
+
+ fn size_hint(&self) -> u32 {
+ 0
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use std::{num::Wrapping, ops::Range};
+
+ use tantivy::{
+ doc,
+ query::TermQuery,
+ schema::{IndexRecordOption, SchemaBuilder, TEXT},
+ DocAddress, Index, Term,
+ };
+
+ // XXX ConstScorer::from(VecDocSet::from(...)), but I can't seem
+ // import tantivy::query::VecDocSet here??
+ struct VecScorer {
+ doc_ids: Vec<DocId>,
+ cursor: Wrapping<usize>,
+ }
+
+ impl Scorer for VecScorer {
+ fn score(&mut self) -> Score {
+ 1.0
+ }
+ }
+
+ impl DocSet for VecScorer {
+ fn advance(&mut self) -> bool {
+ self.cursor += Wrapping(1);
+ self.doc_ids.len() > self.cursor.0
+ }
+
+ fn doc(&self) -> DocId {
+ self.doc_ids[self.cursor.0]
+ }
+
+ fn size_hint(&self) -> u32 {
+ self.doc_ids.len() as u32
+ }
+ }
+
+ fn test_scorer(range: Range<DocId>) -> Box<dyn Scorer> {
+ Box::new(VecScorer {
+ doc_ids: range.collect(),
+ cursor: Wrapping(usize::max_value()),
+ })
+ }
+
+ #[test]
+ fn scorer_advances_as_union() {
+ let scorers = vec![
+ test_scorer(0..10),
+ test_scorer(5..20),
+ test_scorer(9..30),
+ test_scorer(42..43),
+ test_scorer(13..13), // empty docset
+ ];
+
+ let mut dismax = DisMaxScorer::new(scorers, 0.0);
+
+ for i in 0..30 {
+ assert!(dismax.advance(), "failed advance at i={}", i);
+ assert_eq!(i, dismax.doc());
+ }
+
+ assert!(dismax.advance());
+ assert_eq!(42, dismax.doc());
+ assert!(!dismax.advance(), "scorer should have ended by now");
+ }
+
+ #[test]
+ #[allow(clippy::float_cmp)]
+ fn tiebreaker() {
+ let scorers = vec![test_scorer(4..5), test_scorer(4..6), test_scorer(4..7)];
+
+ // So now the score is the sum of scores for
+ // every matching scorer (VecScorer always yields 1)
+ let mut dismax = DisMaxScorer::new(scorers, 1.0);
+
+ assert!(dismax.advance());
+ assert_eq!(3.0, dismax.score());
+ assert!(dismax.advance());
+ assert_eq!(2.0, dismax.score());
+ assert!(dismax.advance());
+ assert_eq!(1.0, dismax.score());
+ assert!(!dismax.advance(), "scorer should have ended by now");
+
+ let scorers = vec![test_scorer(7..8), test_scorer(7..8)];
+
+ // With a tiebreaker 0, it actually uses
+ // the maximum disjunction
+ let mut dismax = DisMaxScorer::new(scorers, 0.0);
+ assert!(dismax.advance());
+ // So now, even though doc=7 occurs twice, the score is 1
+ assert_eq!(1.0, dismax.score());
+ assert!(!dismax.advance(), "scorer should have ended by now");
+ }
+
+ #[test]
+ fn explaination() -> Result<()> {
+ let mut builder = SchemaBuilder::new();
+ let field = builder.add_text_field("field", TEXT);
+ let index = Index::create_in_ram(builder.build());
+ let mut writer = index.writer_with_num_threads(1, 3_000_000)?;
+
+ writer.add_document(doc!(field => "foo"));
+ writer.add_document(doc!(field => "bar"));
+ writer.add_document(doc!(field => "foo bar"));
+ writer.add_document(doc!(field => "baz"));
+ writer.commit()?;
+
+ let foo_query = TermQuery::new(
+ Term::from_field_text(field, "foo"),
+ IndexRecordOption::Basic,
+ );
+
+ let bar_query = TermQuery::new(
+ Term::from_field_text(field, "bar"),
+ IndexRecordOption::Basic,
+ );
+
+ let reader = index.reader()?;
+ let searcher = reader.searcher();
+
+ let dismax = DisMaxQuery::new(vec![Box::new(foo_query), Box::new(bar_query)], 0.0);
+
+ let baz_doc = DocAddress(0, 3);
+ assert!(
+ dismax.explain(&searcher, baz_doc).is_err(),
+ "Shouldn't be able to explain a non-matching doc"
+ );
+
+ // Ensure every other doc can be explained
+ for doc_id in 0..3 {
+ let explanation = dismax.explain(&searcher, DocAddress(0, doc_id))?;
+ assert!(explanation.to_pretty_json().contains("DisMaxQuery"));
+ }
+
+ Ok(())
+ }
+}