Skip to content

Commit b4fa5d9

Browse files
authored
TF Keras support for TF 1.x (aws#304)
* Introduce TF Tensor class * rename class * Brought in most of keras2 changes to this branch * Add keras code * Add keras method to collection' * Fix naming of reduction tensors * merge tfcollection2 * WIP * Introduce create flag in get_hook * addressed some comments, and fixed get_hook * Address comments * Address comments * Refactored wrap_optimizer logic * Changed default value of create_if_not_exists to False, and unset gradients on each sess begin * Update tf in CI to 1.15 * Add check for ALL when checking if collection should be saved * Fixing more issues with keras support * Fixed the predict bug, as well save_all bug. Also added multiple tests * Add reduction support * Change log message * Add training end test * Cleanup and fix save of metrics * Fix losses and metrics collection should be saved * Add tests for regex patterns * fix import * Check if tensor callback already exists * Add keras tests * fix a reduction test, and remove duplicated test * Fix docs and examples for distributed after wrap_optimizer change, and donot initialize writer in save_for_tensor * Get keras working again * Fix tests, and variable name issues * add another test * Add all working mirrored strategy * All keras tests working on ec2 * Make mirrored strategy work on laptop (no GPU) as well * Add tensorboard args to new hooks * get session hook working for mirrored strategy * add tests for mirrored strategy * update test * add test * Fix check of isinstance * make apt less verbose * install tfdatasets * Change running of tests for debugging * Tensor device map: removed the check of strategy * Mark more tests as slow * trigger CI * add pip freeze * add pip freeze * Upgrade pip so hopefully new keras is installed * pin keras version * pin keras version * run all unit tests * run only keras with debug Signed-off-by: Rahul Huilgol <[email protected]> * all unit tests * run only keras with debug Signed-off-by: Rahul Huilgol <[email protected]> * run only keras with debug Signed-off-by: Rahul Huilgol <[email protected]> * run only keras with debug Signed-off-by: Rahul Huilgol <[email protected]> * empty commit Signed-off-by: Rahul Huilgol <[email protected]> * print step 0 Signed-off-by: Rahul Huilgol <[email protected]> * add all tests Signed-off-by: Rahul Huilgol <[email protected]> * add all tests Signed-off-by: Rahul Huilgol <[email protected]> * rename ts_tensor * remove kwargs Signed-off-by: Rahul Huilgol <[email protected]> * change where collection manager is retrieved * Skip test * trigger CI * reduce logging * dont block stdout
1 parent 15ce7f2 commit b4fa5d9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+3196
-1028
lines changed

bin/sagemaker-containers/tensorflow/tf-train.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import tensorflow as tf
55
import time
66
import uuid
7-
from tornasole.tensorflow import TornasoleHook, TornasoleOptimizer, SaveConfig
7+
from tornasole.tensorflow import TornasoleHook, SaveConfig, get_hook
88

99
parser = argparse.ArgumentParser()
1010
parser.add_argument("--lr", type=float, help="Learning Rate", default=0.001)
@@ -18,6 +18,13 @@
1818
# running in Tf estimator mode, script need to accept --model_dir parameter
1919
parser.add_argument("--model_dir", type=str, help="model dir", default=str(uuid.uuid4()))
2020
args = parser.parse_args()
21+
22+
t = str(time.time())
23+
hook = TornasoleHook(
24+
"s3://tornasolecodebuildtest/container_testing/ts_outputs/tf" + t,
25+
save_config=SaveConfig(save_interval=10),
26+
)
27+
2128
# Network definition
2229
with tf.name_scope("foobar"):
2330
x = tf.placeholder(shape=(None, 2), dtype=tf.float32)
@@ -29,15 +36,10 @@
2936
global_step = tf.Variable(17, name="global_step", trainable=False)
3037
increment_global_step_op = tf.assign(global_step, global_step + 1)
3138
optimizer = tf.train.AdamOptimizer(args.lr)
32-
optimizer = TornasoleOptimizer(optimizer)
39+
optimizer = get_hook().wrap_optimizer(optimizer)
3340
optimizer_op = optimizer.minimize(loss, global_step=increment_global_step_op)
3441
graph = tf.get_default_graph()
3542
list_of_tuples = [op.outputs for op in graph.get_operations()]
36-
t = str(time.time())
37-
hook = TornasoleHook(
38-
"s3://tornasolecodebuildtest/container_testing/ts_outputs/tf" + t,
39-
save_config=SaveConfig(save_interval=10),
40-
)
4143
sess = tf.train.MonitoredSession(hooks=[hook])
4244
for i in range(args.steps):
4345
x_ = np.random.random((10, 2)) * args.scale

config/buildspec.yml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@ phases:
1212
commands:
1313
- . config/change_branch.sh
1414
- su && apt-get update
15-
- apt-get install sudo
16-
- sudo apt-get update
17-
- sudo apt-get install unzip
15+
- apt-get install sudo -qq -o=Dpkg::Use-Pty=0 # silence output: https://askubuntu.com/a/668859/724247
16+
- sudo apt-get update -qq -o=Dpkg::Use-Pty=0
17+
- sudo apt-get install unzip -qq -o=Dpkg::Use-Pty=0
1818
- cd $CODEBUILD_SRC_DIR && chmod +x config/protoc_downloader.sh && ./config/protoc_downloader.sh
19-
- pip install pytest wheel pyYaml pytest-html tensorflow==1.14.0 mxnet torch xgboost pre-commit
19+
- pip install -U pip
20+
- pip install -q pytest wheel pyYaml pytest-html keras==2.3.1 tensorflow==1.15.0 mxnet torch xgboost pre-commit tensorflow_datasets
2021
- pip uninstall -y boto3 && pip uninstall -y aiobotocore && pip uninstall -y botocore
2122

2223
pre_build:

config/tests.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ check_logs() {
1313
}
1414

1515
run_for_framework() {
16-
python -m pytest --html=$REPORT_DIR/report_$1.html --self-contained-html tests/$1
16+
python -m pytest --html=$REPORT_DIR/report_$1.html -v -s --self-contained-html tests/$1
1717
python -m pytest --html=$REPORT_DIR/test_rules_$1.html --self-contained-html -s tests/analysis/integration_testing_rules.py::test_test_rules --mode $1 --path_to_config ./tests/analysis/config.yaml --out_dir $OUT_DIR 2>&1 | tee $REPORT_DIR/test_rules_$1.log
1818
}
1919

2020
export TF_CPP_MIN_LOG_LEVEL=1
21-
export TORNASOLE_LOG_LEVEL=debug
21+
export TORNASOLE_LOG_LEVEL=info
2222
#export BLOCK_STDOUT=TRUE
2323
#export BLOCK_STDERR=FALSE
2424

docs/tensorflow/README.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,12 @@ If you do not specify this, it saves steps under a `GLOBAL` mode.
7777
```
7878
hook.set_mode(ts.modes.TRAIN)
7979
```
80-
Wrap your optimizer with TornasoleOptimizer so that
80+
Wrap your optimizer with wrap_optimizer so that
8181
Tornasole can identify your gradients and automatically
8282
provide these tensors as part of the `gradients` collection.
8383
Use this new optimizer to minimize the loss.
8484
```
85-
optimizer = ts.TornasoleOptimizer(optimizer)
85+
optimizer = hook.wrap_optimizer(optimizer)
8686
```
8787
Create a monitored session with the above hook, and use this for executing your TensorFlow job.
8888
```
@@ -110,12 +110,12 @@ If you do not specify this, it saves steps under a `GLOBAL` mode.
110110
```
111111
hook.set_mode(ts.modes.TRAIN)
112112
```
113-
Wrap your optimizer with TornasoleOptimizer so that
113+
Wrap your optimizer with wrap_optimizer so that
114114
Tornasole can identify your gradients and automatically
115115
provide these tensors as part of the `gradients` collection.
116116
Use this new optimizer to minimize the loss.
117117
```
118-
opt = ts.TornasoleOptimizer(opt)
118+
opt = hook.wrap_optimizer(opt)
119119
```
120120
Now pass this hook to the estimator object's train, predict or evaluate methods, whichever ones you want to monitor.
121121
```
@@ -410,13 +410,13 @@ hook = ts.TornasoleHook(..., include_collections = ['weights'], ...)
410410

411411
#### Gradients
412412
We provide an easy way to populate the collection named `gradients` with the gradients wrt to the weights.
413-
This can be done by wrapping around your optimizer with `TornasoleOptimizer` as follows.
413+
This can be done by wrapping around your optimizer with `wrap_optimizer` as follows.
414414
This will also enable us to access the gradients during analysis without having to identify which tensors out of the saved ones are the gradients.
415415

416416
```
417417
import tornasole.tensorflow as ts
418418
...
419-
opt = ts.TornasoleOptimizer(opt)
419+
opt = hook.wrap_optimizer(opt)
420420
```
421421
An example for this can be seen in [this script](../../examples/tensorflow/scripts/train_imagenet_resnet_hvd.py#L738)
422422
Alternatively, you can refer to [customize collections](#customizing-collections) for
@@ -449,7 +449,7 @@ hook = ts.TornasoleHook(..., include_collections = ['losses'..], ...)
449449

450450
#### Optimizer Variables
451451
Optimizer variables such as momentum can also be saved easily with the
452-
above approach of wrapping your optimizer with `TornasoleOptimizer`
452+
above approach of wrapping your optimizer with `wrap_optimizer`
453453
followed by passing `optimizer_variables` in the `include_collections` parameter of the hook.
454454
```
455455
import tornasole.tensorflow as ts
@@ -461,7 +461,7 @@ Please refer [API](api.md) for more details on using collections
461461
### Customizing collections
462462
You can also create any other customized collection yourself.
463463
You can create new collections as well as modify existing collections
464-
(such as including gradients if you do not want to use the above `TornasoleOptimizer`)
464+
(such as including gradients if you do not want to use the above `wrap_optimizer`)
465465
#### Creating or accessing a collection
466466
Each collection should have a unique name (which is a string).
467467
You can get the collection named as `collection_name` by

docs/tensorflow/examples/distributed_training/horovod_mnist_estimator.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ import tornasole.tensorflow as ts
1313
```
1414
**Saving gradients**
1515

16-
We need to wrap our optimizer with TornasoleOptimizer, and use this optimizer to minimize loss.
16+
We need to wrap our optimizer with hook.wrap_optimizer, and use this optimizer to minimize loss.
1717
This will also enable us to access the gradients during analysis without having to identify which tensors out of the saved ones are the gradients.
1818
```
19-
opt = TornasoleOptimizer(opt)
19+
opt = hook.wrap_optimizer(opt)
2020
```
2121

2222

docs/tensorflow/examples/distributed_training/mirrored_strategy_mnist.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ import tornasole.tensorflow as ts
1515
```
1616
**Saving gradients**
1717

18-
We need to wrap our optimizer with TornasoleOptimizer, and use this optimizer to minimize loss.
18+
We need to wrap our optimizer with hook.wrap_optimizer, and use this optimizer to minimize loss.
1919
This will also enable us to access the gradients during analysis without having to identify which tensors out of the saved ones are the gradients.
2020
```
21-
optimizer = ts.TornasoleOptimizer(optimizer)
21+
optimizer = hook.wrap_optimizer(optimizer)
2222
```
2323

2424

@@ -72,7 +72,7 @@ python mirrored_strategy_mnist.py \
7272
--tornasole_path ~/ts_outputs/mirrored_strategy_mnist \
7373
--steps 5000\
7474
--tornasole_frequency 100\
75-
--reductions False
75+
--reductions False\
7676
--save_all True
7777
7878
```

docs/tensorflow/examples/distributed_training/parameter_server_training/parameter_server_mnist.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ import tornasole.tensorflow as ts
1515
```
1616
**Saving gradients**
1717

18-
We need to wrap our optimizer with TornasoleOptimizer, and use this optimizer to minimize loss.
18+
We need to wrap our optimizer with hook.wrap_optimizer, and use this optimizer to minimize loss.
1919
This will also enable us to access the gradients during analysis without having to identify which tensors out of the saved ones are the gradients.
2020
```
21-
optimizer = ts.TornasoleOptimizer(optimizer)
21+
optimizer = hook.wrap_optimizer(optimizer)
2222
```
2323

2424

docs/tensorflow/examples/mnist.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ import tornasole.tensorflow as ts
1313
```
1414
**Saving gradients**
1515

16-
We need to wrap our optimizer with TornasoleOptimizer, and use this optimizer to minimize loss.
16+
We need to wrap our optimizer with wrap_optimizer, and use this optimizer to minimize loss.
1717
This will also enable us to access the gradients during analysis without having to identify which tensors out of the saved ones are the gradients.
1818
```
19-
opt = TornasoleOptimizer(opt)
20-
optimizer_op = optimizer.minimize(loss, global_step=increment_global_step_op)
19+
opt = hook.wrap_optimizer(opt)
20+
optimizer_op = opt.minimize(loss, global_step=increment_global_step_op)
2121
```
2222
Note that here since by default Tornasole tries to save weights, gradients and losses
2323
we didn't need to specify 'gradients' in the include_collections argument of the hook.

docs/tensorflow/examples/resnet50.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ because by default Tornasole tries to save weights, gradients and losses.
2424

2525
**Saving gradients**
2626

27-
We need to wrap our optimizer with TornasoleOptimizer, and use this optimizer to minimize loss.
27+
We need to wrap our optimizer with wrap_optimizer, and use this optimizer to minimize loss.
2828
This will also enable us to access the gradients during analysis without having to identify which tensors out of the saved ones are the gradients.
2929
```
30-
opt = TornasoleOptimizer(opt)
31-
3230
include_collections.append('gradients')
3331
ts.TornasoleHook(..., include_collections=include_collections, ...)
32+
33+
opt = hook.wrap_optimizer(opt)
3434
```
3535
Note that if include_collections is not passed to the hook,
3636
by default Tornasole tries to save weights, gradients and losses.

docs/tensorflow/examples/simple.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ ts.TornasoleHook(..., save_all=True, ...)
1616
```
1717
**Saving gradients**
1818

19-
We need to wrap our optimizer with TornasoleOptimizer, and use this optimizer to minimize loss.
19+
We need to wrap our optimizer with wrap_optimizer, and use this optimizer to minimize loss.
2020
This will also enable us to access the gradients during analysis without having to identify which tensors out of the saved ones are the gradients.
2121
```
22-
opt = TornasoleOptimizer(opt)
23-
optimizer_op = optimizer.minimize(loss, global_step=increment_global_step_op)
22+
hook = ts.TornasoleHook(..., include_collections=[..,'gradients'], ...)
23+
opt = hook.wrap_optimizer(opt)
24+
optimizer_op = opt.minimize(loss, global_step=increment_global_step_op)
2425
25-
ts.TornasoleHook(..., include_collections=[..,'gradients'], ...)
2626
```
2727
**Saving losses**
2828

docs/tensorflow/examples/sm_resnet50.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ include_collections.append('weights')
2020
```
2121
**Saving gradients**
2222

23-
We need to wrap our optimizer with TornasoleOptimizer, and use this optimizer to minimize loss.
23+
We need to wrap our optimizer with wrap_optimizer, and use this optimizer to minimize loss.
2424
This will also enable us to access the gradients during analysis without having to identify which tensors out of the saved ones are the gradients.
2525
```
26-
opt = TornasoleOptimizer(opt)
26+
opt = hook.wrap_optimizer(opt)
2727
2828
include_collections.append('gradients')
2929
ts.TornasoleHook(..., include_collections=include_collections, ...)

examples/tensorflow/sagemaker-notebooks/tensorflow.ipynb

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,10 @@
8282
"\n",
8383
"```python\n",
8484
"import tornasole.tensorflow as ts\n",
85-
"# Wrap the optimizer with Tornasole optimizer to identify gradients\n",
86-
"optimizer = ts.TornasoleOptimizer(optimizer) \n",
8785
"# Ask TORNASOLE to save all tensors. Note: TornasoleHook is highly configurable\n",
8886
"hook = ts.TornasoleHook(save_all=True) \n",
87+
"# Wrap the optimizer with Tornasole optimizer to identify gradients\n",
88+
"optimizer = hook.wrap_optimizer(optimizer) \n",
8989
"# pass the hook to hooks parameter of monitored session\n",
9090
"sess = tf.train.MonitoredSession(hooks=[hook])\n",
9191
"```\n",
@@ -489,9 +489,9 @@
489489
"```\n",
490490
"hook.set_mode(ts.modes.TRAIN)\n",
491491
"```\n",
492-
"Wrap your optimizer with TornasoleOptimizer so that Tornasole can identify your gradients and automatically provide these tensors as part of the `gradients` collection. Use this new optimizer to minimize your loss during training.\n",
492+
"Wrap your optimizer with wrap_optimizer so that Tornasole can identify your gradients and automatically provide these tensors as part of the `gradients` collection. Use this new optimizer to minimize your loss during training.\n",
493493
"```\n",
494-
"optimizer = ts.TornasoleOptimizer(optimizer)\n",
494+
"optimizer = hook.wrap_optimizer(optimizer)\n",
495495
"```\n",
496496
"Create a monitored session with the above hook, and use this for executing your TensorFlow job.\n",
497497
"```\n",
@@ -519,9 +519,9 @@
519519
"```\n",
520520
"hook.set_mode(ts.modes.TRAIN)\n",
521521
"```\n",
522-
"Wrap your optimizer with TornasoleOptimizer so that Tornasole can identify your gradients and automatically provide these tensors as part of the `gradients` collection. Use this new optimizer to minimize your loss during training.\n",
522+
"Wrap your optimizer with wrap_optimizer so that Tornasole can identify your gradients and automatically provide these tensors as part of the `gradients` collection. Use this new optimizer to minimize your loss during training.\n",
523523
"```\n",
524-
"opt = ts.TornasoleOptimizer(opt)\n",
524+
"opt = hook.wrap_optimizer(opt)\n",
525525
"```\n",
526526
"Now pass this hook to the estimator object's train, predict or evaluate methods, whichever ones you want to monitor.\n",
527527
"```\n",

examples/tensorflow/scripts/distributed_training/horovod_mnist_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def cnn_model_fn(features, labels, mode):
113113
optimizer = hvd.DistributedOptimizer(optimizer)
114114

115115
# Tornasole: add Tornasole Optimizer
116-
optimizer = ts.TornasoleOptimizer(optimizer)
116+
optimizer = ts.get_hook().wrap_optimizer(optimizer)
117117

118118
train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
119119
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

examples/tensorflow/scripts/distributed_training/mirrored_strategy_mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def cnn_model_fn(features, labels, mode):
101101
# Configure the Training Op (for TRAIN mode)
102102
if mode == tf.estimator.ModeKeys.TRAIN:
103103
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
104-
optimizer = ts.TornasoleOptimizer(optimizer)
104+
optimizer = ts.get_hook().wrap_optimizer(optimizer)
105105
train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
106106
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
107107

examples/tensorflow/scripts/distributed_training/parameter_server_training/parameter_server_mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def cnn_model_fn(features, labels, mode):
7171
# Configure the Training Op (for TRAIN mode)
7272
if mode == tf.estimator.ModeKeys.TRAIN:
7373
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
74-
optimizer = ts.TornasoleOptimizer(optimizer)
74+
optimizer = ts.get_hook().wrap_optimizer(optimizer)
7575
train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
7676
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)
7777

examples/tensorflow/scripts/keras.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from __future__ import absolute_import, division, print_function, unicode_literals
2+
3+
import tensorflow_datasets as tfds
4+
import tensorflow as tf
5+
from tornasole.tensorflow import TornasoleKerasHook, get_collection
6+
from tornasole.core.collection import CollectionKeys
7+
8+
tfds.disable_progress_bar()
9+
10+
11+
def train_model():
12+
print(tf.__version__)
13+
14+
datasets, info = tfds.load(name="mnist", with_info=True, as_supervised=True)
15+
16+
mnist_train, mnist_test = datasets["train"], datasets["test"]
17+
18+
strategy = tf.distribute.MirroredStrategy()
19+
20+
# You can also do info.splits.total_num_examples to get the total
21+
# number of examples in the dataset.
22+
23+
num_train_examples = info.splits["train"].num_examples
24+
num_test_examples = info.splits["test"].num_examples
25+
26+
BUFFER_SIZE = 10000
27+
28+
BATCH_SIZE_PER_REPLICA = 64
29+
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
30+
31+
def scale(image, label):
32+
image = tf.cast(image, tf.float32)
33+
image /= 255
34+
35+
return image, label
36+
37+
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
38+
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
39+
40+
hook = TornasoleKerasHook(
41+
out_dir="~/ts_outputs/",
42+
include_collections=[
43+
# CollectionKeys.WEIGHTS,
44+
# CollectionKeys.GRADIENTS,
45+
# CollectionKeys.OPTIMIZER_VARIABLES,
46+
CollectionKeys.DEFAULT,
47+
# CollectionKeys.METRICS,
48+
# CollectionKeys.LOSSES,
49+
# CollectionKeys.OUTPUTS,
50+
# CollectionKeys.SCALARS,
51+
],
52+
save_all=True,
53+
)
54+
55+
with strategy.scope():
56+
model = tf.keras.Sequential(
57+
[
58+
tf.keras.layers.Conv2D(32, 3, activation="relu", input_shape=(28, 28, 1)),
59+
tf.keras.layers.MaxPooling2D(),
60+
tf.keras.layers.Flatten(),
61+
tf.keras.layers.Dense(64, activation="relu"),
62+
tf.keras.layers.Dense(10, activation="softmax"),
63+
]
64+
)
65+
model.compile(
66+
loss="sparse_categorical_crossentropy",
67+
optimizer=hook.wrap_optimizer(tf.keras.optimizers.Adam()),
68+
metrics=["accuracy"],
69+
)
70+
71+
# get_collection('default').include('Relu')
72+
73+
callbacks = [
74+
hook
75+
# tf.keras.callbacks.TensorBoard(log_dir='./logs'),
76+
]
77+
78+
model.fit(train_dataset, epochs=1, callbacks=callbacks)
79+
model.predict(eval_dataset, callbacks=callbacks)
80+
model.fit(train_dataset, epochs=1, callbacks=callbacks)

0 commit comments

Comments
 (0)