Skip to content

Commit 05c135e

Browse files
committed
monkeypatch tf and add env vars to suppress warns
1 parent a54fc96 commit 05c135e

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

tpu/Dockerfile

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ RUN pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-release
3333
cp $DEFAULT_LIBTPU $JAX_LIBTPU && \
3434
/tmp/clean-layer.sh
3535

36-
# Monkey-patch JAX & PYTORCH to load the correct libtpu.so when they are imported:
36+
# Monkey-patch TF, JAX & PYTORCH to load the correct libtpu.so when they are imported:
3737
RUN sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_LIBRARY_PATH'] = '${PYTORCH_LIBTPU}'|" /opt/conda/lib/python3.7/site-packages/torch_xla/__init__.py && \
38-
sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_LIBRARY_PATH'] = '${JAX_LIBTPU}'|" /opt/conda/lib/python3.7/site-packages/jax/_src/cloud_tpu_init.py
38+
sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_LIBRARY_PATH'] = '${JAX_LIBTPU}'|" /opt/conda/lib/python3.7/site-packages/jax/_src/cloud_tpu_init.py && \
39+
sed -i "1s/^/from jax._src.cloud_tpu_init import cloud_tpu_init\ncloud_tpu_init()\n/" /opt/conda/lib/python3.7/site-packages/tensorflow/__init__.py
40+
41+
# Set these env vars so that they don't produce errs calling the metadata server to load them:
42+
ENV TPU_ACCELERATOR_TYPE=v3-8
43+
ENV TPU_PROCESS_ADDRESSES=local

0 commit comments

Comments
 (0)