You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It turns out jax & pytorch are incompatible (require different libtpu
versions). In order to support importing EITHER of them (but not both)
we will swap in the correct libtpu during import (by monkey-patching the
import code for both).
http://b/213335159
RUN pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html && \
33
+
cp $DEFAULT_LIBTPU $JAX_LIBTPU && \
31
34
/tmp/clean-layer.sh
35
+
36
+
# Monkey-patch JAX & PYTORCH to load the correct libtpu.so when they are imported:
37
+
RUN sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_LIBRARY_PATH'] = '${PYTORCH_LIBTPU}'|" /opt/conda/lib/python3.7/site-packages/torch_xla/__init__.py && \
38
+
sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_LIBRARY_PATH'] = '${JAX_LIBTPU}'|" /opt/conda/lib/python3.7/site-packages/jax/_src/cloud_tpu_init.py
0 commit comments