Skip to content

Move python version to TPU config #1304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions tpu/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
FROM python:3.8
ARG BASE_IMAGE

FROM $BASE_IMAGE

# We need to define the ARG here to get the ARG below the FROM statement to access it within this build context
# See: https://docs.docker.com/engine/reference/builder/#understand-how-arg-and-from-interact
ARG PYTHON_WHEEL_VERSION
ARG PYTHON_VERSION_PATH
ARG LINUX_WHEEL_VERSION
ARG TORCH_VERSION
ARG TENSORFLOW_VERSION
ARG TF_LIBTPU_VERSION
Expand All @@ -17,19 +22,19 @@ ADD patches/template_conf.json /opt/kaggle/conf.json

# Add BigQuery client proxy settings, kaggle secrets etc.
ENV PYTHONUSERBASE "/root/.local"
ADD patches/kaggle_secrets.py /root/.local/lib/python3.8/site-packages/kaggle_secrets.py
ADD patches/kaggle_session.py /root/.local/lib/python3.8/site-packages/kaggle_session.py
ADD patches/kaggle_web_client.py /root/.local/lib/python3.8/site-packages/kaggle_web_client.py
ADD patches/kaggle_datasets.py /root/.local/lib/python3.8/site-packages/kaggle_datasets.py
ADD patches/kaggle_secrets.py /root/.local/lib/${PYTHON_VERSION_PATH}/site-packages/kaggle_secrets.py
ADD patches/kaggle_session.py /root/.local/lib/${PYTHON_VERSION_PATH}/site-packages/kaggle_session.py
ADD patches/kaggle_web_client.py /root/.local/lib/${PYTHON_VERSION_PATH}/site-packages/kaggle_web_client.py
ADD patches/kaggle_datasets.py /root/.local/lib/${PYTHON_VERSION_PATH}/site-packages/kaggle_datasets.py

# Disable GCP integrations for now
# ADD patches/kaggle_gcp.py /root/.local/lib/python3.8/site-packages/kaggle_gcp.py
# ADD patches/kaggle_gcp.py /root/.local/lib/${PYTHON_VERSION_PATH}/site-packages/kaggle_gcp.py

# Disable logging to file (why do we need this?)
# ADD patches/log.py /root/.local/lib/python3.8/site-packages/log.py
# ADD patches/log.py /root/.local/lib/${PYTHON_VERSION_PATH}/site-packages/log.py

# sitecustomize adds significant latency to ipython kernel startup and should only be added if needed
# ADD patches/sitecustomize.py /root/.local/lib/python3.8/site-packages/sitecustomize.py
# ADD patches/sitecustomize.py /root/.local/lib/${PYTHON_VERSION_PATH}/site-packages/sitecustomize.py

# Install all the packages together for maximum compatibility.

Expand All @@ -47,8 +52,8 @@ ADD patches/kaggle_datasets.py /root/.local/lib/python3.8/site-packages/kaggle_d

# Additional useful packages should be added here

RUN pip install tensorflow_hub https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-${TENSORFLOW_VERSION}/tensorflow-${TENSORFLOW_VERSION}-cp38-cp38-linux_x86_64.whl tensorflow-addons tensorflow-probability tensorflow-io \
torch==${TORCH_VERSION} https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-${TORCH_VERSION%.*}-cp38-cp38-linux_x86_64.whl torchvision==${TORCHVISION_VERSION} torchtext==${TORCHTEXT_VERSION} torchaudio==${TORCHAUDIO_VERSION} \
RUN pip install tensorflow_hub https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-${TENSORFLOW_VERSION}/tensorflow-${TENSORFLOW_VERSION}-${PYTHON_WHEEL_VERSION}-${PYTHON_WHEEL_VERSION}-${LINUX_WHEEL_VERSION}.whl tensorflow-addons tensorflow-probability tensorflow-io \
torch==${TORCH_VERSION} https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-${TORCH_VERSION%.*}-${PYTHON_WHEEL_VERSION}-${PYTHON_WHEEL_VERSION}-${LINUX_WHEEL_VERSION}.whl torchvision==${TORCHVISION_VERSION} torchtext==${TORCHTEXT_VERSION} torchaudio==${TORCHAUDIO_VERSION} \
jax[tpu]==${JAX_VERSION} -f https://storage.googleapis.com/jax-releases/libtpu_releases.html trax flax optax elegy git+https://github.com/deepmind/dm-haiku jraph distrax \
numpy==1.23.5 \
papermill jupyterlab python-lsp-server[all] "jupyter-lsp==1.5.1" \
Expand All @@ -58,12 +63,12 @@ RUN pip install tensorflow_hub https://storage.googleapis.com/cloud-tpu-tpuvm-ar
RUN curl --output /lib/libtpu.so https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/${TF_LIBTPU_VERSION}/libtpu.so

# Kaggle Model Hub patches:
ADD patches/kaggle_module_resolver.py /usr/local/lib/python3.8/site-packages/tensorflow_hub/kaggle_module_resolver.py
RUN sed -i '/from tensorflow_hub import uncompressed_module_resolver/a from tensorflow_hub import kaggle_module_resolver' /usr/local/lib/python3.8/site-packages/tensorflow_hub/config.py
RUN sed -i '/_install_default_resolvers()/a \ \ registry.resolver.add_implementation(kaggle_module_resolver.KaggleFileResolver())' /usr/local/lib/python3.8/site-packages/tensorflow_hub/config.py
ADD patches/kaggle_module_resolver.py /usr/local/lib/${PYTHON_VERSION_PATH}/site-packages/tensorflow_hub/kaggle_module_resolver.py
RUN sed -i '/from tensorflow_hub import uncompressed_module_resolver/a from tensorflow_hub import kaggle_module_resolver' /usr/local/lib/${PYTHON_VERSION_PATH}/site-packages/tensorflow_hub/config.py
RUN sed -i '/_install_default_resolvers()/a \ \ registry.resolver.add_implementation(kaggle_module_resolver.KaggleFileResolver())' /usr/local/lib/${PYTHON_VERSION_PATH}/site-packages/tensorflow_hub/config.py

# Monkey-patch the default TPU to the local (TPU VM).
RUN sed -i 's/tpu=None,/tpu="local",/' /usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py
RUN sed -i 's/tpu=None,/tpu="local",/' /usr/local/lib/${PYTHON_VERSION_PATH}/site-packages/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py

# Set these env vars so that they don't produce errs calling the metadata server to load them:
ENV TPU_ACCELERATOR_TYPE=v3-8
Expand Down
4 changes: 4 additions & 0 deletions tpu/config.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
BASE_IMAGE=python:3.8
PYTHON_WHEEL_VERSION=cp38
PYTHON_VERSION_PATH=python3.8
# https://cloud.google.com/tpu/docs/supported-tpu-configurations#tpu_software_versions:~:text=TensorFlow%20version-,libtpu.so%20version,-2.13.0
TENSORFLOW_VERSION=2.12.0
TF_LIBTPU_VERSION=1.6.0
Expand All @@ -10,3 +13,4 @@ TORCHAUDIO_VERSION=2.0.0
TORCHTEXT_VERSION=0.15.1
# https://github.com/pytorch/vision supports nightly
TORCHVISION_VERSION=0.15.1
LINUX_WHEEL_VERSION=linux_x86_64