Skip to content

Commit 9c637d1

Browse files
authored
Merge pull request #1363 from Kaggle/upgrade-keras-3
Upgrade keras 3
2 parents 71487b6 + 8bcef0f commit 9c637d1

File tree

5 files changed

+107
-24
lines changed

5 files changed

+107
-24
lines changed

Dockerfile.tmpl

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ ENV KMP_SETTINGS=false
6262
ENV PIP_ROOT_USER_ACTION=ignore
6363

6464
ADD clean-layer.sh /tmp/clean-layer.sh
65+
ADD patches/keras_patch.sh /tmp/keras_patch.sh
6566
ADD patches/nbconvert-extensions.tpl /opt/kaggle/nbconvert-extensions.tpl
6667
ADD patches/template_conf.json /opt/kaggle/conf.json
6768

@@ -202,18 +203,30 @@ RUN apt-get install -y default-jre && \
202203

203204
RUN pip install -f http://h2o-release.s3.amazonaws.com/h2o/latest_stable_Py.html h2o && /tmp/clean-layer.sh
204205

206+
# b/318672158 Use simply tensorflow-probability once > 0.23.0 is released.
205207
RUN pip install \
206208
"tensorflow==${TENSORFLOW_VERSION}" \
207209
"tensorflow-io==${TENSORFLOW_IO_VERSION}" \
208-
tensorflow_decision_forests \
210+
git+https://github.com/tensorflow/probability.git@fbc5ebe9b1d343113fb917010096cfd88b32eecf \
209211
tensorflow_text \
210-
tensorflowjs \
211-
tensorflow_hub && \
212+
"tensorflow_hub>=0.16.0" \
213+
tf-keras && \
212214
/tmp/clean-layer.sh
213215

214-
# TODO(b/318672158): Upgrade to Keras 3 once compatible with other TF libries.
215-
# See blockers here: https://b.corp.google.com/issues/319722433#comment8
216-
RUN pip install keras keras-cv keras-nlp && \
216+
# b/318672158 Use simply tensorflow_decision_forests on next release, expected with tf 2.16
217+
RUN pip install tensorflow_decision_forests --no-deps && \
218+
/tmp/clean-layer.sh
219+
220+
RUN chmod +x /tmp/keras_patch.sh && \
221+
/tmp/keras_patch.sh
222+
223+
ADD patches/keras_internal.py /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/keras_internal.py
224+
ADD patches/keras_internal_test.py /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/keras_internal_test.py
225+
226+
# Remove "--no-deps" flag and "namex" package once Keras 3.* is included in our base image.
227+
# We ignore dependencies since tf2.15 and Keras 3.* should work despite pip saying it won't.
228+
# Currently, keras tries to install a nightly version of tf 2.16: https://github.com/keras-team/keras/blob/fe2f54aa5bc42fb23a96449cf90434ab9bb6a2cd/requirements.txt#L2
229+
RUN pip install --no-deps "keras>3" keras-cv keras-nlp namex && \
217230
/tmp/clean-layer.sh
218231

219232
RUN pip install pysal

patches/keras_internal.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright 2021 Google LLC.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Access to Keras function with a different internal and external path."""
16+
17+
from tf_keras.src.engine import data_adapter as _data_adapter
18+
from tf_keras.src.models import Functional
19+
from tf_keras.layers import DenseFeatures
20+
from tf_keras.src.utils.dataset_creator import DatasetCreator
21+
22+
23+
unpack_x_y_sample_weight = _data_adapter.unpack_x_y_sample_weight
24+
get_data_handler = _data_adapter.get_data_handler

patches/keras_internal_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2021 Google LLC.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import tensorflow as tf
16+
from tensorflow_decision_forests.keras import keras_internal
17+
18+
19+
# Does nothing. Ensures keras_internal can be loaded.
20+
21+
if __name__ == "__main__":
22+
tf.test.main()
23+

patches/keras_patch.sh

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#!/bin/bash
2+
3+
# The following "sed" are to patch the current version of tf-df with
4+
# a fix for keras 3. In essence, replaces the use of package name "tf.keras" with
5+
# "tf_keras"
6+
7+
sed -i "/import tensorflow_decision_forests as tfdf/a import tf_keras" /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/__init__.py && \
8+
sed -i -e "/import tensorflow as tf/a import tf_keras" \
9+
-e "/from yggdrasil_decision_forests.utils.distribute.implementations.grpc/a from tensorflow_decision_forests.keras import keras_internal" \
10+
-e '/try:/{:a;N;/backend = tf.keras.backend/!ba;d}'\
11+
/opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/core.py && \
12+
sed -i -e "s/from typing import Optional, List, Dict, Any, Union, NamedTuple/from typing import Any, Dict, List, NamedTuple, Optional, Union/g" \
13+
-e "/import tensorflow as tf/a from tensorflow_decision_forests.keras import keras_internal" \
14+
-e "/import tensorflow as tf/a import tf_keras" \
15+
-e '/layers = tf.keras.layers/{:a;N;/backend = tf.keras.backend/!ba;d}' \
16+
/opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/core_inference.py && \
17+
find /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests -type f -exec sed -i \
18+
-e "s/get_data_handler/keras_internal.get_data_handler/g" \
19+
-e 's/"models.Functional"/keras_internal.Functional/g' \
20+
-e "s/tf.keras.utils.unpack_x_y_sample_weight/keras_internal.unpack_x_y_sample_weight/g" \
21+
-e "s/tf.keras.utils.experimental/keras_internal/g" \
22+
{} \; && \
23+
sed -i -e "/import tensorflow as tf/a import tf_keras" \
24+
-e "/from tensorflow_decision_forests.keras import core/a from tensorflow_decision_forests.keras import keras_internal" \
25+
-e '/layers = tf.keras.layers/{:a;N;/callbacks = tf.keras.callbacks/!ba;d}' \
26+
/opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/keras_test.py && \
27+
find /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras -type f -exec sed -i \
28+
-e "s/ layers.Input/ tf_keras.layers.Input/g" \
29+
-e "s/layers.minimum/tf_keras.layers.minimum/g" \
30+
-e "s/layers.Concatenate/tf_keras.layers.Concatenate/g" \
31+
-e "s/layers.Dense/tf_keras.layers.Dense/g" \
32+
-e "s/layers.experimental.preprocessing./tf_keras.layers./g" \
33+
-e "s/layers.DenseFeatures/keras_internal.layers.DenseFeatures/g" \
34+
-e "s/models.Model/tf_keras.models.Model/g" {} \; && \
35+
sed -i "s/ models.load_model/ tf_keras.models.load_model/g" /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/keras_test.py && \
36+
sed -i "/import tensorflow as tf/a import tf_keras" /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/test_runner.py && \
37+
sed -i "/import tensorflow as tf/a import tf_keras" /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/wrappers.py && \
38+
sed -i -e "/import tensorflow as tf/a import tf_keras" \
39+
-e "s/optimizer=optimizers.Adam()/optimizer=tf_keras.optimizers.Adam()/g" \
40+
/opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/wrappers_pre_generated.py && \
41+
find /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests -type f -exec sed -i "s/tf.keras./tf_keras./g" {} \;

tests/test_tensorflowjs.py

Lines changed: 0 additions & 18 deletions
This file was deleted.

0 commit comments

Comments
 (0)