Skip to content

Load correct libtpu for pytorch & jax. #1166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions tpu/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
ARG BASE_IMAGE_TAG
ARG LIBTPU_IMAGE_TAG
ARG TENSORFLOW_VERSION

FROM gcr.io/cloud-tpu-v2-images/libtpu:${LIBTPU_IMAGE_TAG} as libtpu
FROM gcr.io/kaggle-images/python-tpu-tensorflow-whl:python-${BASE_IMAGE_TAG}-${TENSORFLOW_VERSION} AS tensorflow_whl
FROM gcr.io/kaggle-images/python:${BASE_IMAGE_TAG}

Expand All @@ -12,20 +10,34 @@ ARG TORCH_VERSION

ENV ISTPUVM=1

COPY --from=libtpu /libtpu.so /lib

COPY --from=tensorflow_whl /tmp/tensorflow_pkg/tensorflow*.whl /tmp/tensorflow_pkg/
RUN pip install /tmp/tensorflow_pkg/tensorflow*.whl && \
rm -rf /tmp/tensorflow_pkg && \
/tmp/clean-layer.sh

# LIBTPU installed here:
ENV DEFAULT_LIBTPU=/opt/conda/lib/python3.7/site-packages/libtpu/libtpu.so
ENV PYTORCH_LIBTPU=/opt/conda/lib/python3.7/site-packages/libtpu/torch-libtpu.so
ENV JAX_LIBTPU=/opt/conda/lib/python3.7/site-packages/libtpu/jax-libtpu.so

# https://cloud.google.com/tpu/docs/pytorch-xla-ug-tpu-vm#changing_pytorch_version
RUN pip uninstall -y torch && \
pip install torch==${TORCH_VERSION} && \
# The URL doesn't include patch version. i.e. must use 1.11 instead of 1.11.0
pip install torch_xla[tpuvm] -f https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-${TORCH_VERSION%.*}-cp37-cp37m-linux_x86_64.whl && \
cp $DEFAULT_LIBTPU $PYTORCH_LIBTPU && \
/tmp/clean-layer.sh

# https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm#install_jax_on_your_cloud_tpu_vm
RUN pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html && \
cp $DEFAULT_LIBTPU $JAX_LIBTPU && \
/tmp/clean-layer.sh

# Monkey-patch TF, JAX & PYTORCH to load the correct libtpu.so when they are imported:
RUN sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_LIBRARY_PATH'] = '${PYTORCH_LIBTPU}'|" /opt/conda/lib/python3.7/site-packages/torch_xla/__init__.py && \
sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_LIBRARY_PATH'] = '${JAX_LIBTPU}'|" /opt/conda/lib/python3.7/site-packages/jax/_src/cloud_tpu_init.py && \
sed -i "1s/^/from jax._src.cloud_tpu_init import cloud_tpu_init\ncloud_tpu_init()\n/" /opt/conda/lib/python3.7/site-packages/tensorflow/__init__.py

# Set these env vars so that they don't produce errs calling the metadata server to load them:
ENV TPU_ACCELERATOR_TYPE=v3-8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the session manager set these instead?

This image could eventually be used on TPU v4 or TPU v3 pods (e.g. v3-XYZ).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll stick them in the session-manager as well (which will override the image defaults), but having the defaults helped me release the image quicker.

I doubt we'll be jumping TPU arch's anytime soon though, but I agree with the idea.

ENV TPU_PROCESS_ADDRESSES=local
3 changes: 1 addition & 2 deletions tpu/config.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# TODO(b/213335159): Use ci-pretest for BASE_IMAGE_TAG once stable.
BASE_IMAGE_TAG=v108
LIBTPU_IMAGE_TAG=libtpu_1.1.0_RC00
BASE_IMAGE_TAG=v115
TENSORFLOW_VERSION=2.8.0
TORCH_VERSION=1.11.0