22
22
import tensorflow_datasets as tfds
23
23
from tests .tensorflow .hooks .test_mirrored_strategy import test_basic
24
24
from tests .tensorflow .keras .test_keras_mirrored import test_tf_keras
25
- from tf_utils import (
25
+
26
+ # First Party
27
+ import smdebug .tensorflow as smd
28
+ from smdebug .core .utils import SagemakerSimulator
29
+
30
+ # Local
31
+ from .tf_utils import (
26
32
get_data ,
27
33
get_estimator ,
28
34
get_input_fns ,
31
37
get_train_op_and_placeholders ,
32
38
)
33
39
34
- # First Party
35
- import smdebug .tensorflow as smd
36
- from smdebug .core .utils import SagemakerSimulator
37
40
38
-
39
- def test_estimator (script_mode : bool ):
41
+ def test_estimator (script_mode : bool = False ):
40
42
""" Works as intended. """
41
43
smd .del_hook ()
42
44
tf .reset_default_graph ()
@@ -134,7 +136,15 @@ def test_estimator_gradients_zcc(nested=False, mirrored=False):
134
136
assert len (trial .modes ()) == 2
135
137
136
138
137
- def test_linear_classifier (script_mode : bool ):
139
+ def test_estimator_gradients_zcc_nested ():
140
+ test_estimator_gradients_zcc (nested = True )
141
+
142
+
143
+ def test_estimator_gradients_zcc_mirrored ():
144
+ test_estimator_gradients_zcc (nested = False , mirrored = True )
145
+
146
+
147
+ def test_linear_classifier (script_mode : bool = False ):
138
148
""" Works as intended. """
139
149
smd .del_hook ()
140
150
tf .reset_default_graph ()
@@ -160,11 +170,20 @@ def test_linear_classifier(script_mode: bool):
160
170
assert len (trial .tensor_names ()) > 0 , "Tensors were not saved."
161
171
162
172
163
- def test_monitored_session (script_mode : bool ):
173
+ def test_monitored_session (script_mode : bool = False ):
164
174
""" Works as intended. """
165
175
smd .del_hook ()
166
176
tf .reset_default_graph ()
167
- with SagemakerSimulator () as sim :
177
+ json_file_contents = """
178
+ {
179
+ "S3OutputPath": "s3://sagemaker-test",
180
+ "LocalPath": "/opt/ml/output/tensors",
181
+ "HookParameters" : {
182
+ "save_interval": "100"
183
+ }
184
+ }
185
+ """
186
+ with SagemakerSimulator (json_file_contents = json_file_contents ) as sim :
168
187
train_op , X , Y = get_train_op_and_placeholders ()
169
188
init = tf .compat .v1 .global_variables_initializer ()
170
189
mnist = get_data ()
@@ -195,6 +214,9 @@ def test_monitored_session_gradients_zcc():
195
214
{
196
215
"S3OutputPath": "s3://sagemaker-test",
197
216
"LocalPath": "/opt/ml/output/tensors",
217
+ "HookParameters" : {
218
+ "save_interval": "100"
219
+ },
198
220
"CollectionConfigurations": [
199
221
{
200
222
"CollectionName": "gradients"
@@ -227,7 +249,7 @@ def test_monitored_session_gradients_zcc():
227
249
assert len (trial .tensor_names (collection = "gradients" )) > 0
228
250
229
251
230
- def test_keras_v1 (script_mode : bool ):
252
+ def test_keras_v1 (script_mode : bool = False ):
231
253
""" Works as intended. """
232
254
smd .del_hook ()
233
255
tf .reset_default_graph ()
@@ -258,7 +280,7 @@ def test_keras_v1(script_mode: bool):
258
280
assert len (trial .tensor_names ()) > 0 , "Tensors were not saved."
259
281
260
282
261
- def test_keras_gradients (script_mode : bool , tf_optimizer : bool = False ):
283
+ def test_keras_gradients (script_mode : bool = False , tf_optimizer : bool = False ):
262
284
""" Works as intended. """
263
285
smd .del_hook ()
264
286
tf .reset_default_graph ()
@@ -320,6 +342,10 @@ def test_keras_gradients(script_mode: bool, tf_optimizer: bool = False):
320
342
assert len (trial .tensor_names (collection = "optimizer_variables" )) > 0
321
343
322
344
345
+ def test_keras_gradients_tf_opt (script_mode : bool = False ):
346
+ test_keras_gradients (script_mode = script_mode , tf_optimizer = True )
347
+
348
+
323
349
def test_keras_gradients_mirrored (include_workers = "one" ):
324
350
""" Works as intended. """
325
351
smd .del_hook ()
@@ -366,7 +392,11 @@ def test_keras_gradients_mirrored(include_workers="one"):
366
392
test_tf_keras ("/opt/ml/output/tensors" , zcc = True , include_workers = include_workers )
367
393
368
394
369
- def test_keras_to_estimator (script_mode : bool ):
395
+ def test_keras_gradients_mirrored_all_workers ():
396
+ test_keras_gradients_mirrored (include_workers = "all" )
397
+
398
+
399
+ def test_keras_to_estimator (script_mode : bool = False ):
370
400
""" Works as intended. """
371
401
import tensorflow .compat .v1 .keras as keras
372
402
@@ -426,14 +456,14 @@ def input_fn():
426
456
test_monitored_session_gradients_zcc ()
427
457
test_estimator (script_mode = script_mode )
428
458
if not script_mode :
429
- test_estimator_gradients_zcc (nested = True )
430
- test_estimator_gradients_zcc ( nested = False )
431
- test_estimator_gradients_zcc ( nested = False , mirrored = True )
459
+ test_estimator_gradients_zcc ()
460
+ test_estimator_gradients_zcc_nested ( )
461
+ test_estimator_gradients_zcc_mirrored ( )
432
462
test_linear_classifier (script_mode = script_mode )
433
463
test_keras_v1 (script_mode = script_mode )
434
464
test_keras_gradients (script_mode = script_mode )
435
- test_keras_gradients (script_mode = script_mode , tf_optimizer = True )
465
+ test_keras_gradients_tf_opt (script_mode = script_mode )
436
466
test_keras_to_estimator (script_mode = script_mode )
437
467
if not script_mode :
438
- test_keras_gradients_mirrored ( include_workers = "all" )
468
+ test_keras_gradients_mirrored_all_workers ( )
439
469
test_keras_gradients_mirrored ()
0 commit comments