diff options
Diffstat (limited to 'validator_derive')
| -rw-r--r-- | validator_derive/src/lib.rs | 118 | ||||
| -rw-r--r-- | validator_derive/tests/compile-fail/schema/missing_function.rs | 20 | ||||
| -rw-r--r-- | validator_derive/tests/run-pass/custom.rs | 2 | ||||
| -rw-r--r-- | validator_derive/tests/run-pass/email.rs | 2 | ||||
| -rw-r--r-- | validator_derive/tests/run-pass/length.rs | 2 | ||||
| -rw-r--r-- | validator_derive/tests/run-pass/must_match.rs | 2 | ||||
| -rw-r--r-- | validator_derive/tests/run-pass/range.rs | 2 | ||||
| -rw-r--r-- | validator_derive/tests/run-pass/schema.rs | 27 | ||||
| -rw-r--r-- | validator_derive/tests/test_derive.rs | 81 |
9 files changed, 248 insertions, 8 deletions
diff --git a/validator_derive/src/lib.rs b/validator_derive/src/lib.rs index eaf82e1..92f3ea5 100644 --- a/validator_derive/src/lib.rs +++ b/validator_derive/src/lib.rs @@ -14,9 +14,17 @@ use validator::{Validator}; static RANGE_TYPES: [&'static str; 12] = [ - "usize", "u8", "u16", "u32", "u64", "isize", "i8", "i16", "i32", "i64", "f32", "f64" + "usize", "u8", "u16", "u32", "u64", + "isize", "i8", "i16", "i32", "i64", + "f32", "f64", ]; +#[derive(Debug)] +struct SchemaValidation { + function: String, + skip_on_field_errors: bool, +} + #[proc_macro_derive(Validate, attributes(validate))] pub fn derive_validation(input: TokenStream) -> TokenStream { @@ -107,7 +115,9 @@ fn expand_validation(ast: &syn::MacroInput) -> quote::Tokens { let fn_ident = syn::Ident::new(f.clone()); quote!( match #fn_ident(&self.#field_ident) { - ::std::option::Option::Some(s) => errors.entry(#name.to_string()).or_insert_with(|| vec![]).push(s), + ::std::option::Option::Some(s) => { + errors.entry(#name.to_string()).or_insert_with(|| vec![]).push(s) + }, ::std::option::Option::None => (), }; ) @@ -116,6 +126,25 @@ fn expand_validation(ast: &syn::MacroInput) -> quote::Tokens { } } + let struct_validation = find_struct_validation(&ast.attrs); + let struct_validation_tokens = match struct_validation { + Some(s) => { + let fn_ident = syn::Ident::new(s.function); + let skip_on_field_errors = s.skip_on_field_errors; + quote!( + if !#skip_on_field_errors || #skip_on_field_errors && errors.len() == 0 { + match #fn_ident(self) { + ::std::option::Option::Some((key, val)) => { + errors.entry(key).or_insert_with(|| vec![]).push(val) + }, + ::std::option::Option::None => (), + } + } + ) + }, + None => quote!() + }; + let ident = &ast.ident; let impl_ast = quote!( impl Validate for #ident { @@ -123,7 +152,9 @@ fn expand_validation(ast: &syn::MacroInput) -> quote::Tokens { use std::collections::HashMap; let mut errors = HashMap::new(); - #(#validations)* + #(#validations)* + + #struct_validation_tokens if errors.is_empty() { ::std::result::Result::Ok(()) @@ -137,6 +168,80 @@ fn expand_validation(ast: &syn::MacroInput) -> quote::Tokens { impl_ast } + +/// Find if a struct has some schema validation and returns the info if so +fn find_struct_validation(struct_attrs: &Vec<syn::Attribute>) -> Option<SchemaValidation> { + let error = |msg: &str| -> ! { + panic!("Invalid schema level validation: {}", msg); + }; + + for attr in struct_attrs { + match attr.value { + syn::MetaItem::List(ref ident, ref meta_items) => { + if ident != "validate" { + continue; + } + + match meta_items[0] { + syn::NestedMetaItem::MetaItem(ref item) => match item { + &syn::MetaItem::List(ref ident2, ref args) => { + if ident2 != "schema" { + error("Only `schema` is allowed as validator on a struct") + } + + let mut function = "".to_string(); + let mut skip_on_field_errors = true; + for arg in args { + match *arg { + syn::NestedMetaItem::MetaItem(ref item) => match *item { + syn::MetaItem::NameValue(ref name, ref val) => { + match name.to_string().as_ref() { + "function" => { + function = match lit_to_string(val) { + Some(s) => s, + None => error("invalid argument type for `function` \ + : only a string is allowed"), + }; + }, + "skip_on_field_errors" => { + skip_on_field_errors = match lit_to_bool(val) { + Some(s) => s, + None => error("invalid argument type for `skip_on_field_errors` \ + : only a bool is allowed"), + }; + }, + _ => error("Unknown argument") + } + + }, + _ => error("Unexpected args") + }, + _ => error("Unexpected args") + } + } + + if function == "" { + error("`function` is required"); + } + + return Some(SchemaValidation { + function: function, + skip_on_field_errors: skip_on_field_errors + }); + }, + _ => error("Unexpected struct validator") + }, + _ => error("Unexpected struct validator") + } + }, + _ => error("Unexpected struct validator") + } + } + + None +} + + // Find all the types (as string) for each field of the struct // Needed for the `must_match` filter fn find_fields_type(fields: &Vec<syn::Field>) -> HashMap<String, String> { @@ -433,6 +538,13 @@ fn lit_to_float(lit: &syn::Lit) -> Option<f64> { } } +fn lit_to_bool(lit: &syn::Lit) -> Option<bool> { + match *lit { + syn::Lit::Bool(ref s) => Some(*s), + _ => None, + } +} + fn option_u64_to_tokens(opt: Option<u64>) -> quote::Tokens { let mut tokens = quote::Tokens::new(); tokens.append("::"); diff --git a/validator_derive/tests/compile-fail/schema/missing_function.rs b/validator_derive/tests/compile-fail/schema/missing_function.rs new file mode 100644 index 0000000..cacb328 --- /dev/null +++ b/validator_derive/tests/compile-fail/schema/missing_function.rs @@ -0,0 +1,20 @@ +#![feature(proc_macro, attr_literals)] + +#[macro_use] extern crate validator_derive; +extern crate validator; +use validator::Validate; + +#[derive(Validate)] +//~^ ERROR: custom derive attribute panicked +//~^^ HELP: Invalid schema level validation: `function` is required +#[validate(schema())] +struct Test { + s: i32, +} + +fn hey(_: &Test) -> Option<(String, String)> { + None +} + + +fn main() {} diff --git a/validator_derive/tests/run-pass/custom.rs b/validator_derive/tests/run-pass/custom.rs index 205198e..642b2d2 100644 --- a/validator_derive/tests/run-pass/custom.rs +++ b/validator_derive/tests/run-pass/custom.rs @@ -1,4 +1,4 @@ -#![feature(proc_macro, attr_literals)] +#![feature(attr_literals)] #[macro_use] extern crate validator_derive; extern crate validator; diff --git a/validator_derive/tests/run-pass/email.rs b/validator_derive/tests/run-pass/email.rs index edfc357..014c7b8 100644 --- a/validator_derive/tests/run-pass/email.rs +++ b/validator_derive/tests/run-pass/email.rs @@ -1,4 +1,4 @@ -#![feature(proc_macro, attr_literals)] +#![feature(attr_literals)] #[macro_use] extern crate validator_derive; extern crate validator; diff --git a/validator_derive/tests/run-pass/length.rs b/validator_derive/tests/run-pass/length.rs index 01b85ea..1e0d30e 100644 --- a/validator_derive/tests/run-pass/length.rs +++ b/validator_derive/tests/run-pass/length.rs @@ -1,4 +1,4 @@ -#![feature(proc_macro, attr_literals)] +#![feature(attr_literals)] #[macro_use] extern crate validator_derive; extern crate validator; diff --git a/validator_derive/tests/run-pass/must_match.rs b/validator_derive/tests/run-pass/must_match.rs index 0d2d917..c79d20d 100644 --- a/validator_derive/tests/run-pass/must_match.rs +++ b/validator_derive/tests/run-pass/must_match.rs @@ -1,4 +1,4 @@ -#![feature(proc_macro, attr_literals)] +#![feature(attr_literals)] #[macro_use] extern crate validator_derive; extern crate validator; diff --git a/validator_derive/tests/run-pass/range.rs b/validator_derive/tests/run-pass/range.rs index 8f3a047..79e3229 100644 --- a/validator_derive/tests/run-pass/range.rs +++ b/validator_derive/tests/run-pass/range.rs @@ -1,4 +1,4 @@ -#![feature(proc_macro, attr_literals)] +#![feature(attr_literals)] #[macro_use] extern crate validator_derive; extern crate validator; diff --git a/validator_derive/tests/run-pass/schema.rs b/validator_derive/tests/run-pass/schema.rs new file mode 100644 index 0000000..788d1e2 --- /dev/null +++ b/validator_derive/tests/run-pass/schema.rs @@ -0,0 +1,27 @@ +#![feature(attr_literals)] + +#[macro_use] extern crate validator_derive; +extern crate validator; +use validator::Validate; + +#[derive(Validate)] +#[validate(schema(function = "hey"))] +struct Test { + s: String, +} + +fn hey(_: &Test) -> Option<(String, String)> { + None +} + +#[derive(Validate)] +#[validate(schema(function = "hey2", skip_on_field_errors = false))] +struct Test2 { + s: String, +} + +fn hey2(_: &Test2) -> Option<(String, String)> { + None +} + +fn main() {} diff --git a/validator_derive/tests/test_derive.rs b/validator_derive/tests/test_derive.rs index 9a11b79..6832762 100644 --- a/validator_derive/tests/test_derive.rs +++ b/validator_derive/tests/test_derive.rs @@ -9,6 +9,7 @@ use validator::Validate; #[derive(Debug, Validate, Deserialize)] +#[validate(schema(function = "validate_signup", skip_on_field_errors = false))] struct SignupData { #[validate(email)] mail: String, @@ -37,6 +38,44 @@ fn validate_unique_username(username: &str) -> Option<String> { None } +fn validate_signup(data: &SignupData) -> Option<(String, String)> { + if data.mail.ends_with("gmail.com") && data.age == 18 { + return Some(("all".to_string(), "stupid_rule".to_string())); + } + + None +} + +#[derive(Debug, Validate, Deserialize)] +#[validate(schema(function = "validate_signup2", skip_on_field_errors = false))] +struct SignupData2 { + #[validate(email)] + mail: String, + #[validate(range(min = 18, max = 20))] + age: u32, +} + +#[derive(Debug, Validate, Deserialize)] +#[validate(schema(function = "validate_signup3"))] +struct SignupData3 { + #[validate(email)] + mail: String, + #[validate(range(min = 18, max = 20))] + age: u32, +} + +fn validate_signup2(data: &SignupData2) -> Option<(String, String)> { + if data.mail.starts_with("bob") && data.age == 18 { + return Some(("mail".to_string(), "stupid_rule".to_string())); + } + + None +} + +fn validate_signup3(_: &SignupData3) -> Option<(String, String)> { + Some(("mail".to_string(), "stupid_rule".to_string())) +} + #[test] fn test_can_validate_ok() { let signup = SignupData { @@ -161,3 +200,45 @@ fn test_must_match_can_fail() { }; assert!(data.validate().is_err()) } + +#[test] +fn test_can_fail_struct_validation_new_key() { + let signup = SignupData { + mail: "bob@gmail.com".to_string(), + site: "https://hello.com".to_string(), + first_name: "xXxShad0wxXx".to_string(), + age: 18, + }; + let res = signup.validate(); + assert!(res.is_err()); + let errs = res.unwrap_err(); + assert!(errs.contains_key("all")); + assert_eq!(errs["all"], vec!["stupid_rule".to_string()]); +} + +#[test] +fn test_can_fail_struct_validation_existing_key() { + let signup = SignupData2 { + mail: "bob".to_string(), + age: 18, + }; + let res = signup.validate(); + assert!(res.is_err()); + let errs = res.unwrap_err(); + assert!(errs.contains_key("mail")); + assert_eq!(errs["mail"], vec!["email".to_string(), "stupid_rule".to_string()]); +} + +#[test] +fn test_skip_struct_validation_by_default_if_errors() { + let signup = SignupData3 { + mail: "bob".to_string(), + age: 18, + }; + let res = signup.validate(); + assert!(res.is_err()); + let errs = res.unwrap_err(); + assert!(errs.contains_key("mail")); + assert_eq!(errs["mail"], vec!["email".to_string()]); + +} |
