caio.co/de/cantine

Add types back to cursors

This patch starts adding the type information to the cursors. This
information is"lost" by containing everything as (u64,u64), so we
add it back by making cursor types enums.

The price I'm paying is an added byte for the type tag (mapping to
an enum member) which leads to 2 extra characters in the base64
SearchCursor representation.
Id
3049519e68ae7e48554cb3e6fe89df3608cd96b9
Author
Caio
Commit time
2020-01-10T13:56:27+01:00

Modified cantine/src/index.rs

@@ -7,7 +7,7
fastfield::FastFieldReader,
query::Query,
schema::{Field, Schema, SchemaBuilder, Value, FAST, STORED, TEXT},
- DocId, Document, Result, Searcher, SegmentLocalId, SegmentReader, TantivyError,
+ DocId, Document, Result, Score, Searcher, SegmentLocalId, SegmentReader, TantivyError,
};

use crate::model::{
@@ -75,58 +75,6
}

Ok(items)
- }
-
- fn topk_u64<S: 'static + ScorerForSegment<u64>>(
- &self,
- searcher: &Searcher,
- query: &dyn Query,
- limit: usize,
- after: After,
- scorer: S,
- ) -> Result<(usize, Vec<RecipeId>, Option<After>)> {
- let condition = Paginator::new_u64(self.id, after);
- let top_collector = CustomScoreTopCollector::new(limit, condition, scorer);
-
- let result = searcher.search(query, &top_collector)?;
- let items = self.addresses_to_ids(&searcher, &result.items)?;
-
- let num_items = items.len();
- let cursor = if result.visited.saturating_sub(num_items) > 0 {
- let last_score = result.items[num_items - 1].score;
- let last_id = items[num_items - 1];
- Some((last_score, last_id).into())
- } else {
- None
- };
-
- Ok((result.total, items, cursor))
- }
-
- fn topk_f64<S: 'static + ScorerForSegment<f64>>(
- &self,
- searcher: &Searcher,
- query: &dyn Query,
- limit: usize,
- after: After,
- scorer: S,
- ) -> Result<(usize, Vec<RecipeId>, Option<After>)> {
- let condition = Paginator::new_f64(self.id, after);
- let top_collector = CustomScoreTopCollector::new(limit, condition, scorer);
-
- let result = searcher.search(query, &top_collector)?;
- let items = self.addresses_to_ids(&searcher, &result.items)?;
-
- let num_items = items.len();
- let cursor = if result.visited.saturating_sub(num_items) > 0 {
- let last_score = result.items[num_items - 1].score;
- let last_id = items[num_items - 1];
- Some((last_score, last_id).into())
- } else {
- None
- };
-
- Ok((result.total, items, cursor))
}

pub fn search(
@@ -231,6 +179,44
}
}

