Skip to content

Commit 27f9f1f

Browse files
committed
config_type: derive serde automatically
1 parent ef8b2eb commit 27f9f1f

File tree

2 files changed

+20
-101
lines changed

2 files changed

+20
-101
lines changed

config_proc_macro/src/attrs.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,6 @@ pub fn config_value(attr: &syn::Attribute) -> Option<String> {
4040
get_name_value_str_lit(attr, "value")
4141
}
4242

43-
/// Returns `true` if the given attribute is a `value` attribute.
44-
pub fn is_config_value(attr: &syn::Attribute) -> bool {
45-
is_attr_name_value(attr, "value")
46-
}
47-
4843
/// Returns `true` if the given attribute is an `unstable` attribute.
4944
pub fn is_unstable_variant(attr: &syn::Attribute) -> bool {
5045
is_attr_path(attr, "unstable_variant")

config_proc_macro/src/item_enum.rs

Lines changed: 20 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -22,47 +22,42 @@ pub fn define_config_type_on_enum(args: &Args, em: &syn::ItemEnum) -> syn::Resul
2222
let mod_name_str = format!("__define_config_type_on_enum_{}", ident);
2323
let mod_name = syn::Ident::new(&mod_name_str, ident.span());
2424
let variants = fold_quote(variants.iter().map(process_variant), |meta| quote!(#meta,));
25+
let mut has_serde = false;
2526
let derives = [
2627
"std::fmt::Debug",
2728
"std::clone::Clone",
2829
"std::marker::Copy",
2930
"std::cmp::Eq",
3031
"std::cmp::PartialEq",
32+
"serde::Serialize",
33+
"serde::Deserialize",
3134
]
3235
.iter()
3336
.filter(|d| args.skip_derives().all(|s| !d.ends_with(s)))
37+
.inspect(|d| has_serde |= d.contains("serde::"))
3438
.map(|d| syn::parse_str(d).unwrap())
3539
.collect::<Vec<syn::Path>>();
36-
let derives = if derives.is_empty() {
37-
quote!()
38-
} else {
39-
quote! { #[derive( #( #derives ),* )] }
40-
};
40+
let derives = derives
41+
.is_empty()
42+
.then(|| quote!())
43+
.unwrap_or_else(|| quote! { #[derive( #( #derives ),* )] });
44+
let serde_attr = has_serde
45+
.then(|| quote!(#[serde(rename_all = "PascalCase")]))
46+
.unwrap_or_default();
4147

4248
let impl_doc_hint = impl_doc_hint(&em.ident, &em.variants);
4349
let impl_from_str = impl_from_str(&em.ident, &em.variants);
4450
let impl_display = impl_display(&em.ident, &em.variants);
45-
let impl_serialize = if args.skip_derives().any(|s| s == "Serialize") {
46-
quote!()
47-
} else {
48-
impl_serialize(&em.ident, &em.variants)
49-
};
50-
let impl_deserialize = if args.skip_derives().any(|s| s == "Deserialize") {
51-
quote!()
52-
} else {
53-
impl_deserialize(&em.ident, &em.variants)
54-
};
5551

5652
Ok(quote! {
5753
#[allow(non_snake_case)]
5854
mod #mod_name {
5955
#derives
56+
#serde_attr
6057
pub #enum_token #ident #generics { #variants }
6158
#impl_display
6259
#impl_doc_hint
6360
#impl_from_str
64-
#impl_serialize
65-
#impl_deserialize
6661
}
6762
#vis use #mod_name::#ident;
6863
})
@@ -73,8 +68,14 @@ fn process_variant(variant: &syn::Variant) -> TokenStream {
7368
let metas = variant
7469
.attrs
7570
.iter()
76-
.filter(|attr| !is_doc_hint(attr) && !is_config_value(attr) && !is_unstable_variant(attr));
77-
let attrs = fold_quote(metas, |meta| quote!(#meta));
71+
.filter(|attr| !is_doc_hint(attr) && !is_unstable_variant(attr));
72+
let attrs = fold_quote(metas, |meta| {
73+
if let Some(rename) = config_value(meta) {
74+
quote!(#[serde(rename = #rename)])
75+
} else {
76+
quote!(#meta)
77+
}
78+
});
7879
let syn::Variant { ident, fields, .. } = variant;
7980
quote!(#attrs #ident #fields)
8081
}
@@ -188,80 +189,3 @@ fn config_value_of_variant(variant: &syn::Variant) -> String {
188189
fn unstable_of_variant(variant: &syn::Variant) -> bool {
189190
any_unstable_variant(&variant.attrs)
190191
}
191-
192-
fn impl_serialize(ident: &syn::Ident, variants: &Variants) -> TokenStream {
193-
let arms = fold_quote(variants.iter(), |v| {
194-
let v_ident = &v.ident;
195-
let pattern = match v.fields {
196-
syn::Fields::Named(..) => quote!(#ident::v_ident{..}),
197-
syn::Fields::Unnamed(..) => quote!(#ident::#v_ident(..)),
198-
syn::Fields::Unit => quote!(#ident::#v_ident),
199-
};
200-
let option_value = config_value_of_variant(v);
201-
quote! {
202-
#pattern => serializer.serialize_str(&#option_value),
203-
}
204-
});
205-
206-
quote! {
207-
impl ::serde::ser::Serialize for #ident {
208-
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
209-
where
210-
S: ::serde::ser::Serializer,
211-
{
212-
use serde::ser::Error;
213-
match self {
214-
#arms
215-
_ => Err(S::Error::custom(format!("Cannot serialize {:?}", self))),
216-
}
217-
}
218-
}
219-
}
220-
}
221-
222-
// Currently only unit variants are supported.
223-
fn impl_deserialize(ident: &syn::Ident, variants: &Variants) -> TokenStream {
224-
let supported_vs = variants.iter().filter(|v| is_unit(v));
225-
let if_patterns = fold_quote(supported_vs, |v| {
226-
let config_value = config_value_of_variant(v);
227-
let variant_ident = &v.ident;
228-
quote! {
229-
if #config_value.eq_ignore_ascii_case(s) {
230-
return Ok(#ident::#variant_ident);
231-
}
232-
}
233-
});
234-
235-
let supported_vs = variants.iter().filter(|v| is_unit(v));
236-
let allowed = fold_quote(supported_vs.map(config_value_of_variant), |s| quote!(#s,));
237-
238-
quote! {
239-
impl<'de> serde::de::Deserialize<'de> for #ident {
240-
fn deserialize<D>(d: D) -> Result<Self, D::Error>
241-
where
242-
D: serde::Deserializer<'de>,
243-
{
244-
use serde::de::{Error, Visitor};
245-
use std::marker::PhantomData;
246-
use std::fmt;
247-
struct StringOnly<T>(PhantomData<T>);
248-
impl<'de, T> Visitor<'de> for StringOnly<T>
249-
where T: serde::Deserializer<'de> {
250-
type Value = String;
251-
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
252-
formatter.write_str("string")
253-
}
254-
fn visit_str<E>(self, value: &str) -> Result<String, E> {
255-
Ok(String::from(value))
256-
}
257-
}
258-
let s = &d.deserialize_string(StringOnly::<D>(PhantomData))?;
259-
260-
#if_patterns
261-
262-
static ALLOWED: &'static[&str] = &[#allowed];
263-
Err(D::Error::unknown_variant(&s, ALLOWED))
264-
}
265-
}
266-
}
267-
}

0 commit comments

Comments
 (0)