Skip to content

Commit 9587f69

Browse files
authored
TF 2.16 with Torch 2.4.0 (#1415)
Upgraded our base cpu and gpu base image -updated torch eco. -removed patch and pins placed for tf 2.15 -removed torch text, it is no longer maintained and incompatible with torch 2.4.0 -geopanda depreciated it's datasets methods, test needed to be updated
1 parent 3a9e7ed commit 9587f69

File tree

10 files changed

+29
-118
lines changed

10 files changed

+29
-118
lines changed

Dockerfile.tmpl

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ ARG GPU_BASE_IMAGE_NAME
55
ARG LIGHTGBM_VERSION
66
ARG TORCH_VERSION
77
ARG TORCHAUDIO_VERSION
8-
ARG TORCHTEXT_VERSION
98
ARG TORCHVISION_VERSION
109
ARG JAX_VERSION
1110

@@ -38,16 +37,15 @@ RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/lib
3837
{{ end }}
3938

4039
# Keep these variables in sync if base image is updated.
41-
ENV TENSORFLOW_VERSION=2.15.0
40+
ENV TENSORFLOW_VERSION=2.16.1
4241
# See https://github.com/tensorflow/io#tensorflow-version-compatibility
43-
ENV TENSORFLOW_IO_VERSION=0.35.0
42+
ENV TENSORFLOW_IO_VERSION=0.37.0
4443

4544
# We need to redefine the ARG here to get the ARG value defined above the FROM instruction.
4645
# See: https://docs.docker.com/engine/reference/builder/#understand-how-arg-and-from-interact
4746
ARG LIGHTGBM_VERSION
4847
ARG TORCH_VERSION
4948
ARG TORCHAUDIO_VERSION
50-
ARG TORCHTEXT_VERSION
5149
ARG TORCHVISION_VERSION
5250
ARG JAX_VERSION
5351

@@ -62,7 +60,6 @@ ENV KMP_SETTINGS=false
6260
ENV PIP_ROOT_USER_ACTION=ignore
6361

6462
ADD clean-layer.sh /tmp/clean-layer.sh
65-
ADD patches/keras_patch.sh /tmp/keras_patch.sh
6663
ADD patches/nbconvert-extensions.tpl /opt/kaggle/nbconvert-extensions.tpl
6764
ADD patches/template_conf.json /opt/kaggle/conf.json
6865

@@ -122,21 +119,20 @@ RUN pip install spacy && \
122119
{{ end}}
123120

124121
# Install PyTorch
122+
# b/356397043: magma-cuda121 is the latest version
125123
{{ if eq .Accelerator "gpu" }}
126124
COPY --from=torch_whl /tmp/whl/*.whl /tmp/torch/
127-
RUN mamba install -y -c pytorch magma-cuda${CUDA_MAJOR_VERSION}${CUDA_MINOR_VERSION} && \
125+
RUN mamba install -y -c pytorch magma-cuda121 && \
128126
pip install /tmp/torch/*.whl && \
129-
# b/255757999 openmp (libomp.so) is an dependency of libtorchtext and libtorchaudio but
130-
mamba install -y openmp && \
127+
sudo apt -y install libsox-dev && \
131128
rm -rf /tmp/torch && \
132129
/tmp/clean-layer.sh
133130
{{ else }}
134131
RUN pip install \
135132
torch==$TORCH_VERSION+cpu \
136133
torchvision==$TORCHVISION_VERSION+cpu \
137134
torchaudio==$TORCHAUDIO_VERSION+cpu \
138-
torchtext==$TORCHTEXT_VERSION \
139-
-f https://download.pytorch.org/whl/torch_stable.html && \
135+
--index-url https://download.pytorch.org/whl/cpu && \
140136
/tmp/clean-layer.sh
141137
{{ end }}
142138

@@ -199,32 +195,22 @@ RUN apt-get update && \
199195

200196
RUN pip install -f http://h2o-release.s3.amazonaws.com/h2o/latest_stable_Py.html h2o && /tmp/clean-layer.sh
201197

202-
# b/318672158 Use simply tensorflow-probability once > 0.23.0 is released.
203198
RUN pip install \
204199
"tensorflow==${TENSORFLOW_VERSION}" \
205200
"tensorflow-io==${TENSORFLOW_IO_VERSION}" \
206-
git+https://github.com/tensorflow/probability.git@fbc5ebe9b1d343113fb917010096cfd88b32eecf \
207-
tensorflow_text \
201+
tensorflow-probability \
202+
tensorflow_decision_forests \
203+
tensorflow-text \
208204
"tensorflow_hub>=0.16.0" \
209205
# b/331799280 remove once other packages over to dm-tre
210206
optree \
211207
tf-keras && \
212208
/tmp/clean-layer.sh
213209

214-
# b/318672158 Use simply tensorflow_decision_forests on next release, expected with tf 2.16
215-
RUN pip install tensorflow_decision_forests==1.8.1 --no-deps && \
216-
/tmp/clean-layer.sh
217-
218-
RUN chmod +x /tmp/keras_patch.sh && \
219-
/tmp/keras_patch.sh
220-
221210
ADD patches/keras_internal.py /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/keras_internal.py
222211
ADD patches/keras_internal_test.py /opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/keras_internal_test.py
223212

224-
# Remove "--no-deps" flag and "namex" package once Keras 3.* is included in our base image.
225-
# We ignore dependencies since tf2.15 and Keras 3.* should work despite pip saying it won't.
226-
# Currently, keras tries to install a nightly version of tf 2.16: https://github.com/keras-team/keras/blob/fe2f54aa5bc42fb23a96449cf90434ab9bb6a2cd/requirements.txt#L2
227-
RUN pip install --no-deps "keras>3" keras-cv keras-nlp namex && \
213+
RUN pip install "keras>3" keras-cv keras-nlp && \
228214
/tmp/clean-layer.sh
229215

230216
# b/328788268 libpysal 4.10 seems to fail with "module 'shapely' has no attribute 'Geometry'. Did you mean: 'geometry'"

Jenkinsfile

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ pipeline {
3636
--package torch \
3737
--version $TORCH_VERSION \
3838
--build-arg TORCHAUDIO_VERSION=$TORCHAUDIO_VERSION \
39-
--build-arg TORCHTEXT_VERSION=$TORCHTEXT_VERSION \
4039
--build-arg TORCHVISION_VERSION=$TORCHVISION_VERSION \
4140
--build-arg CUDA_MAJOR_VERSION=$CUDA_MAJOR_VERSION \
4241
--build-arg CUDA_MINOR_VERSION=$CUDA_MINOR_VERSION \

config.txt

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
BASE_IMAGE_REPO=gcr.io/deeplearning-platform-release
2-
BASE_IMAGE_TAG=m114
3-
CPU_BASE_IMAGE_NAME=tf2-cpu.2-15.py310
4-
GPU_BASE_IMAGE_NAME=tf2-gpu.2-15.py310
2+
BASE_IMAGE_TAG=m122
3+
CPU_BASE_IMAGE_NAME=tf2-cpu.2-16.py310
4+
GPU_BASE_IMAGE_NAME=tf2-gpu.2-16.py310
55
LIGHTGBM_VERSION=4.2.0
6-
TORCH_VERSION=2.1.2
7-
TORCHAUDIO_VERSION=2.1.2
8-
TORCHTEXT_VERSION=0.16.2
9-
TORCHVISION_VERSION=0.16.2
6+
TORCH_VERSION=2.4.0
7+
TORCHAUDIO_VERSION=2.4.0
8+
TORCHVISION_VERSION=0.19.0
109
JAX_VERSION=0.4.26
1110
CUDA_MAJOR_VERSION=12
12-
CUDA_MINOR_VERSION=1
11+
CUDA_MINOR_VERSION=3

packages/jaxlib.Dockerfile

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@ ENV LIBRARY_PATH="$LIBRARY_PATH:/opt/conda/lib"
1515
ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/opt/conda/lib"
1616

1717
# Instructions: https://jax.readthedocs.io/en/latest/developer.html#building-jaxlib-from-source
18-
RUN apt-get update && \
19-
apt-get install -y g++ python python3-dev
18+
RUN sudo ln -s /usr/bin/python3 /usr/bin/python
19+
20+
RUN apt-get update && \
21+
apt-get install -y g++ python3 python3-dev
2022

2123
RUN pip install numpy wheel build
2224

packages/torch.Dockerfile

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ FROM ${BASE_IMAGE} AS builder
44

55
ARG PACKAGE_VERSION
66
ARG TORCHAUDIO_VERSION
7-
ARG TORCHTEXT_VERSION
87
ARG TORCHVISION_VERSION
98
ARG CUDA_MAJOR_VERSION
109
ARG CUDA_MINOR_VERSION
@@ -20,7 +19,7 @@ RUN conda install -c conda-forge mamba
2019

2120
# Build instructions: https://github.com/pytorch/pytorch#from-source
2221
RUN mamba install astunparse numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing_extensions future six requests dataclasses
23-
RUN mamba install -c pytorch magma-cuda${CUDA_MAJOR_VERSION}${CUDA_MINOR_VERSION}
22+
RUN mamba install -c pytorch magma-cuda121
2423

2524
# By default, it uses the version from version.txt which includes the `a0` (alpha zero) suffix and part of the git hash.
2625
# This causes dependency conflicts like these: https://paste.googleplex.com/4786486378496000
@@ -63,18 +62,6 @@ RUN sudo apt-get update && \
6362
RUN sed -i 's/set(envs/set(envs\n "LIBS=-ltinfo"/' /usr/local/src/audio/third_party/sox/CMakeLists.txt
6463
RUN cd /usr/local/src/audio && python setup.py bdist_wheel
6564

66-
# Build torchtext
67-
# Instructions: https://github.com/pytorch/text#building-from-source
68-
# See comment above for PYTORCH_BUILD_VERSION.
69-
ENV BUILD_VERSION=$TORCHTEXT_VERSION
70-
RUN cd /usr/local/src && \
71-
git clone https://github.com/pytorch/text && \
72-
cd text && \
73-
git checkout tags/v$TORCHTEXT_VERSION && \
74-
git submodule sync && \
75-
git submodule update --init --recursive --jobs 1 && \
76-
python setup.py bdist_wheel
77-
7865
# Build torchvision.
7966
# Instructions: https://github.com/pytorch/vision/tree/main#installation
8067
# See comment above for PYTORCH_BUILD_VERSION.
@@ -93,7 +80,6 @@ FROM alpine:latest
9380
RUN mkdir -p /tmp/whl/
9481
COPY --from=builder /usr/local/src/pytorch/dist/*.whl /tmp/whl
9582
COPY --from=builder /usr/local/src/audio/dist/*.whl /tmp/whl
96-
COPY --from=builder /usr/local/src/text/dist/*.whl /tmp/whl
9783
COPY --from=builder /usr/local/src/vision/dist/*.whl /tmp/whl
9884

9985
# Print out the built .whl file.

patches/keras_patch.sh

Lines changed: 0 additions & 41 deletions
This file was deleted.

tests/test_geopandas.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
11
import unittest
22

33
import geopandas
4+
from shapely.geometry import Polygon
45

56
class TestGeopandas(unittest.TestCase):
6-
def test_read(self):
7-
df = geopandas.read_file(geopandas.datasets.get_path('nybb'))
8-
self.assertTrue(df.size > 1)
9-
10-
def test_spatial_join(self):
11-
cities = geopandas.read_file(geopandas.datasets.get_path('naturalearth_cities'))
12-
world = geopandas.read_file(geopandas.datasets.get_path('naturalearth_lowres'))
13-
countries = world[['geometry', 'name']]
14-
countries = countries.rename(columns={'name':'country'})
15-
cities_with_country = geopandas.sjoin(cities, countries, how="inner", op='intersects')
16-
self.assertTrue(cities_with_country.size > 1)
7+
def test_GeoSeries(self):
8+
p1 = Polygon([(0, 0), (1, 0), (1, 1)])
9+
p2 = Polygon([(0, 0), (1, 0), (1, 1), (0, 1)])
10+
p3 = Polygon([(2, 0), (3, 0), (3, 1), (2, 1)])
11+
g = geopandas.GeoSeries([p1, p2, p3])

tests/test_torchtext.py

Lines changed: 0 additions & 12 deletions
This file was deleted.

tpu/Dockerfile

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ ARG TENSORFLOW_VERSION
1313
ARG TF_LIBTPU_VERSION
1414
ARG JAX_VERSION
1515
ARG TORCHVISION_VERSION
16-
ARG TORCHTEXT_VERSION
1716
ARG TORCHAUDIO_VERSION
1817

1918
ENV ISTPUVM=1
@@ -60,7 +59,7 @@ RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
6059
# Additional useful packages should be added here
6160

6261
RUN pip install tensorflow_hub https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-${TENSORFLOW_VERSION}/tensorflow-${TENSORFLOW_VERSION}-${PYTHON_WHEEL_VERSION}-${PYTHON_WHEEL_VERSION}-${TF_LINUX_WHEEL_VERSION}.whl tensorflow-probability tensorflow-io \
63-
torch~=${TORCH_VERSION} https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}+libtpu-${PYTHON_WHEEL_VERSION}-${PYTHON_WHEEL_VERSION}-${TORCH_LINUX_WHEEL_VERSION}.whl torchvision==${TORCHVISION_VERSION} torchtext==${TORCHTEXT_VERSION} torchaudio==${TORCHAUDIO_VERSION} \
62+
torch~=${TORCH_VERSION} https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}+libtpu-${PYTHON_WHEEL_VERSION}-${PYTHON_WHEEL_VERSION}-${TORCH_LINUX_WHEEL_VERSION}.whl torchvision==${TORCHVISION_VERSION} torchaudio==${TORCHAUDIO_VERSION} \
6463
jax[tpu]==${JAX_VERSION} -f https://storage.googleapis.com/jax-releases/libtpu_releases.html trax flax optax git+https://github.com/deepmind/dm-haiku jraph distrax \
6564
papermill jupyterlab python-lsp-server[all] "jupyter-lsp==1.5.1" \
6665
pandas matplotlib opencv-python-headless librosa accelerate diffusers scikit-learn transformers \

tpu/config.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ JAX_VERSION=0.4.23
1212
TORCH_VERSION=2.4.0
1313
# https://github.com/pytorch/audio supports nightly
1414
TORCHAUDIO_VERSION=2.4.0
15-
# https://github.com/pytorch/text supports main
16-
TORCHTEXT_VERSION=0.18.0
1715
# https://github.com/pytorch/vision supports nightly
1816
TORCHVISION_VERSION=0.19.0
1917
TORCH_LINUX_WHEEL_VERSION=manylinux_2_28_x86_64

0 commit comments

Comments
 (0)