14
14
import tensorflow .compat .v2 as tf
15
15
import tensorflow_datasets as tfds
16
16
from tests .constants import TEST_DATASET_S3_PATH
17
- from tests .tensorflow2 .utils import is_tf_2_2
17
+ from tests .tensorflow2 .utils import is_tf_2_2 , is_tf_2_3
18
18
from tests .tensorflow .utils import create_trial_fast_refresh
19
19
from tests .utils import use_s3_datasets
20
20
@@ -195,7 +195,7 @@ def test_keras_gradtape(out_dir, saveall):
195
195
196
196
trial = smd .create_trial (path = out_dir )
197
197
if saveall : # save losses, metrics, weights, biases
198
- assert len (trial .tensor_names ()) == 15
198
+ assert len (trial .tensor_names ()) == ( 25 if is_tf_2_2 () else 15 )
199
199
assert len (trial .tensor_names (collection = CollectionKeys .BIASES )) == 2
200
200
assert len (trial .tensor_names (collection = CollectionKeys .WEIGHTS )) == 2
201
201
assert len (trial .tensor_names (collection = CollectionKeys .OPTIMIZER_VARIABLES )) == 5
@@ -275,7 +275,7 @@ def test_gradtape_include_regex(out_dir):
275
275
tr = create_trial_fast_refresh (out_dir )
276
276
tnames = tr .tensor_names (collection = "custom_coll" )
277
277
278
- assert len (tnames ) == 8
278
+ assert len (tnames ) == ( 12 if is_tf_2_2 () else 8 )
279
279
for tname in tnames :
280
280
assert tr .tensor (tname ).value (0 ) is not None
281
281
@@ -343,7 +343,7 @@ def test_gradtape_include_collections(out_dir):
343
343
344
344
trial = smd .create_trial (path = out_dir )
345
345
# can't save gradients in TF 2.x
346
- assert len (trial .tensor_names ()) == 15
346
+ assert len (trial .tensor_names ()) == ( 16 if is_tf_2_2 () else 15 )
347
347
assert len (trial .tensor_names (collection = CollectionKeys .GRADIENTS )) == 4
348
348
assert len (trial .tensor_names (collection = CollectionKeys .OPTIMIZER_VARIABLES )) == 5
349
349
assert len (trial .tensor_names (collection = CollectionKeys .BIASES )) == 2
@@ -388,7 +388,7 @@ def test_gradtape_persistent(out_dir, saveall):
388
388
389
389
trial = smd .create_trial (path = out_dir )
390
390
if saveall : # save losses, metrics, weights, biases
391
- assert len (trial .tensor_names ()) == 15
391
+ assert len (trial .tensor_names ()) == ( 25 if is_tf_2_2 () else 15 )
392
392
assert len (trial .tensor_names (collection = CollectionKeys .BIASES )) == 2
393
393
assert len (trial .tensor_names (collection = CollectionKeys .WEIGHTS )) == 2
394
394
assert len (trial .tensor_names (collection = CollectionKeys .OPTIMIZER_VARIABLES )) == 5
@@ -409,17 +409,24 @@ def test_keras_fit(out_dir, tf_eager_mode, saveall):
409
409
helper_keras_fit (
410
410
trial_dir = out_dir ,
411
411
hook = hook ,
412
- eager = tf_eager_mode ,
412
+ run_eagerly = tf_eager_mode ,
413
413
steps = ["train" , "eval" , "predict" , "train" ],
414
414
)
415
415
416
416
trial = smd .create_trial (path = out_dir )
417
417
# can't save gradients in TF 2.x eager mode
418
418
if saveall : # save losses, metrics, weights, biases, scalar
419
419
if tf_eager_mode :
420
- assert len (trial .tensor_names ()) == (13 if is_tf_2_2 () else 14 )
421
- assert len (trial .tensor_names (collection = CollectionKeys .INPUTS )) == 0
422
- assert len (trial .tensor_names (collection = CollectionKeys .OUTPUTS )) == 0
420
+ if is_tf_2_2 ():
421
+ assert len (trial .tensor_names ()) == 28
422
+ else :
423
+ assert len (trial .tensor_names ()) == (21 if is_tf_2_3 () else 14 )
424
+ assert len (trial .tensor_names (collection = CollectionKeys .INPUTS )) == (
425
+ 1 if is_tf_2_2 () else 0
426
+ )
427
+ assert len (trial .tensor_names (collection = CollectionKeys .OUTPUTS )) == (
428
+ 2 if is_tf_2_2 () else 0
429
+ )
423
430
else :
424
431
assert len (trial .tensor_names ()) == 21
425
432
assert len (trial .tensor_names (collection = CollectionKeys .BIASES )) == 2
@@ -435,10 +442,12 @@ def test_keras_fit(out_dir, tf_eager_mode, saveall):
435
442
"No Optimizer Variables Should be Saved in EVAL Mode" ,
436
443
)
437
444
else : # save the default losses and metrics
438
- assert len (trial .tensor_names ()) == (4 if is_tf_2_2 () and tf_eager_mode else 5 )
445
+ assert len (trial .tensor_names ()) == (
446
+ 4 if (is_tf_2_2 () or is_tf_2_3 ()) and tf_eager_mode else 5
447
+ )
439
448
assert len (trial .tensor_names (collection = CollectionKeys .LOSSES )) == 1
440
449
assert len (trial .tensor_names (collection = CollectionKeys .METRICS )) == (
441
- 2 if is_tf_2_2 () and tf_eager_mode else 3
450
+ 2 if ( is_tf_2_2 () or is_tf_2_3 () ) and tf_eager_mode else 3
442
451
)
443
452
for tname in trial .tensor_names ():
444
453
assert trial .tensor (tname ).value (0 ) is not None
@@ -510,7 +519,7 @@ def test_include_regex(out_dir, tf_eager_mode):
510
519
tnames = tr .tensor_names (collection = "custom_coll" )
511
520
512
521
if tf_eager_mode :
513
- assert len (tnames ) == 8
522
+ assert len (tnames ) == ( 12 if is_tf_2_2 () else 8 )
514
523
else :
515
524
assert len (tnames ) == 8
516
525
for tname in tnames :
@@ -534,7 +543,7 @@ def test_clash_with_tb_callback(out_dir):
534
543
add_callbacks = ["tensorboard" ],
535
544
)
536
545
tr = create_trial_fast_refresh (out_dir )
537
- assert len (tr .tensor_names ()) == (7 if is_tf_2_2 () else 8 )
546
+ assert len (tr .tensor_names ()) == (7 if ( is_tf_2_2 () or is_tf_2_3 () ) else 8 )
538
547
539
548
540
549
@pytest .mark .slow
@@ -560,12 +569,12 @@ def test_weights_collections(out_dir, tf_eager_mode):
560
569
561
570
trial = smd .create_trial (path = out_dir )
562
571
# can't save gradients in TF 2.x
563
- assert len (trial .tensor_names ()) == (5 if is_tf_2_2 () and tf_eager_mode else 6 )
572
+ assert len (trial .tensor_names ()) == (5 if ( is_tf_2_2 () or is_tf_2_3 () ) and tf_eager_mode else 6 )
564
573
assert len (trial .tensor_names (collection = CollectionKeys .BIASES )) == 0
565
574
assert len (trial .tensor_names (collection = CollectionKeys .WEIGHTS )) == 2
566
575
assert len (trial .tensor_names (collection = CollectionKeys .LOSSES )) == 1
567
576
assert len (trial .tensor_names (collection = CollectionKeys .METRICS )) == (
568
- 2 if is_tf_2_2 () and tf_eager_mode else 3
577
+ 2 if ( is_tf_2_2 () or is_tf_2_3 () ) and tf_eager_mode else 3
569
578
)
570
579
571
580
@@ -595,7 +604,10 @@ def test_include_collections(out_dir, tf_eager_mode):
595
604
trial = smd .create_trial (path = out_dir )
596
605
# can't save gradients in TF 2.x
597
606
if tf_eager_mode :
598
- assert len (trial .tensor_names ()) == (12 if is_tf_2_2 () else 13 )
607
+ if is_tf_2_2 ():
608
+ assert len (trial .tensor_names ()) == 16
609
+ else :
610
+ assert len (trial .tensor_names ()) == (12 if is_tf_2_3 () else 13 )
599
611
else :
600
612
assert len (trial .tensor_names ()) == 18
601
613
assert len (trial .tensor_names (collection = CollectionKeys .GRADIENTS )) == 4
@@ -605,7 +617,7 @@ def test_include_collections(out_dir, tf_eager_mode):
605
617
assert len (trial .tensor_names (collection = CollectionKeys .WEIGHTS )) == 2
606
618
assert len (trial .tensor_names (collection = CollectionKeys .LOSSES )) == 1
607
619
assert len (trial .tensor_names (collection = CollectionKeys .METRICS )) == (
608
- 2 if is_tf_2_2 () and tf_eager_mode else 3
620
+ 2 if ( is_tf_2_2 () or is_tf_2_3 () ) and tf_eager_mode else 3
609
621
)
610
622
611
623
@@ -625,7 +637,7 @@ def test_include_only_custom_collection(out_dir, tf_eager_mode):
625
637
)
626
638
627
639
trial = smd .create_trial (path = out_dir )
628
- assert len (trial .tensor_names ()) == (8 if is_tf_2_2 () and tf_eager_mode else 9 )
640
+ assert len (trial .tensor_names ()) == (8 if ( is_tf_2_2 () or is_tf_2_3 () ) and tf_eager_mode else 9 )
629
641
assert len (trial .tensor_names (collection = "custom_optimizer_variables" )) == 5
630
642
631
643
@@ -640,12 +652,12 @@ def test_hook_from_json(out_dir, tf_eager_mode, monkeypatch):
640
652
641
653
trial = smd .create_trial (path = out_dir )
642
654
# can't save gradients in TF 2.x
643
- assert len (trial .tensor_names ()) == (5 if is_tf_2_2 () and tf_eager_mode else 6 )
655
+ assert len (trial .tensor_names ()) == (5 if ( is_tf_2_2 () or is_tf_2_3 () ) and tf_eager_mode else 6 )
644
656
assert len (trial .tensor_names (collection = CollectionKeys .BIASES )) == 0
645
657
assert len (trial .tensor_names (collection = CollectionKeys .WEIGHTS )) == 2
646
658
assert len (trial .tensor_names (collection = CollectionKeys .LOSSES )) == 1
647
659
assert len (trial .tensor_names (collection = CollectionKeys .METRICS )) == (
648
- 2 if is_tf_2_2 () and tf_eager_mode else 3
660
+ 2 if ( is_tf_2_2 () or is_tf_2_3 () ) and tf_eager_mode else 3
649
661
)
650
662
651
663
@@ -658,12 +670,15 @@ def test_keras_fit_pure_eager(out_dir, tf_eager_mode):
658
670
helper_keras_fit (trial_dir = out_dir , hook = hook , eager = tf_eager_mode , run_eagerly = True )
659
671
660
672
trial = smd .create_trial (path = out_dir )
661
- assert len (trial .tensor_names ()) == (20 if is_tf_2_2 () else 21 )
673
+ if is_tf_2_2 ():
674
+ assert len (trial .tensor_names ()) == 27
675
+ else :
676
+ assert len (trial .tensor_names ()) == (20 if is_tf_2_3 () else 21 )
662
677
assert len (trial .tensor_names (collection = CollectionKeys .BIASES )) == 2
663
678
assert len (trial .tensor_names (collection = CollectionKeys .WEIGHTS )) == 2
664
679
assert len (trial .tensor_names (collection = CollectionKeys .OPTIMIZER_VARIABLES )) == 5
665
- assert len (trial .tensor_names (collection = CollectionKeys .INPUTS )) == 0
666
- assert len (trial .tensor_names (collection = CollectionKeys .OUTPUTS )) == 0
680
+ assert len (trial .tensor_names (collection = CollectionKeys .INPUTS )) == ( 1 if is_tf_2_2 () else 0 )
681
+ assert len (trial .tensor_names (collection = CollectionKeys .OUTPUTS )) == ( 2 if is_tf_2_2 () else 0 )
667
682
668
683
669
684
@pytest .mark .skip # skip until aws tf update
0 commit comments