Skip to content

Commit 0528d46

Browse files
committed
TPU Image with python3.8
Tensorflow for TPU VM is not supported on python3.7 and therefore we need a brand new image instead of one based on prior images. http://b/213335159
1 parent efdad4a commit 0528d46

File tree

4 files changed

+45
-113
lines changed

4 files changed

+45
-113
lines changed

Jenkinsfile

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -56,22 +56,6 @@ pipeline {
5656
'''
5757
}
5858
}
59-
stage('tensorflow TPU') {
60-
options {
61-
timeout(time: 240, unit: 'MINUTES')
62-
}
63-
steps {
64-
sh '''#!/bin/bash
65-
set -exo pipefail
66-
source tpu/config.txt
67-
cd packages/
68-
./build_package --base-image gcr.io/kaggle-images/python:${BASE_IMAGE_TAG} \
69-
--package tpu-tensorflow \
70-
--version $TENSORFLOW_VERSION \
71-
--push
72-
'''
73-
}
74-
}
7559
}
7660
}
7761
stage('Build/Test/Diff') {
@@ -171,23 +155,12 @@ pipeline {
171155
stages {
172156
stage('Build Tensorflow TPU Image') {
173157
options {
174-
timeout(time: 20, unit: 'MINUTES')
158+
timeout(time: 60, unit: 'MINUTES')
175159
}
176160
steps {
177161
sh '''#!/bin/bash
178162
set -exo pipefail
179163
180-
# Login to docker to get access to gcr.io/cloud-tpu-v2-images/libtpu
181-
182-
# To grant access to a SA, start a TPU VM with that SA once.
183-
# Disable echo to avoid printing sensitive tokens:
184-
set +x
185-
METADATA=http://metadata.google.internal/computeMetadata/v1
186-
SVC_ACCT=$METADATA/instance/service-accounts/default
187-
ACCESS_TOKEN=$(/usr/bin/curl -s -H 'Metadata-Flavor: Google' $SVC_ACCT/token | cut -d'"' -f 4)
188-
docker login --username oauth2accesstoken --password $ACCESS_TOKEN https://gcr.io
189-
set -x
190-
191164
./tpu/build | ts
192165
./push --tpu ${PRETEST_TAG}
193166
'''

packages/tpu-tensorflow.Dockerfile

Lines changed: 0 additions & 56 deletions
This file was deleted.

tpu/Dockerfile

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,60 @@
1-
ARG BASE_IMAGE_TAG
2-
ARG TENSORFLOW_VERSION
3-
4-
FROM gcr.io/kaggle-images/python-tpu-tensorflow-whl:python-${BASE_IMAGE_TAG}-${TENSORFLOW_VERSION} AS tensorflow_whl
5-
FROM gcr.io/kaggle-images/python:${BASE_IMAGE_TAG}
1+
FROM python:3.8
62

73
# We need to define the ARG here to get the ARG below the FROM statement to access it within this build context
84
# See: https://docs.docker.com/engine/reference/builder/#understand-how-arg-and-from-interact
95
ARG TORCH_VERSION
6+
ARG TENSORFLOW_VERSION
107

118
ENV ISTPUVM=1
129

13-
COPY --from=tensorflow_whl /tmp/tensorflow_pkg/tensorflow*.whl /tmp/tensorflow_pkg/
14-
RUN pip install /tmp/tensorflow_pkg/tensorflow*.whl && \
15-
rm -rf /tmp/tensorflow_pkg && \
16-
/tmp/clean-layer.sh
10+
ADD patches/nbconvert-extensions.tpl /opt/kaggle/nbconvert-extensions.tpl
11+
ADD patches/template_conf.json /opt/kaggle/conf.json
12+
13+
# Tensorflow install:
14+
RUN pip install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-${TENSORFLOW_VERSION}/tensorflow-${TENSORFLOW_VERSION}-cp38-cp38-linux_x86_64.whl && \
15+
curl --output /lib/libtpu.so https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.3.0/libtpu.so
1716

1817
# LIBTPU installed here:
19-
ENV DEFAULT_LIBTPU=/opt/conda/lib/python3.7/site-packages/libtpu/libtpu.so
20-
ENV PYTORCH_LIBTPU=/opt/conda/lib/python3.7/site-packages/libtpu/torch-libtpu.so
21-
ENV JAX_LIBTPU=/opt/conda/lib/python3.7/site-packages/libtpu/jax-libtpu.so
18+
ENV PIP_LIBTPU=/usr/local/lib/python3.8/site-packages/libtpu/libtpu.so
19+
ENV DEFAULT_LIBTPU=/lib/libtpu.so
20+
ENV PYTORCH_LIBTPU=/lib/torch-libtpu.so
21+
ENV JAX_LIBTPU=/lib/jax-libtpu.so
22+
23+
# https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm#install_jax_on_your_cloud_tpu_vm
24+
RUN pip install "jax[tpu]==0.3.10" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
25+
26+
RUN cp $PIP_LIBTPU $JAX_LIBTPU
2227

2328
# https://cloud.google.com/tpu/docs/pytorch-xla-ug-tpu-vm#changing_pytorch_version
24-
RUN pip uninstall -y torch && \
25-
pip install torch==${TORCH_VERSION} && \
26-
# The URL doesn't include patch version. i.e. must use 1.11 instead of 1.11.0
27-
pip install torch_xla[tpuvm] -f https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-${TORCH_VERSION%.*}-cp37-cp37m-linux_x86_64.whl && \
28-
cp $DEFAULT_LIBTPU $PYTORCH_LIBTPU && \
29-
/tmp/clean-layer.sh
29+
RUN pip install torch==${TORCH_VERSION}
30+
31+
# The URL doesn't include patch version. i.e. must use 1.11 instead of 1.11.0
32+
RUN pip install torch_xla[tpuvm] -f https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-${TORCH_VERSION%.*}-cp38-cp38-linux_x86_64.whl
3033

31-
# https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm#install_jax_on_your_cloud_tpu_vm
32-
RUN pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html && \
33-
cp $DEFAULT_LIBTPU $JAX_LIBTPU && \
34-
/tmp/clean-layer.sh
34+
RUN cp $PIP_LIBTPU $PYTORCH_LIBTPU
3535

3636
# Monkey-patch TF, JAX & PYTORCH to load the correct libtpu.so when they are imported:
37-
RUN sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_LIBRARY_PATH'] = '${PYTORCH_LIBTPU}'|" /opt/conda/lib/python3.7/site-packages/torch_xla/__init__.py && \
38-
sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_LIBRARY_PATH'] = '${JAX_LIBTPU}'|" /opt/conda/lib/python3.7/site-packages/jax/_src/cloud_tpu_init.py && \
39-
sed -i "1s/^/from jax._src.cloud_tpu_init import cloud_tpu_init\ncloud_tpu_init()\n/" /opt/conda/lib/python3.7/site-packages/tensorflow/__init__.py
37+
RUN sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_LIBRARY_PATH'] = '${PYTORCH_LIBTPU}'|" /usr/local/lib/python3.8/site-packages/torch_xla/__init__.py && \
38+
sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_LIBRARY_PATH'] = '${JAX_LIBTPU}'|" /usr/local/lib/python3.8/site-packages/jax/_src/cloud_tpu_init.py
39+
40+
# Packages needed by the Notebook editor:
41+
RUN pip install papermill jupyterlab python-lsp-server[all] jupyterlab-lsp
4042

4143
# Set these env vars so that they don't produce errs calling the metadata server to load them:
4244
ENV TPU_ACCELERATOR_TYPE=v3-8
43-
ENV TPU_PROCESS_ADDRESSES=local
45+
ENV TPU_PROCESS_ADDRESSES=local
46+
47+
# Metadata
48+
ARG GIT_COMMIT=unknown
49+
ARG BUILD_DATE=unknown
50+
51+
LABEL git-commit=$GIT_COMMIT
52+
LABEL build-date=$BUILD_DATE
53+
ENV GIT_COMMIT=${GIT_COMMIT}
54+
ENV BUILD_DATE=${BUILD_DATE}
55+
56+
LABEL tensorflow-version=$TENSORFLOW_VERSION
57+
LABEL kaggle-lang=python
58+
59+
# Correlate current release with the git hash inside the kernel editor by running `!cat /etc/git_commit`.
60+
RUN echo "$GIT_COMMIT" > /etc/git_commit && echo "$BUILD_DATE" > /etc/build_date

tpu/config.txt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,2 @@
1-
# TODO(b/213335159): Use ci-pretest for BASE_IMAGE_TAG once stable.
2-
BASE_IMAGE_TAG=v115
3-
TENSORFLOW_VERSION=2.8.0
1+
TENSORFLOW_VERSION=2.9.1
42
TORCH_VERSION=1.11.0

0 commit comments

Comments
 (0)