Cleanly report code generation errors
- Id
- a59ac0a5ef59a40ac6b4bf826747538431ccaea2
- Author
- Caio
- Commit time
- 2020-02-08T19:45:06+01:00
Modified cantine_derive/internal/src/lib.rs
extern crate proc_macro;
use proc_macro::TokenStream;
-use proc_macro2::TokenStream as TokenStream2;
+use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
use quote::{format_ident, quote, quote_spanned};
use syn::{
parse_macro_input, spanned::Spanned, Data, DeriveInput, Field, Fields, GenericArgument,
#[proc_macro_derive(Filterable)]
pub fn derive_filter_and_agg(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
- let filter_query = make_filter_query(&input);
- TokenStream::from(quote! {
- #filter_query
- })
+ TokenStream::from(
+ parse_public_fields(&input).map_or_else(render_error, |fields| {
+ make_filter_query(&input.ident, &fields)
+ }),
+ )
}
#[proc_macro_derive(Aggregable)]
pub fn derive_agg(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
- let agg_query = make_agg_query(&input);
- let agg_result = make_agg_result(&input);
-
- TokenStream::from(quote! {
- #agg_query
- #agg_result
- })
+ TokenStream::from(
+ parse_public_fields(&input).map_or_else(render_error, |fields| {
+ let agg_query = make_agg_query(&input.ident, &fields);
+ let agg_result = make_agg_result(&input.ident, &fields);
+ quote! {
+ #agg_query
+ #agg_result
+ }
+ }),
+ )
}
-fn make_filter_query(input: &DeriveInput) -> TokenStream2 {
- let feat = &input.ident;
- let name = format_ident!("FilterableFilterQuery{}", &input.ident);
+fn parse_public_fields(input: &DeriveInput) -> Result<Vec<FieldInfo<'_>>, Error> {
+ let fields = get_public_fields(input)?;
- let fields: Vec<_> = get_public_struct_fields(&input).cloned().collect();
+ if fields.is_empty() {
+ Err(Error::BadInput)
+ } else {
+ let mut infos = Vec::with_capacity(fields.len());
+
+ for field in fields.into_iter() {
+ infos.push(FieldInfo::new(field)?);
+ }
+
+ Ok(infos)
+ }
+}
+
+fn get_public_fields(input: &DeriveInput) -> Result<Vec<&Field>, Error> {
+ match input.data {
+ Data::Struct(ref data) => match data.fields {
+ Fields::Named(ref fields) => Ok(fields
+ .named
+ .iter()
+ .filter(|field| match &field.vis {
+ Visibility::Public(_) => true,
+ _ => false,
+ })
+ .collect()),
+ _ => Err(Error::BadInput),
+ },
+ _ => Err(Error::BadInput),
+ }
+}
+
+struct FieldInfo<'a> {
+ span: Span,
+
+ ident: &'a Ident,
+ ty: &'a Type,
+ is_optional: bool,
+
+ schema: FieldType,
+ is_largest: bool,
+}
+
+impl<'a> FieldInfo<'a> {
+ fn new(field: &'a Field) -> Result<Self, Error> {
+ let span = field.span();
+ let ident = field.ident.as_ref().ok_or(Error::BadField(span))?;
+
+ let optional_type = extract_type_if_option(&field.ty);
+ let is_optional = optional_type.is_some();
+ let ty = optional_type.unwrap_or(&field.ty);
+
+ let (schema, is_largest) = get_field_type(&ty).ok_or(Error::BadField(span))?;
+ Ok(Self {
+ span,
+ ident,
+ ty,
+ is_optional,
+ schema,
+ is_largest,
+ })
+ }
+
+ fn span(&self) -> Span {
+ self.span
+ }
+}
+
+fn make_filter_query(feat: &Ident, fields: &[FieldInfo]) -> TokenStream2 {
+ let name = format_ident!("FilterableFilterQuery{}", &feat);
let query_fields = fields.iter().map(|field| {
let name = &field.ident;
- let ty = extract_type_if_option(&field.ty).unwrap_or(&field.ty);
+ let ty = &field.ty;
- quote_spanned! { field.span()=>
+ quote_spanned! { field.span() =>
#[serde(skip_serializing_if = "Option::is_none")]
pub #name: Option<std::ops::Range<#ty>>
}
});
- let index_name = format_ident!("FilterableFilterFields{}", &input.ident);
+ let index_name = format_ident!("FilterableFilterFields{}", &feat);
let index_fields = fields.iter().map(|field| {
- let name = &field.ident;
+ let name = field.ident;
quote_spanned! { field.span()=>
pub #name: tantivy::schema::Field
}
});
let from_decls = fields.iter().map(|field| {
- let name = field.ident.as_ref().unwrap();
+ let name = field.ident;
let field_name = format_ident!("Filterable_field_{}", &name);
let quoted = format!("\"{}\"", field_name);
- let ty = extract_type_if_option(&field.ty).unwrap_or(&field.ty);
- let field_type = get_field_type(&ty);
- let method = match field_type {
+ let method = match field.schema {
FieldType::UNSIGNED => quote!(add_u64_field),
FieldType::SIGNED => quote!(add_i64_field),
FieldType::FLOAT => quote!(add_f64_field),
});
let try_from_decls = fields.iter().map(|field| {
- if let Some(name) = &field.ident {
- let field_name = format_ident!("Filterable_field_{}", &name);
- let err_msg = format!("Missing field for {} ({})", name, field_name);
- let quoted = format!("\"{}\"", field_name);
- quote_spanned! { field.span()=>
- #name: schema.get_field(#quoted).ok_or_else(
- || tantivy::TantivyError::SchemaError(#err_msg.to_string()))?
- }
- } else {
- unreachable!();
+ let name = field.ident;
+ let field_name = format_ident!("Filterable_field_{}", &name);
+ let err_msg = format!("Missing field for {} ({})", name, field_name);
+ let quoted = format!("\"{}\"", field_name);
+ quote_spanned! { field.span()=>
+ #name: schema.get_field(#quoted).ok_or_else(
+ || tantivy::TantivyError::SchemaError(#err_msg.to_string()))?
}
});
let interpret_code = fields.iter().map(|field| {
- let name = &field.ident;
- let ty = extract_type_if_option(&field.ty).unwrap_or(&field.ty);
- let is_largest = is_largest_type(&ty);
- let field_type = get_field_type(&ty);
+ let name = field.ident;
- let (from_code, query_code) = match field_type {
+ let (from_code, query_code) = match field.schema {
FieldType::UNSIGNED => (
quote!(u64::from),
quote!(tantivy::query::RangeQuery::new_u64),
),
};
- let range_code = if is_largest {
+ let range_code = if field.is_largest {
quote! {
let range = rr.clone();
}
});
let add_to_doc_code = fields.iter().map(|field| {
- let name = &field.ident;
+ let name = field.ident;
- let opt_type = extract_type_if_option(&field.ty);
- let is_optional = opt_type.is_some();
-
- let ty = opt_type.unwrap_or(&field.ty);
- let is_largest = is_largest_type(&ty);
-
- let field_type = get_field_type(&ty);
-
- let convert_code = if is_largest {
+ let convert_code = if field.is_largest {
quote_spanned! { field.span()=>
let value = value;
}
} else {
- match field_type {
+ match field.schema {
FieldType::UNSIGNED => quote_spanned! { field.span()=>
let value = u64::from(value);
},
}
};
- let add_code = match field_type {
+ let add_code = match field.schema {
FieldType::UNSIGNED => quote!(doc.add_u64(self.#name, value);),
FieldType::SIGNED => quote!(doc.add_i64(self.#name, value);),
FieldType::FLOAT => quote!(doc.add_f64(self.#name, value);),
};
- if is_optional {
+ if field.is_optional {
quote_spanned! { field.span()=>
if let Some(value) = feat.#name {
#convert_code
}
}
-fn make_agg_query(input: &DeriveInput) -> TokenStream2 {
- let name = format_ident!("AggregableAggregationQuery{}", &input.ident);
+fn make_agg_query(feat: &Ident, fields: &[FieldInfo]) -> TokenStream2 {
+ let name = format_ident!("AggregableAggregationQuery{}", &feat);
- let fields = get_public_struct_fields(&input).map(|field| {
+ let query_fields = fields.iter().map(|field| {
let name = &field.ident;
- let ty = extract_type_if_option(&field.ty).unwrap_or(&field.ty);
+ let ty = &field.ty;
quote_spanned! { field.span()=>
#[serde(default = "Vec::new")]
pub #name: Vec<std::ops::Range<#ty>>
}
});
- let full_range = get_public_struct_fields(&input).map(|field| {
+ let full_range = fields.iter().map(|field| {
let name = &field.ident;
- let ty = extract_type_if_option(&field.ty).unwrap_or(&field.ty);
+ let ty = &field.ty;
quote_spanned! { field.span()=>
#name: vec![std::#ty::MIN..std::#ty::MAX]
}
#[derive(serde::Serialize, serde::Deserialize, Default, Debug, Clone, PartialEq)]
#[serde(deny_unknown_fields)]
pub struct #name {
- #(#fields),*
+ #(#query_fields),*
}
impl #name {
}
}
-fn make_agg_result(input: &DeriveInput) -> TokenStream2 {
- let feature = &input.ident;
- let name = format_ident!("AggregableAggregationResult{}", &input.ident);
+fn make_agg_result(feature: &Ident, fields: &[FieldInfo]) -> TokenStream2 {
+ let name = format_ident!("AggregableAggregationResult{}", &feature);
- let fields = get_public_struct_fields(&input).map(|field| {
+ let agg_fields = fields.iter().map(|field| {
let name = &field.ident;
- let ty = extract_type_if_option(&field.ty).unwrap_or(&field.ty);
+ let ty = &field.ty;
quote_spanned! { field.span()=>
#[serde(skip_serializing_if = "Vec::is_empty")]
}
});
- let merge_code = get_public_struct_fields(&input).map(|field| {
+ let merge_code = fields.iter().map(|field| {
let name = &field.ident;
quote_spanned! { field.span()=>
for (idx, stats) in self.#name.iter_mut().enumerate() {
}
});
- let agg_query = format_ident!("AggregableAggregationQuery{}", &input.ident);
- let convert_code = get_public_struct_fields(&input).map(|field| {
+ let agg_query = format_ident!("AggregableAggregationQuery{}", &feature);
+ let convert_code = fields.iter().map(|field| {
let name = &field.ident;
quote_spanned! { field.span()=>
#name:
}
});
- let collect_code = get_public_struct_fields(&input).map(|field| {
+ let collect_code = fields.iter().map(|field| {
let name = &field.ident;
- if let Some(_type) = extract_type_if_option(&field.ty) {
+ if field.is_optional {
quote_spanned! { field.span()=>
if let Some(feat) = feature.#name {
for (idx, range) in query.#name.iter().enumerate() {
quote! {
#[derive(serde::Serialize, Default, Debug, Clone)]
pub struct #name {
- #(#fields),*
+ #(#agg_fields),*
}
impl cantine_derive::Aggregable for #feature {
}
}
-fn get_public_struct_fields(input: &DeriveInput) -> impl Iterator<Item = &Field> {
- match input.data {
- Data::Struct(ref data) => match data.fields {
- Fields::Named(ref fields) => fields.named.iter().filter(|field| match &field.vis {
- Visibility::Public(_) => true,
- _ => false,
- }),
- _ => unimplemented!(),
- },
- _ => unimplemented!(),
+fn extract_type_if_option(ty: &Type) -> Option<&Type> {
+ match ty {
+ Type::Path(tp) if tp.path.segments.first()?.ident == "Option" => {
+ match tp.path.segments.first()?.arguments {
+ PathArguments::AngleBracketed(ref params) => match params.args.first()? {
+ GenericArgument::Type(ty) => Some(ty),
+ _ => None,
+ },
+ _ => None,
+ }
+ }
+ _ => None,
}
}
FLOAT,
}
-const SUPPORTED_UNSIGNED: [&str; 4] = ["u8", "u16", "u32", "u64"];
-const SUPPORTED_SIGNED: [&str; 4] = ["i8", "i16", "i32", "i64"];
-const SUPPORTED_FLOAT: [&str; 2] = ["f32", "f64"];
+fn get_field_type(ty: &Type) -> Option<(FieldType, bool)> {
+ match ty {
+ Type::Path(tp) if tp.path.segments.len() == 1 => {
+ match tp.path.segments.first()?.ident.to_string().as_str() {
+ "u64" => Some((FieldType::UNSIGNED, true)),
+ "u8" | "u16" | "u32" => Some((FieldType::UNSIGNED, false)),
-const LARGEST_TYPE: [&str; 3] = ["u64", "i64", "f64"];
+ "i64" => Some((FieldType::SIGNED, true)),
+ "i8" | "i16" | "i32" => Some((FieldType::SIGNED, false)),
-fn is_largest_type(ty: &Type) -> bool {
- if let Type::Path(tp) = ty {
- if tp.path.segments.len() == 1 {
- let ident = &tp.path.segments.first().unwrap().ident;
-
- for name in LARGEST_TYPE.iter() {
- if ident == name {
- return true;
- }
+ "f64" => Some((FieldType::FLOAT, true)),
+ "f32" => Some((FieldType::FLOAT, false)),
+ _ => None,
}
}
+ _ => None,
}
- false
}
-fn get_field_type(ty: &Type) -> FieldType {
- if let Type::Path(tp) = ty {
- if tp.path.segments.len() == 1 {
- let ident = &tp.path.segments.first().unwrap().ident;
-
- for name in SUPPORTED_SIGNED.iter() {
- if ident == name {
- return FieldType::SIGNED;
- }
- }
-
- for name in SUPPORTED_UNSIGNED.iter() {
- if ident == name {
- return FieldType::UNSIGNED;
- }
- }
-
- for name in SUPPORTED_FLOAT.iter() {
- if ident == name {
- return FieldType::FLOAT;
- }
- }
- }
- }
- unimplemented!()
+enum Error {
+ BadField(Span),
+ BadInput,
}
-fn extract_type_if_option(ty: &Type) -> Option<&Type> {
- if let Type::Path(tp) = ty {
- if tp.path.segments.len() == 1 && tp.path.segments.first().unwrap().ident == "Option" {
- if let Some(type_params) = tp.path.segments.first() {
- if let PathArguments::AngleBracketed(ref params) = type_params.arguments {
- let generic_arg = params.args.first().unwrap();
- if let GenericArgument::Type(ty) = generic_arg {
- return Some(ty);
- }
- }
+fn render_error(err: Error) -> TokenStream2 {
+ match err {
+ Error::BadField(span) => {
+ quote_spanned! { span =>
+ compile_error!("Unsupported field");
}
}
+ Error::BadInput => panic!("Only structs with public named fields are supported"),
}
- None
}