Skip to content

Commit f3f21de

Browse files
authored
Merge pull request #1304 from Kaggle/tpu-config
Move python version to TPU config
2 parents 0fac236 + 5723345 commit f3f21de

File tree

2 files changed

+23
-14
lines changed

2 files changed

+23
-14
lines changed

tpu/Dockerfile

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1-
FROM python:3.8
1+
ARG BASE_IMAGE
2+
3+
FROM $BASE_IMAGE
24

35
# We need to define the ARG here to get the ARG below the FROM statement to access it within this build context
46
# See: https://docs.docker.com/engine/reference/builder/#understand-how-arg-and-from-interact
7+
ARG PYTHON_WHEEL_VERSION
8+
ARG PYTHON_VERSION_PATH
9+
ARG LINUX_WHEEL_VERSION
510
ARG TORCH_VERSION
611
ARG TENSORFLOW_VERSION
712
ARG TF_LIBTPU_VERSION
@@ -17,19 +22,19 @@ ADD patches/template_conf.json /opt/kaggle/conf.json
1722

1823
# Add BigQuery client proxy settings, kaggle secrets etc.
1924
ENV PYTHONUSERBASE "/root/.local"
20-
ADD patches/kaggle_secrets.py /root/.local/lib/python3.8/site-packages/kaggle_secrets.py
21-
ADD patches/kaggle_session.py /root/.local/lib/python3.8/site-packages/kaggle_session.py
22-
ADD patches/kaggle_web_client.py /root/.local/lib/python3.8/site-packages/kaggle_web_client.py
23-
ADD patches/kaggle_datasets.py /root/.local/lib/python3.8/site-packages/kaggle_datasets.py
25+
ADD patches/kaggle_secrets.py /root/.local/lib/${PYTHON_VERSION_PATH}/site-packages/kaggle_secrets.py
26+
ADD patches/kaggle_session.py /root/.local/lib/${PYTHON_VERSION_PATH}/site-packages/kaggle_session.py
27+
ADD patches/kaggle_web_client.py /root/.local/lib/${PYTHON_VERSION_PATH}/site-packages/kaggle_web_client.py
28+
ADD patches/kaggle_datasets.py /root/.local/lib/${PYTHON_VERSION_PATH}/site-packages/kaggle_datasets.py
2429

2530
# Disable GCP integrations for now
26-
# ADD patches/kaggle_gcp.py /root/.local/lib/python3.8/site-packages/kaggle_gcp.py
31+
# ADD patches/kaggle_gcp.py /root/.local/lib/${PYTHON_VERSION_PATH}/site-packages/kaggle_gcp.py
2732

2833
# Disable logging to file (why do we need this?)
29-
# ADD patches/log.py /root/.local/lib/python3.8/site-packages/log.py
34+
# ADD patches/log.py /root/.local/lib/${PYTHON_VERSION_PATH}/site-packages/log.py
3035

3136
# sitecustomize adds significant latency to ipython kernel startup and should only be added if needed
32-
# ADD patches/sitecustomize.py /root/.local/lib/python3.8/site-packages/sitecustomize.py
37+
# ADD patches/sitecustomize.py /root/.local/lib/${PYTHON_VERSION_PATH}/site-packages/sitecustomize.py
3338

3439
# Install all the packages together for maximum compatibility.
3540

@@ -47,8 +52,8 @@ ADD patches/kaggle_datasets.py /root/.local/lib/python3.8/site-packages/kaggle_d
4752

4853
# Additional useful packages should be added here
4954

