@@ -108,7 +108,9 @@ def merge_default_save_config(self, default_save_config):
108
108
self .set_save_config (mode = mode , save_config_mode = SaveConfigMode ())
109
109
110
110
@classmethod
111
- def from_dict (cls , params : Dict [ModeKeys , Any ]) -> "SaveConfig" :
111
+ def from_dict (
112
+ cls , params : Dict [ModeKeys , Any ], default_values : Dict [str , Any ] = None
113
+ ) -> "SaveConfig" :
112
114
"""Parses a dict into a SaveConfig object.
113
115
114
116
Appropriate formats:
@@ -119,12 +121,17 @@ def from_dict(cls, params: Dict[ModeKeys, Any]) -> "SaveConfig":
119
121
"""
120
122
if params is None :
121
123
return None
124
+ if default_values is None :
125
+ default_values = {}
122
126
# Maybe convert strings to enums
123
- if all ([ isinstance (key , str ) for key , value in params .items ()] ):
127
+ if all (isinstance (key , str ) for key , value in params .items ()):
124
128
params = {ModeKeys [key ]: value for key , value in params .items ()}
125
129
# Maybe convert dicts to SaveConfigMode
126
- if all ([value is None or isinstance (value , dict ) for key , value in params .items ()]):
127
- params = {key : SaveConfigMode .from_dict (value ) for key , value in params .items ()}
130
+ if all (value is None or isinstance (value , dict ) for key , value in params .items ()):
131
+ params = {
132
+ key : SaveConfigMode .from_dict (value , default_values )
133
+ for key , value in params .items ()
134
+ }
128
135
return cls (mode_save_configs = params )
129
136
130
137
@classmethod
@@ -171,18 +178,18 @@ def __repr__(self):
171
178
172
179
class SaveConfigMode :
173
180
"""
174
- Wrapping all the save configuration parameters into this object.
175
- This would make it easier to set different save configuration for
176
- different collections and for the base tensors saved.
181
+ Wrapping all the save configuration parameters into this object.
182
+ This would make it easier to set different save configuration for
183
+ different collections and for the base tensors saved.
177
184
178
- This class should not be serialized by itself, only inside of SaveConfig.
185
+ This class should not be serialized by itself, only inside of SaveConfig.
179
186
180
- Parameters:
181
- save_interval (int): Save every n steps.
182
- save_steps (list of int): Save at all the steps given in this list. Overrides save_interval.
183
- start_step (int): Save after n steps.
184
- end_step (int): Stop saving after n steps.
185
- """
187
+ Parameters:
188
+ save_interval (int): Save every n steps.
189
+ save_steps (list of int): Save at all the steps given in this list. Overrides save_interval.
190
+ start_step (int): Save after n steps.
191
+ end_step (int): Stop saving after n steps.
192
+ """
186
193
187
194
def __init__ (
188
195
self ,
@@ -222,16 +229,20 @@ def to_json_dict(self):
222
229
}
223
230
224
231
@classmethod
225
- def from_dict (cls , params : Dict [str , Any ]):
232
+ def from_dict (cls , params : Dict [str , Any ], default_values : Dict [ str , Any ] = None ):
226
233
if params is None :
227
234
return None
228
235
elif not isinstance (params , dict ):
229
236
raise TypeError (f"params={ params } is not a dict." )
237
+ if default_values is None :
238
+ default_values = {}
239
+ elif not isinstance (default_values , dict ):
240
+ raise TypeError (f"default_values={ default_values } is not a dict." )
230
241
return cls (
231
- save_interval = params .get ("save_interval" ),
232
- start_step = params .get ("start_step" ),
233
- end_step = params .get ("end_step" ),
234
- save_steps = params .get ("save_steps" ),
242
+ save_interval = params .get ("save_interval" , default_values . get ( "save_interval" ) ),
243
+ start_step = params .get ("start_step" , default_values . get ( "start_step" ) ),
244
+ end_step = params .get ("end_step" , default_values . get ( "end_step" ) ),
245
+ save_steps = params .get ("save_steps" , default_values . get ( "save_steps" ) ),
235
246
)
236
247
237
248
def __eq__ (self , other ):
0 commit comments