@@ -22,47 +22,42 @@ pub fn define_config_type_on_enum(args: &Args, em: &syn::ItemEnum) -> syn::Resul
22
22
let mod_name_str = format ! ( "__define_config_type_on_enum_{}" , ident) ;
23
23
let mod_name = syn:: Ident :: new ( & mod_name_str, ident. span ( ) ) ;
24
24
let variants = fold_quote ( variants. iter ( ) . map ( process_variant) , |meta| quote ! ( #meta, ) ) ;
25
+ let mut has_serde = false ;
25
26
let derives = [
26
27
"std::fmt::Debug" ,
27
28
"std::clone::Clone" ,
28
29
"std::marker::Copy" ,
29
30
"std::cmp::Eq" ,
30
31
"std::cmp::PartialEq" ,
32
+ "serde::Serialize" ,
33
+ "serde::Deserialize" ,
31
34
]
32
35
. iter ( )
33
36
. filter ( |d| args. skip_derives ( ) . all ( |s| !d. ends_with ( s) ) )
37
+ . inspect ( |d| has_serde |= d. contains ( "serde::" ) )
34
38
. map ( |d| syn:: parse_str ( d) . unwrap ( ) )
35
39
. collect :: < Vec < syn:: Path > > ( ) ;
36
- let derives = if derives. is_empty ( ) {
37
- quote ! ( )
38
- } else {
39
- quote ! { #[ derive( #( #derives ) , * ) ] }
40
- } ;
40
+ let derives = derives
41
+ . is_empty ( )
42
+ . then ( || quote ! ( ) )
43
+ . unwrap_or_else ( || quote ! { #[ derive( #( #derives ) , * ) ] } ) ;
44
+ let serde_attr = has_serde
45
+ . then ( || quote ! ( #[ serde( rename_all = "PascalCase" ) ] ) )
46
+ . unwrap_or_default ( ) ;
41
47
42
48
let impl_doc_hint = impl_doc_hint ( & em. ident , & em. variants ) ;
43
49
let impl_from_str = impl_from_str ( & em. ident , & em. variants ) ;
44
50
let impl_display = impl_display ( & em. ident , & em. variants ) ;
45
- let impl_serialize = if args. skip_derives ( ) . any ( |s| s == "Serialize" ) {
46
- quote ! ( )
47
- } else {
48
- impl_serialize ( & em. ident , & em. variants )
49
- } ;
50
- let impl_deserialize = if args. skip_derives ( ) . any ( |s| s == "Deserialize" ) {
51
- quote ! ( )
52
- } else {
53
- impl_deserialize ( & em. ident , & em. variants )
54
- } ;
55
51
56
52
Ok ( quote ! {
57
53
#[ allow( non_snake_case) ]
58
54
mod #mod_name {
59
55
#derives
56
+ #serde_attr
60
57
pub #enum_token #ident #generics { #variants }
61
58
#impl_display
62
59
#impl_doc_hint
63
60
#impl_from_str
64
- #impl_serialize
65
- #impl_deserialize
66
61
}
67
62
#vis use #mod_name:: #ident;
68
63
} )
@@ -73,8 +68,14 @@ fn process_variant(variant: &syn::Variant) -> TokenStream {
73
68
let metas = variant
74
69
. attrs
75
70
. iter ( )
76
- . filter ( |attr| !is_doc_hint ( attr) && !is_config_value ( attr) && !is_unstable_variant ( attr) ) ;
77
- let attrs = fold_quote ( metas, |meta| quote ! ( #meta) ) ;
71
+ . filter ( |attr| !is_doc_hint ( attr) && !is_unstable_variant ( attr) ) ;
72
+ let attrs = fold_quote ( metas, |meta| {
73
+ if let Some ( rename) = config_value ( meta) {
74
+ quote ! ( #[ serde( rename = #rename) ] )
75
+ } else {
76
+ quote ! ( #meta)
77
+ }
78
+ } ) ;
78
79
let syn:: Variant { ident, fields, .. } = variant;
79
80
quote ! ( #attrs #ident #fields)
80
81
}
@@ -188,80 +189,3 @@ fn config_value_of_variant(variant: &syn::Variant) -> String {
188
189
fn unstable_of_variant ( variant : & syn:: Variant ) -> bool {
189
190
any_unstable_variant ( & variant. attrs )
190
191
}
191
-
192
- fn impl_serialize ( ident : & syn:: Ident , variants : & Variants ) -> TokenStream {
193
- let arms = fold_quote ( variants. iter ( ) , |v| {
194
- let v_ident = & v. ident ;
195
- let pattern = match v. fields {
196
- syn:: Fields :: Named ( ..) => quote ! ( #ident:: v_ident{ ..} ) ,
197
- syn:: Fields :: Unnamed ( ..) => quote ! ( #ident:: #v_ident( ..) ) ,
198
- syn:: Fields :: Unit => quote ! ( #ident:: #v_ident) ,
199
- } ;
200
- let option_value = config_value_of_variant ( v) ;
201
- quote ! {
202
- #pattern => serializer. serialize_str( & #option_value) ,
203
- }
204
- } ) ;
205
-
206
- quote ! {
207
- impl :: serde:: ser:: Serialize for #ident {
208
- fn serialize<S >( & self , serializer: S ) -> Result <S :: Ok , S :: Error >
209
- where
210
- S : :: serde:: ser:: Serializer ,
211
- {
212
- use serde:: ser:: Error ;
213
- match self {
214
- #arms
215
- _ => Err ( S :: Error :: custom( format!( "Cannot serialize {:?}" , self ) ) ) ,
216
- }
217
- }
218
- }
219
- }
220
- }
221
-
222
- // Currently only unit variants are supported.
223
- fn impl_deserialize ( ident : & syn:: Ident , variants : & Variants ) -> TokenStream {
224
- let supported_vs = variants. iter ( ) . filter ( |v| is_unit ( v) ) ;
225
- let if_patterns = fold_quote ( supported_vs, |v| {
226
- let config_value = config_value_of_variant ( v) ;
227
- let variant_ident = & v. ident ;
228
- quote ! {
229
- if #config_value. eq_ignore_ascii_case( s) {
230
- return Ok ( #ident:: #variant_ident) ;
231
- }
232
- }
233
- } ) ;
234
-
235
- let supported_vs = variants. iter ( ) . filter ( |v| is_unit ( v) ) ;
236
- let allowed = fold_quote ( supported_vs. map ( config_value_of_variant) , |s| quote ! ( #s, ) ) ;
237
-
238
- quote ! {
239
- impl <' de> serde:: de:: Deserialize <' de> for #ident {
240
- fn deserialize<D >( d: D ) -> Result <Self , D :: Error >
241
- where
242
- D : serde:: Deserializer <' de>,
243
- {
244
- use serde:: de:: { Error , Visitor } ;
245
- use std:: marker:: PhantomData ;
246
- use std:: fmt;
247
- struct StringOnly <T >( PhantomData <T >) ;
248
- impl <' de, T > Visitor <' de> for StringOnly <T >
249
- where T : serde:: Deserializer <' de> {
250
- type Value = String ;
251
- fn expecting( & self , formatter: & mut fmt:: Formatter <' _>) -> fmt:: Result {
252
- formatter. write_str( "string" )
253
- }
254
- fn visit_str<E >( self , value: & str ) -> Result <String , E > {
255
- Ok ( String :: from( value) )
256
- }
257
- }
258
- let s = & d. deserialize_string( StringOnly :: <D >( PhantomData ) ) ?;
259
-
260
- #if_patterns
261
-
262
- static ALLOWED : & ' static [ & str ] = & [ #allowed] ;
263
- Err ( D :: Error :: unknown_variant( & s, ALLOWED ) )
264
- }
265
- }
266
- }
267
- }
0 commit comments