Skip to content

Commit 3cba809

Browse files
committed
bump tpu jax/torch versions, add tpu-info
1 parent 9de1142 commit 3cba809

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

tpu/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ ARG TORCH_LINUX_WHEEL_VERSION
1111
ARG TORCH_VERSION
1212
ARG TENSORFLOW_VERSION
1313
ARG TF_LIBTPU_VERSION
14-
ARG JAX_VERSION
1514
ARG TORCHVISION_VERSION
1615
ARG TORCHAUDIO_VERSION
1716

@@ -67,6 +66,7 @@ RUN envsubst < /kaggle_requirements.txt > /requirements.txt
6766
RUN curl -LsSf https://astral.sh/uv/install.sh | sh
6867
RUN export PATH="${HOME}/.local/bin:${PATH}" && uv pip install --system -r /requirements.txt --prerelease=allow --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html && \
6968
/tmp/clean-layer.sh
69+
ENV PATH="${HOME}/.local/bin:${PATH}"
7070

7171
# Tensorflow libtpu:
7272
RUN curl --output /usr/local/lib/python3.10/site-packages/libtpu/libtpu.so https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/${TF_LIBTPU_VERSION}/libtpu.so

tpu/config.txt

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@ PYTHON_VERSION_PATH=python3.10
66
TENSORFLOW_VERSION=2.16.1
77
TF_LIBTPU_VERSION=1.10.1
88
TF_LINUX_WHEEL_VERSION=manylinux_2_17_x86_64.manylinux2014_x86_64
9-
JAX_VERSION=0.4.23
10-
# gsutil ls gs://pytorch-xla-releases/wheels/tpuvm/* | grep libtpu | grep -v -E ".*rc[0-9].*"
9+
# gsutil ls gs://pytorch-xla-releases/wheels/tpuvm/* | grep libtpu | grep torch_xla | grep -v -E ".*rc[0-9].*" | sed 's/.*torch_xla-\(.*\)+libtpu.*/\1/' | sort -rV
1110
# Supports nightly
12-
TORCH_VERSION=2.4.0
11+
TORCH_VERSION=2.5.0
1312
# https://github.com/pytorch/audio supports nightly
14-
TORCHAUDIO_VERSION=2.4.0
13+
TORCHAUDIO_VERSION=2.5.0
1514
# https://github.com/pytorch/vision supports nightly
16-
TORCHVISION_VERSION=0.19.0
15+
TORCHVISION_VERSION=0.20.0
1716
TORCH_LINUX_WHEEL_VERSION=manylinux_2_28_x86_64

tpu/requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# TPU Utils
2+
tpu-info
13
# Tensorflow packages
24
https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-${TENSORFLOW_VERSION}/tensorflow-${TENSORFLOW_VERSION}-${PYTHON_WHEEL_VERSION}-${PYTHON_WHEEL_VERSION}-${TF_LINUX_WHEEL_VERSION}.whl
35
tensorflow_hub
@@ -9,7 +11,7 @@ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TOR
911
torchaudio==${TORCHAUDIO_VERSION}
1012
torchvision==${TORCHVISION_VERSION}
1113
# Jax packages
12-
jax[tpu]==${JAX_VERSION}
14+
jax[tpu]>=0.4.34
1315
distrax
1416
flax
1517
git+https://github.com/deepmind/dm-haiku

0 commit comments

Comments
 (0)