Skip to content

Commit 8cbe4e0

Browse files
authored
add patch for tf-df pt1
1 parent 71487b6 commit 8cbe4e0

File tree

1 file changed

+51
-6
lines changed

1 file changed

+51
-6
lines changed

Dockerfile.tmpl

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -202,18 +202,63 @@ RUN apt-get install -y default-jre && \
202202

203203
RUN pip install -f http://h2o-release.s3.amazonaws.com/h2o/latest_stable_Py.html h2o && /tmp/clean-layer.sh
204204

205+
# b/318672158 Use simply tensorflow-probability once > 0.23.0 is released.
205206
RUN pip install \
206207
"tensorflow==${TENSORFLOW_VERSION}" \
207208
"tensorflow-io==${TENSORFLOW_IO_VERSION}" \
208-
tensorflow_decision_forests \
209+
git+https://github.com/tensorflow/probability.git@fbc5ebe9b1d343113fb917010096cfd88b32eecf \
209210
tensorflow_text \
210-
tensorflowjs \
211-
tensorflow_hub && \
211+
"tensorflow_hub>=0.16.0" \
212+
tf-keras && \
212213
/tmp/clean-layer.sh
213214

214-
# TODO(b/318672158): Upgrade to Keras 3 once compatible with other TF libries.
215-
# See blockers here: https://b.corp.google.com/issues/319722433#comment8
216-
RUN pip install keras keras-cv keras-nlp && \
215+
# b/318672158 Use simply tensorflow_decision_forests on next release, expected with tf 2.16
216+
RUN pip install tensorflow_decision_forests --no-deps && \
217+
/tmp/clean-layer.sh
218+
219+
RUN sed -i "/import tensorflow_decision_forests as tfdf/a import tf_keras" /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/__init__.py && \
220+
sed -i -e "/import tensorflow as tf/a import tf_keras" \
221+
-e "/from yggdrasil_decision_forests.utils.distribute.implementations.grpc/a from tensorflow_decision_forests.keras import keras_internal" \
222+
-e '/try:/{:a;N;/backend = tf.keras.backend/!ba;d}'\
223+
/opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/core.py && \
224+
sed -i -e "s/from typing import Optional, List, Dict, Any, Union, NamedTuple/from typing import Any, Dict, List, NamedTuple, Optional, Union/g" \
225+
-e "/import tensorflow as tf/a from tensorflow_decision_forests.keras import keras_internal" \
226+
-e "/import tensorflow as tf/a import tf_keras" \
227+
-e '/layers = tf.keras.layers/{:a;N;/backend = tf.keras.backend/!ba;d}' \
228+
/opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/core_inference.py && \
229+
find /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests -type f -exec sed -i \
230+
-e "s/get_data_handler/keras_internal.get_data_handler/g" \
231+
-e 's/"models.Functional"/keras_internal.Functional/g' \
232+
-e "s/tf.keras.utils.unpack_x_y_sample_weight/keras_internal.unpack_x_y_sample_weight/g" \
233+
-e "s/tf.keras.utils.experimental/keras_internal/g" \
234+
{} \; && \
235+
sed -i -e "/import tensorflow as tf/a import tf_keras" \
236+
-e "/from tensorflow_decision_forests.keras import core/a from tensorflow_decision_forests.keras import keras_internal" \
237+
-e '/layers = tf.keras.layers/{:a;N;/callbacks = tf.keras.callbacks/!ba;d}' \
238+
/opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/keras_test.py && \
239+
find /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras -type f -exec sed -i \
240+
-e "s/ layers.Input/ tf_keras.layers.Input/g" \
241+
-e "s/layers.minimum/tf_keras.layers.minimum/g" \
242+
-e "s/layers.Concatenate/tf_keras.layers.Concatenate/g" \
243+
-e "s/layers.Dense/tf_keras.layers.Dense/g" \
244+
-e "s/layers.experimental.preprocessing./tf_keras.layers./g" \
245+
-e "s/layers.DenseFeatures/keras_internal.layers.DenseFeatures/g" \
246+
-e "s/models.Model/tf_keras.models.Model/g" {} \; && \
247+
sed -i "s/ models.load_model/ tf_keras.models.load_model/g" /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/keras_test.py && \
248+
sed -i "/import tensorflow as tf/a import tf_keras" /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/test_runner.py && \
249+
sed -i "/import tensorflow as tf/a import tf_keras" /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/wrappers.py && \
250+
sed -i -e "/import tensorflow as tf/a import tf_keras" \
251+
-e "s/optimizer=optimizers.Adam()/optimizer=tf_keras.optimizers.Adam()/g" \
252+
/opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/wrappers_pre_generated.py && \
253+
find /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests -type f -exec sed -i "s/tf.keras./tf_keras./g" {} \;
254+
255+
ADD patches/keras_internal.py /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/keras_internal.py
256+
ADD patches/keras_internal_test.py /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/keras_internal_test.py
257+
258+
# Remove "--no-deps" flag and "namex" package once Keras 3.* is included in our base image.
259+
# We ignore dependencies since tf2.15 and Keras 3.* should work despite pip saying it won't.
260+
# Currently, keras tries to install a nightly version of tf 2.16: https://github.com/keras-team/keras/blob/fe2f54aa5bc42fb23a96449cf90434ab9bb6a2cd/requirements.txt#L2
261+
RUN pip install --no-deps "keras>3" keras-cv keras-nlp namex && \
217262
/tmp/clean-layer.sh
218263

219264
RUN pip install pysal

0 commit comments

Comments
 (0)