Skip to content

Commit d7d0794

Browse files
authored
Move Tensorflow to a wheel, install ecosystem packages
Instead of installing tensorflow directly, which is incompatible with the other packages, we'll simply provide the tensorflow TPU wheel in the image. Also add ecosystem packages for Pytorch & JAX, and add pandas (we can add more popular packages as needed). http://b/213335159
1 parent 0528d46 commit d7d0794

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

tpu/Dockerfile

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,26 +10,28 @@ ENV ISTPUVM=1
1010
ADD patches/nbconvert-extensions.tpl /opt/kaggle/nbconvert-extensions.tpl
1111
ADD patches/template_conf.json /opt/kaggle/conf.json
1212

13-
# Tensorflow install:
14-
RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-${TENSORFLOW_VERSION}/tensorflow-${TENSORFLOW_VERSION}-cp38-cp38-linux_x86_64.whl && \
15-
curl --output /lib/libtpu.so https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.3.0/libtpu.so
13+
# Tensorflow wheel:
14+
# When tensorflow is compatible with being installed alongside JAX/Pytorch then we no longer need to include the wheel and can install it directly.
15+
# RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-${TENSORFLOW_VERSION}/tensorflow-${TENSORFLOW_VERSION}-cp38-cp38-linux_x86_64.whl
16+
RUN mkdir -p /lib/wheels && curl --output /lib/wheels/tensorflow-${TENSORFLOW_VERSION}-cp38-cp38-linux_x86_64.whl https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-${TENSORFLOW_VERSION}/tensorflow-${TENSORFLOW_VERSION}-cp38-cp38-linux_x86_64.whl
17+
RUN curl --output /lib/libtpu.so https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.3.0/libtpu.so
1618

1719
# LIBTPU installed here:
1820
ENV PIP_LIBTPU=/usr/local/lib/python3.8/site-packages/libtpu/libtpu.so
1921
ENV DEFAULT_LIBTPU=/lib/libtpu.so
2022
ENV PYTORCH_LIBTPU=/lib/torch-libtpu.so
2123
ENV JAX_LIBTPU=/lib/jax-libtpu.so
2224

25+
# Install JAX & related packages
2326
# https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm#install_jax_on_your_cloud_tpu_vm
24-
RUN pip install "jax[tpu]==0.3.10" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
27+
RUN pip install "jax[tpu]==0.3.10" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html trax flax optax elegy git+https://github.com/deepmind/dm-haiku jraph distrax
2528

2629
RUN cp $PIP_LIBTPU $JAX_LIBTPU
2730

31+
# Install Pytorch & related packages
2832
# https://cloud.google.com/tpu/docs/pytorch-xla-ug-tpu-vm#changing_pytorch_version
29-
RUN pip install torch==${TORCH_VERSION}
30-
3133
# The URL doesn't include patch version. i.e. must use 1.11 instead of 1.11.0
32-
RUN pip install torch_xla[tpuvm] -f https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-${TORCH_VERSION%.*}-cp38-cp38-linux_x86_64.whl
34+
RUN pip install torch==${TORCH_VERSION} torch_xla[tpuvm] -f https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-${TORCH_VERSION%.*}-cp38-cp38-linux_x86_64.whl torchvision==0.12.0 torchtext==0.12.0 torchaudio==0.11.0
3335

3436
RUN cp $PIP_LIBTPU $PYTORCH_LIBTPU
3537

@@ -40,6 +42,9 @@ RUN sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_L
4042
# Packages needed by the Notebook editor:
4143
RUN pip install papermill jupyterlab python-lsp-server[all] jupyterlab-lsp
4244

45+
# Additional useful packages should be added here:
46+
RUN pip install pandas
47+
4348
# Set these env vars so that they don't produce errs calling the metadata server to load them:
4449
ENV TPU_ACCELERATOR_TYPE=v3-8
4550
ENV TPU_PROCESS_ADDRESSES=local

0 commit comments

Comments
 (0)