+macro_rules! impl_typed_topk_fn {
+ ($name: ident, $type: ty, $paginator: ident) => {
+ impl RecipeIndex {
+ fn $name<S>(
+ &self,
+ searcher: &Searcher,
+ query: &dyn Query,
+ limit: usize,
+ after: After,
+ scorer: S,
+ ) -> Result<(usize, Vec<RecipeId>, Option<After>)>
+ where
+ S: 'static + ScorerForSegment<$type>,
+ {
+ let condition = Paginator::$paginator(self.id, after);
+ let top_collector = CustomScoreTopCollector::new(limit, condition, scorer);
+
+ let result = searcher.search(query, &top_collector)?;
+ let items = self.addresses_to_ids(&searcher, &result.items)?;
+
+ let num_items = items.len();
+ let cursor = if result.visited.saturating_sub(num_items) > 0 {
+ let last_score = result.items[num_items - 1].score;
+ let last_id = items[num_items - 1];
+ Some((last_score, last_id).into())
+ } else {
+ None
+ };
+
+ Ok((result.total, items, cursor))
+ }
+ }
+ };
+}
+
+impl_typed_topk_fn!(topk_u64, u64, new_u64);
+impl_typed_topk_fn!(topk_f64, f64, new_f64);
+
impl From<&mut SchemaBuilder> for RecipeIndex {
fn from(builder: &mut SchemaBuilder) -> Self {
RecipeIndex {
@@ -267,52 +253,29
}
}

-#[derive(Serialize, Deserialize, Debug, Default, Clone)]
-pub struct After(u64, RecipeId);
-
-impl After {
- pub const START: Self = Self(INVALID_RECIPE_ID, 0);
-
- pub fn new(score: u64, recipe_id: RecipeId) -> Self {
- Self(score, recipe_id)
- }
-
- pub fn recipe_id(&self) -> RecipeId {
- self.1
- }
-
- pub fn score(&self) -> u64 {
- self.0
- }
-
- fn score_f32(&self) -> f32 {
- f32::from_bits(self.0 as u32)
- }
-
- fn score_f64(&self) -> f64 {
- f64::from_bits(self.0)
- }
-
- fn is_start(&self) -> bool {
- self.0 == INVALID_RECIPE_ID
- }
+#[derive(Serialize, Deserialize, Debug, Clone)]
+pub enum After {
+ Start,
+ Relevance(Score, RecipeId),
+ F64Field(f64, RecipeId),
+ U64Field(u64, RecipeId),
}

-impl From<(f32, RecipeId)> for After {
- fn from(src: (f32, RecipeId)) -> Self {
- Self(src.0.to_bits() as u64, src.1)
+impl From<(Score, RecipeId)> for After {
+ fn from(src: (Score, RecipeId)) -> Self {
+ After::Relevance(src.0, src.1)
}
}

impl From<(f64, RecipeId)> for After {
fn from(src: (f64, RecipeId)) -> Self {
- Self(src.0.to_bits(), src.1)
+ After::F64Field(src.0, src.1)
}
}

impl From<(u64, RecipeId)> for After {
fn from(src: (u64, RecipeId)) -> Self {
- Self(src.0, src.1)
+ After::U64Field(src.0, src.1)
}
}

@@ -346,29 +309,31

impl Paginator<u64> {
pub fn new_u64(field: Field, after: After) -> Self {
- Paginator(field, after.is_start(), after.recipe_id(), after.score())
+ match after {
+ After::Start => Paginator(field, true, INVALID_RECIPE_ID, 0),
+ After::U64Field(score, id) => Paginator(field, false, id, score),
+ rest => panic!("Can't handle {:?}", rest),
+ }
}
}

impl Paginator<f64> {
pub fn new_f64(field: Field, after: After) -> Self {
- Paginator(
- field,
- after.is_start(),
- after.recipe_id(),
- after.score_f64(),
- )
+ match after {
+ After::Start => Paginator(field, true, INVALID_RECIPE_ID, 0.0),
+ After::F64Field(score, id) => Paginator(field, false, id, score),
+ rest => panic!("Can't handle {:?}", rest),
+ }
}
}

impl Paginator<f32> {
pub fn new(field: Field, after: After) -> Self {
- Paginator(
- field,
- after.is_start(),
- after.recipe_id(),
- after.score_f32(),
- )
+ match after {
+ After::Start => Paginator(field, true, INVALID_RECIPE_ID, 0.0),
+ After::Relevance(score, id) => Paginator(field, false, id, score),
+ rest => panic!("Can't handle {:?}", rest),
+ }
}
}

Modified cantine/src/main.rs

@@ -60,16 +60,26
Ok(HttpResponse::Ok().json(info.get_ref()))
}

+fn cursor_to_after(database: &RecipeDatabase, cursor: &SearchCursor) -> Option<After> {
+ database
+ .id_for_uuid(&Uuid::from_bytes(*cursor.uuid()))
+ .map(|id| match &cursor {
+ SearchCursor::Relevance(score, _) => After::Relevance(*score, *id),
+ SearchCursor::U64Field(score, _) => After::U64Field(*score, *id),
+ SearchCursor::F64Field(score, _) => After::F64Field(*score, *id),
+ })
+}
+
pub async fn search(
query: web::Json<SearchQuery>,
state: web::Data<Arc<SearchState>>,
database: web::Data<RecipeDatabase>,
) -> ActixResult<HttpResponse> {
let after = match &query.after {
- None => After::START,
+ None => After::Start,
Some(cursor) => {
- if let Some(recipe_id) = database.id_for_uuid(&Uuid::from_bytes(cursor.1)) {
- After::new(cursor.0, *recipe_id)
+ if let Some(after) = cursor_to_after(&database, cursor) {
+ after
} else {
return Ok(HttpResponse::new(StatusCode::BAD_REQUEST));
}
@@ -88,9 +98,15
items.push(RecipeCard::from(recipe));
}

- let next = after.map(|cursor| {
- let last = &items[num_results - 1];
- SearchCursor::new(cursor.score(), &last.uuid)
+ let next = after.map(|after| {
+ let last_uuid = &items[num_results - 1].uuid;
+
+ match after {
+ After::Relevance(score, _) => SearchCursor::Relevance(score, *last_uuid.as_bytes()),
+ After::U64Field(score, _) => SearchCursor::U64Field(score, *last_uuid.as_bytes()),
+ After::F64Field(score, _) => SearchCursor::F64Field(score, *last_uuid.as_bytes()),
+ _ => unreachable!(),
+ }
});

Ok(HttpResponse::Ok().json(SearchResult {

Modified cantine/src/model.rs

@@ -1,10 +1,11
-use std::{convert::TryInto, mem::size_of};
+use std::convert::TryInto;

use base64::{self, URL_SAFE_NO_PAD};
use serde::{
de::{Deserializer, Error, Visitor},
Deserialize, Serialize, Serializer,
};
+use tantivy::Score;
use uuid::{self, Uuid};

use crate::database::DatabaseRecord;
@@ -168,32 +169,64
pub next: Option<SearchCursor>,
}

-#[derive(Debug, Default, PartialEq)]
-pub struct SearchCursor(pub u64, pub uuid::Bytes);
+#[derive(Debug, PartialEq)]
+pub enum SearchCursor {
+ F64Field(f64, uuid::Bytes),
+ U64Field(u64, uuid::Bytes),
+ Relevance(Score, uuid::Bytes),
+}

impl SearchCursor {
- pub const SIZE: usize = size_of::<SearchCursor>();
+ /// tag + score_as_bits + uuid
+ pub const SIZE: usize = 1 + 8 + 16;

- pub fn new(score_bits: u64, uuid: &Uuid) -> Self {
- Self(score_bits, *uuid.as_bytes())
+ pub fn uuid(&self) -> &uuid::Bytes {
+ match self {
+ Self::Relevance(_, uuid) => uuid,
+ Self::U64Field(_, uuid) => uuid,
+ Self::F64Field(_, uuid) => uuid,
+ }
}

pub fn from_bytes(src: &[u8; Self::SIZE]) -> Self {
- let score_bits =
- u64::from_be_bytes(src[0..8].try_into().expect("Slice has correct length"));
- Self(
- score_bits,
- src[8..].try_into().expect("Slice has correct length"),
- )
+ // tag 0 + 0-padding for f32
+ if src[0..5] == [0, 0, 0, 0, 0] {
+ let score = f32::from_be_bytes(src[5..9].try_into().unwrap());
+ Self::Relevance(score, src[9..].try_into().unwrap())
+ } else if src[0] == 1 {
+ let score = u64::from_be_bytes(src[1..9].try_into().unwrap());
+ Self::U64Field(score, src[9..].try_into().unwrap())
+ } else if src[0] == 2 {
+ let score = f64::from_be_bytes(src[1..9].try_into().unwrap());
+ Self::F64Field(score, src[9..].try_into().unwrap())
+ } else {
+ todo!("change interface to Result")
+ }
}

pub fn write_bytes(&self, buf: &mut [u8; Self::SIZE]) {
- buf[0..8].copy_from_slice(&self.0.to_be_bytes());
- buf[8..].copy_from_slice(&self.1[..]);
+ match self {
+ Self::Relevance(score, uuid) => {
+ // tag 0 + 0-padding
+ buf[0..5].copy_from_slice(&[0, 0, 0, 0, 0]);
+ buf[5..9].copy_from_slice(&score.to_be_bytes());
+ buf[9..].copy_from_slice(&uuid[..]);
+ }
+ Self::U64Field(score, uuid) => {
+ buf[0] = 1;
+ buf[1..9].copy_from_slice(&score.to_be_bytes());
+ buf[9..].copy_from_slice(&uuid[..]);
+ }
+ Self::F64Field(score, uuid) => {
+ buf[0] = 2;
+ buf[1..9].copy_from_slice(&score.to_be_bytes());
+ buf[9..].copy_from_slice(&uuid[..]);
+ }
+ }
}
}

-const ENCODED_SEARCH_CURSOR_LEN: usize = 32;
+const ENCODED_SEARCH_CURSOR_LEN: usize = 34;

impl Serialize for SearchCursor {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
@@ -254,13 +287,23

#[test]
fn search_cursor_json_round_trip() {
- for i in 0..100 {
- let cursor = SearchCursor::new(i, &Uuid::new_v4());
-
+ let roundtrip = |cursor| {
let serialized = serde_json::to_string(&cursor).unwrap();
let deserialized = serde_json::from_str(&serialized).unwrap();

assert_eq!(cursor, deserialized);
+ };
+
+ for i in 0..100 {
+ roundtrip(SearchCursor::Relevance(
+ i as f32 * 1.0f32,
+ *Uuid::new_v4().as_bytes(),
+ ));
+ roundtrip(SearchCursor::U64Field(i, *Uuid::new_v4().as_bytes()));
+ roundtrip(SearchCursor::F64Field(
+ i as f64 * 1.0f64,
+ *Uuid::new_v4().as_bytes(),
+ ));
}
}

Modified cantine/tests/index_integration.rs

@@ -55,7 +55,7
let reader = GLOBAL.index.reader()?;
let searcher = reader.searcher();

- let mut after = After::START;
+ let mut after = After::Start;
let mut seen = HashSet::with_capacity(INDEX_SIZE);

loop {
@@ -91,7 +91,7
INDEX_SIZE,
Sort::NumIngredients,
false,
- After::START,
+ After::Start,
)?;

let mut last_num_ingredients = std::u8::MAX;
@@ -116,7 +116,7
INDEX_SIZE,
Sort::ProteinContent,
false,
- After::START,
+ After::Start,
)?;

let mut last_protein = std::f32::MAX;
@@ -144,7 +144,7
INDEX_SIZE,
Sort::InstructionsLength,
true,
- After::START,
+ After::Start,
)?;

let mut last_len = 0;