diff options
Diffstat (limited to 'validator_derive/src/lib.rs')
| -rw-r--r-- | validator_derive/src/lib.rs | 118 |
1 files changed, 115 insertions, 3 deletions
diff --git a/validator_derive/src/lib.rs b/validator_derive/src/lib.rs index eaf82e1..92f3ea5 100644 --- a/validator_derive/src/lib.rs +++ b/validator_derive/src/lib.rs @@ -14,9 +14,17 @@ use validator::{Validator}; static RANGE_TYPES: [&'static str; 12] = [ - "usize", "u8", "u16", "u32", "u64", "isize", "i8", "i16", "i32", "i64", "f32", "f64" + "usize", "u8", "u16", "u32", "u64", + "isize", "i8", "i16", "i32", "i64", + "f32", "f64", ]; +#[derive(Debug)] +struct SchemaValidation { + function: String, + skip_on_field_errors: bool, +} + #[proc_macro_derive(Validate, attributes(validate))] pub fn derive_validation(input: TokenStream) -> TokenStream { @@ -107,7 +115,9 @@ fn expand_validation(ast: &syn::MacroInput) -> quote::Tokens { let fn_ident = syn::Ident::new(f.clone()); quote!( match #fn_ident(&self.#field_ident) { - ::std::option::Option::Some(s) => errors.entry(#name.to_string()).or_insert_with(|| vec![]).push(s), + ::std::option::Option::Some(s) => { + errors.entry(#name.to_string()).or_insert_with(|| vec![]).push(s) + }, ::std::option::Option::None => (), }; ) @@ -116,6 +126,25 @@ fn expand_validation(ast: &syn::MacroInput) -> quote::Tokens { } } + let struct_validation = find_struct_validation(&ast.attrs); + let struct_validation_tokens = match struct_validation { + Some(s) => { + let fn_ident = syn::Ident::new(s.function); + let skip_on_field_errors = s.skip_on_field_errors; + quote!( + if !#skip_on_field_errors || #skip_on_field_errors && errors.len() == 0 { + match #fn_ident(self) { + ::std::option::Option::Some((key, val)) => { + errors.entry(key).or_insert_with(|| vec![]).push(val) + }, + ::std::option::Option::None => (), + } + } + ) + }, + None => quote!() + }; + let ident = &ast.ident; let impl_ast = quote!( impl Validate for #ident { @@ -123,7 +152,9 @@ fn expand_validation(ast: &syn::MacroInput) -> quote::Tokens { use std::collections::HashMap; let mut errors = HashMap::new(); - #(#validations)* + #(#validations)* + + #struct_validation_tokens if errors.is_empty() { ::std::result::Result::Ok(()) @@ -137,6 +168,80 @@ fn expand_validation(ast: &syn::MacroInput) -> quote::Tokens { impl_ast } + +/// Find if a struct has some schema validation and returns the info if so +fn find_struct_validation(struct_attrs: &Vec<syn::Attribute>) -> Option<SchemaValidation> { + let error = |msg: &str| -> ! { + panic!("Invalid schema level validation: {}", msg); + }; + + for attr in struct_attrs { + match attr.value { + syn::MetaItem::List(ref ident, ref meta_items) => { + if ident != "validate" { + continue; + } + + match meta_items[0] { + syn::NestedMetaItem::MetaItem(ref item) => match item { + &syn::MetaItem::List(ref ident2, ref args) => { + if ident2 != "schema" { + error("Only `schema` is allowed as validator on a struct") + } + + let mut function = "".to_string(); + let mut skip_on_field_errors = true; + for arg in args { + match *arg { + syn::NestedMetaItem::MetaItem(ref item) => match *item { + syn::MetaItem::NameValue(ref name, ref val) => { + match name.to_string().as_ref() { + "function" => { + function = match lit_to_string(val) { + Some(s) => s, + None => error("invalid argument type for `function` \ + : only a string is allowed"), + }; + }, + "skip_on_field_errors" => { + skip_on_field_errors = match lit_to_bool(val) { + Some(s) => s, + None => error("invalid argument type for `skip_on_field_errors` \ + : only a bool is allowed"), + }; + }, + _ => error("Unknown argument") + } + + }, + _ => error("Unexpected args") + }, + _ => error("Unexpected args") + } + } + + if function == "" { + error("`function` is required"); + } + + return Some(SchemaValidation { + function: function, + skip_on_field_errors: skip_on_field_errors + }); + }, + _ => error("Unexpected struct validator") + }, + _ => error("Unexpected struct validator") + } + }, + _ => error("Unexpected struct validator") + } + } + + None +} + + // Find all the types (as string) for each field of the struct // Needed for the `must_match` filter fn find_fields_type(fields: &Vec<syn::Field>) -> HashMap<String, String> { @@ -433,6 +538,13 @@ fn lit_to_float(lit: &syn::Lit) -> Option<f64> { } } +fn lit_to_bool(lit: &syn::Lit) -> Option<bool> { + match *lit { + syn::Lit::Bool(ref s) => Some(*s), + _ => None, + } +} + fn option_u64_to_tokens(opt: Option<u64>) -> quote::Tokens { let mut tokens = quote::Tokens::new(); tokens.append("::"); |
