Skip to content

Commit b6c9ef1

Browse files
authored
Merge pull request #1166 from Kaggle/fix-libtpu
Load correct libtpu for pytorch & jax.
2 parents 85f7616 + 05c135e commit b6c9ef1

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

tpu/Dockerfile

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
ARG BASE_IMAGE_TAG
2-
ARG LIBTPU_IMAGE_TAG
32
ARG TENSORFLOW_VERSION
43

5-
FROM gcr.io/cloud-tpu-v2-images/libtpu:${LIBTPU_IMAGE_TAG} as libtpu
64
FROM gcr.io/kaggle-images/python-tpu-tensorflow-whl:python-${BASE_IMAGE_TAG}-${TENSORFLOW_VERSION} AS tensorflow_whl
75
FROM gcr.io/kaggle-images/python:${BASE_IMAGE_TAG}
86

@@ -12,20 +10,34 @@ ARG TORCH_VERSION
1210

1311
ENV ISTPUVM=1
1412

15-
COPY --from=libtpu /libtpu.so /lib
16-
1713
COPY --from=tensorflow_whl /tmp/tensorflow_pkg/tensorflow*.whl /tmp/tensorflow_pkg/
1814
RUN pip install /tmp/tensorflow_pkg/tensorflow*.whl && \
1915
rm -rf /tmp/tensorflow_pkg && \
2016
/tmp/clean-layer.sh
2117

18+
# LIBTPU installed here:
19+
ENV DEFAULT_LIBTPU=/opt/conda/lib/python3.7/site-packages/libtpu/libtpu.so
20+
ENV PYTORCH_LIBTPU=/opt/conda/lib/python3.7/site-packages/libtpu/torch-libtpu.so
21+
ENV JAX_LIBTPU=/opt/conda/lib/python3.7/site-packages/libtpu/jax-libtpu.so
22+
2223
# https://cloud.google.com/tpu/docs/pytorch-xla-ug-tpu-vm#changing_pytorch_version
2324
RUN pip uninstall -y torch && \
2425
pip install torch==${TORCH_VERSION} && \
2526
# The URL doesn't include patch version. i.e. must use 1.11 instead of 1.11.0
2627
pip install torch_xla[tpuvm] -f https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-${TORCH_VERSION%.*}-cp37-cp37m-linux_x86_64.whl && \
28+
cp $DEFAULT_LIBTPU $PYTORCH_LIBTPU && \
2729
/tmp/clean-layer.sh
2830

2931
# https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm#install_jax_on_your_cloud_tpu_vm
3032
RUN pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html && \
33+
cp $DEFAULT_LIBTPU $JAX_LIBTPU && \
3134
/tmp/clean-layer.sh
35+
36+
# Monkey-patch TF, JAX & PYTORCH to load the correct libtpu.so when they are imported:
37+
RUN sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_LIBRARY_PATH'] = '${PYTORCH_LIBTPU}'|" /opt/conda/lib/python3.7/site-packages/torch_xla/__init__.py && \
38+
sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_LIBRARY_PATH'] = '${JAX_LIBTPU}'|" /opt/conda/lib/python3.7/site-packages/jax/_src/cloud_tpu_init.py && \
39+
sed -i "1s/^/from jax._src.cloud_tpu_init import cloud_tpu_init\ncloud_tpu_init()\n/" /opt/conda/lib/python3.7/site-packages/tensorflow/__init__.py
40+
41+
# Set these env vars so that they don't produce errs calling the metadata server to load them:
42+
ENV TPU_ACCELERATOR_TYPE=v3-8
43+
ENV TPU_PROCESS_ADDRESSES=local

tpu/config.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# TODO(b/213335159): Use ci-pretest for BASE_IMAGE_TAG once stable.
2-
BASE_IMAGE_TAG=v108
3-
LIBTPU_IMAGE_TAG=libtpu_1.1.0_RC00
2+
BASE_IMAGE_TAG=v115
43
TENSORFLOW_VERSION=2.8.0
54
TORCH_VERSION=1.11.0

0 commit comments

Comments
 (0)