50-
RUN pip install tensorflow_hub https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-${TENSORFLOW_VERSION}/tensorflow-${TENSORFLOW_VERSION}-cp38-cp38-linux_x86_64.whl tensorflow-addons tensorflow-probability tensorflow-io \
51-
torch==${TORCH_VERSION} https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-${TORCH_VERSION%.*}-cp38-cp38-linux_x86_64.whl torchvision==${TORCHVISION_VERSION} torchtext==${TORCHTEXT_VERSION} torchaudio==${TORCHAUDIO_VERSION} \
55+
RUN pip install tensorflow_hub https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-${TENSORFLOW_VERSION}/tensorflow-${TENSORFLOW_VERSION}-${PYTHON_WHEEL_VERSION}-${PYTHON_WHEEL_VERSION}-${LINUX_WHEEL_VERSION}.whl tensorflow-addons tensorflow-probability tensorflow-io \
56+
torch==${TORCH_VERSION} https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-${TORCH_VERSION%.*}-${PYTHON_WHEEL_VERSION}-${PYTHON_WHEEL_VERSION}-${LINUX_WHEEL_VERSION}.whl torchvision==${TORCHVISION_VERSION} torchtext==${TORCHTEXT_VERSION} torchaudio==${TORCHAUDIO_VERSION} \
5257
jax[tpu]==${JAX_VERSION} -f https://storage.googleapis.com/jax-releases/libtpu_releases.html trax flax optax elegy git+https://github.com/deepmind/dm-haiku jraph distrax \
5358
numpy==1.23.5 \
5459
papermill jupyterlab python-lsp-server[all] "jupyter-lsp==1.5.1" \
@@ -58,12 +63,12 @@ RUN pip install tensorflow_hub https://storage.googleapis.com/cloud-tpu-tpuvm-ar
5863
RUN curl --output /lib/libtpu.so https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/${TF_LIBTPU_VERSION}/libtpu.so
5964

6065
# Kaggle Model Hub patches:
61-
ADD patches/kaggle_module_resolver.py /usr/local/lib/python3.8/site-packages/tensorflow_hub/kaggle_module_resolver.py
62-
RUN sed -i '/from tensorflow_hub import uncompressed_module_resolver/a from tensorflow_hub import kaggle_module_resolver' /usr/local/lib/python3.8/site-packages/tensorflow_hub/config.py
63-
RUN sed -i '/_install_default_resolvers()/a \ \ registry.resolver.add_implementation(kaggle_module_resolver.KaggleFileResolver())' /usr/local/lib/python3.8/site-packages/tensorflow_hub/config.py
66+
ADD patches/kaggle_module_resolver.py /usr/local/lib/${PYTHON_VERSION_PATH}/site-packages/tensorflow_hub/kaggle_module_resolver.py
67+
RUN sed -i '/from tensorflow_hub import uncompressed_module_resolver/a from tensorflow_hub import kaggle_module_resolver' /usr/local/lib/${PYTHON_VERSION_PATH}/site-packages/tensorflow_hub/config.py
68+
RUN sed -i '/_install_default_resolvers()/a \ \ registry.resolver.add_implementation(kaggle_module_resolver.KaggleFileResolver())' /usr/local/lib/${PYTHON_VERSION_PATH}/site-packages/tensorflow_hub/config.py
6469

6570
# Monkey-patch the default TPU to the local (TPU VM).
66-
RUN sed -i 's/tpu=None,/tpu="local",/' /usr/local/lib/python3.8/site-packages/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py
71+
RUN sed -i 's/tpu=None,/tpu="local",/' /usr/local/lib/${PYTHON_VERSION_PATH}/site-packages/tensorflow/python/distribute/cluster_resolver/tpu/tpu_cluster_resolver.py
6772

6873
# Set these env vars so that they don't produce errs calling the metadata server to load them:
6974
ENV TPU_ACCELERATOR_TYPE=v3-8

tpu/config.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
BASE_IMAGE=python:3.8
2+
PYTHON_WHEEL_VERSION=cp38
3+
PYTHON_VERSION_PATH=python3.8
14
# https://cloud.google.com/tpu/docs/supported-tpu-configurations#tpu_software_versions:~:text=TensorFlow%20version-,libtpu.so%20version,-2.13.0
25
TENSORFLOW_VERSION=2.12.0
36
TF_LIBTPU_VERSION=1.6.0
@@ -10,3 +13,4 @@ TORCHAUDIO_VERSION=2.0.0
1013
TORCHTEXT_VERSION=0.15.1
1114
# https://github.com/pytorch/vision supports nightly
1215
TORCHVISION_VERSION=0.15.1
16+
LINUX_WHEEL_VERSION=linux_x86_64

0 commit comments

Comments
 (0)