Skip to content

Commit b7d1851

Browse files
authored
Pin tensorflow_probability to work with TF 2.4.1 (#1057)
* Pin tensorflow_probability to work with TF 2.4.1 The latest version of `tensorflow_probability` is not compatible with our version of TF: ``` ImportError: This version of TensorFlow Probability requires TensorFlow version >= 2.5; Detected an installation of version 2.4.1. Please upgrade TensorFlow to proceed``` Pinning & adding a test to prevent regression. http://b/194837139 * import tensorflow
1 parent 20bce43 commit b7d1851

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ RUN pip install seaborn python-dateutil dask python-igraph && \
6262
RUN pip install tensorflow==${TENSORFLOW_VERSION} && \
6363
pip install tensorflow-gcs-config==2.4.0 && \
6464
pip install tensorflow-addons==0.12.1 && \
65+
pip install tensorflow_probability==0.12.2 && \
6566
/tmp/clean-layer.sh
6667

6768
RUN apt-get install -y libfreetype6-dev && \

tests/test_tensorflow_probability.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import unittest
2+
3+
# b/194837139 importing tensorflow before tfp was trigerring an error. Adding this import to prevent regression.
4+
import tensorflow
5+
import tensorflow_probability as tfp
6+
7+
8+
class TestTensorFlowProbability(unittest.TestCase):
9+
def test_distribution(self):
10+
tfd = tfp.distributions
11+
dist = tfd.Bernoulli(logits=[-1, 1, 1])
12+
self.assertEqual('Bernoulli', dist.name)

0 commit comments

Comments
 (0)