Skip to content

Commit 774f8dc

Browse files
authored
TPU VM prevent tensorflow/numpy downgrades
1 parent 210373d commit 774f8dc

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tpu/Dockerfile

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ ENV PYTORCH_LIBTPU=/lib/torch-libtpu.so
4646
# https://cloud.google.com/tpu/docs/pytorch-xla-ug-tpu-vm#changing_pytorch_version
4747
# The URL doesn't include patch version. i.e. must use 1.11 instead of 1.11.0
4848
# We need to keep the numpy version the same as the installed tf one but compatible with other installs.
49-
RUN pip install numpy==1.24.2 torch==${TORCH_VERSION} torch_xla[tpuvm] -f https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-${TORCH_VERSION%.*}-cp38-cp38-linux_x86_64.whl torchvision==${TORCHVISION_VERSION} torchtext==${TORCHTEXT_VERSION} torchaudio==${TORCHAUDIO_VERSION}
49+
RUN pip install torch==${TORCH_VERSION} torch_xla[tpuvm] -f https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-${TORCH_VERSION%.*}-cp38-cp38-linux_x86_64.whl torchvision==${TORCHVISION_VERSION} torchtext==${TORCHTEXT_VERSION} torchaudio==${TORCHAUDIO_VERSION} \
50+
https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-${TENSORFLOW_VERSION}/tensorflow-${TENSORFLOW_VERSION}-cp38-cp38-linux_x86_64.whl \
51+
numpy==1.23.5
5052

5153
RUN cp $PIP_LIBTPU $PYTORCH_LIBTPU
5254

@@ -56,7 +58,9 @@ RUN sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_L
5658
# Install JAX & related packages
5759
# https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm#install_jax_on_your_cloud_tpu_vm
5860
# We need to keep the numpy version the same as the installed tf one but compatible with other installs.
59-
RUN pip install numpy==1.24.2 jax[tpu]==${JAX_VERSION} -f https://storage.googleapis.com/jax-releases/libtpu_releases.html trax flax optax elegy git+https://github.com/deepmind/dm-haiku jraph distrax
61+
RUN pip install jax[tpu]==${JAX_VERSION} -f https://storage.googleapis.com/jax-releases/libtpu_releases.html trax flax optax elegy git+https://github.com/deepmind/dm-haiku jraph distrax \
62+
https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-${TENSORFLOW_VERSION}/tensorflow-${TENSORFLOW_VERSION}-cp38-cp38-linux_x86_64.whl \
63+
numpy==1.23.5
6064

6165
# Packages needed by the Notebook editor:
6266
RUN pip install papermill jupyterlab python-lsp-server[all] jupyterlab-lsp

0 commit comments

Comments
 (0)