Skip to content

Commit 2ac680f

Browse files
authored
Fix MAGMA PyTorch issue with GPU (#1154)
- Add test to prevent regression. http://b/231736279
1 parent 2e43fcc commit 2ac680f

File tree

5 files changed

+21
-3
lines changed

5 files changed

+21
-3
lines changed

Dockerfile.tmpl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@ ARG TORCHVISION_VERSION
1212
FROM gcr.io/kaggle-images/python-lightgbm-whl:${GPU_BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${LIGHTGBM_VERSION} AS lightgbm_whl
1313
FROM gcr.io/kaggle-images/python-torch-whl:${GPU_BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${TORCH_VERSION} AS torch_whl
1414
FROM ${BASE_IMAGE_REPO}/${GPU_BASE_IMAGE_NAME}:${BASE_IMAGE_TAG}
15-
ENV CUDA_MAJOR_VERSION=11
16-
ENV CUDA_MINOR_VERSION=0
15+
ARG CUDA_MAJOR_VERSION
16+
ARG CUDA_MINOR_VERSION
17+
ENV CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION}
18+
ENV CUDA_MINOR_VERSION=${CUDA_MINOR_VERSION}
1719
# NVIDIA binaries from the host are mounted to /opt/bin.
1820
ENV PATH=/opt/bin:${PATH}
1921
# Add CUDA stubs to LD_LIBRARY_PATH to support building the GPU image on a CPU machine.
@@ -99,7 +101,8 @@ RUN conda install implicit && \
99101
# Install PyTorch
100102
{{ if eq .Accelerator "gpu" }}
101103
COPY --from=torch_whl /tmp/whl/*.whl /tmp/torch/
102-
RUN pip install /tmp/torch/*.whl && \
104+
RUN conda install -c pytorch magma-cuda${CUDA_MAJOR_VERSION}${CUDA_MINOR_VERSION} && \
105+
pip install /tmp/torch/*.whl && \
103106
rm -rf /tmp/torch && \
104107
/tmp/clean-layer.sh
105108
{{ else }}

Jenkinsfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ pipeline {
3737
--build-arg TORCHAUDIO_VERSION=$TORCHAUDIO_VERSION \
3838
--build-arg TORCHTEXT_VERSION=$TORCHTEXT_VERSION \
3939
--build-arg TORCHVISION_VERSION=$TORCHVISION_VERSION \
40+
--build-arg CUDA_MAJOR_VERSION=$CUDA_MAJOR_VERSION \
41+
--build-arg CUDA_MINOR_VERSION=$CUDA_MINOR_VERSION \
4042
--push
4143
'''
4244
}

config.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@ TORCH_VERSION=1.9.1
77
TORCHAUDIO_VERSION=0.9.1
88
TORCHTEXT_VERSION=0.10.1
99
TORCHVISION_VERSION=0.10.1
10+
CUDA_MAJOR_VERSION=11
11+
CUDA_MINOR_VERSION=0

packages/torch.Dockerfile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@ ARG PACKAGE_VERSION
66
ARG TORCHAUDIO_VERSION
77
ARG TORCHTEXT_VERSION
88
ARG TORCHVISION_VERSION
9+
ARG CUDA_MAJOR_VERSION
10+
ARG CUDA_MINOR_VERSION
911

1012
# TORCHVISION_VERSION is mandatory
1113
RUN test -n "$TORCHVISION_VERSION"
1214

1315
# Build instructions: https://github.com/pytorch/pytorch#from-source
1416
RUN conda install astunparse numpy ninja pyyaml mkl mkl-include setuptools==59.5.0 cmake cffi typing_extensions future six requests dataclasses
17+
RUN conda install -c pytorch magma-cuda${CUDA_MAJOR_VERSION}${CUDA_MINOR_VERSION}
1518

1619
# By default, it uses the version from version.txt which includes the `a0` (alpha zero) suffix and part of the git hash.
1720
# This causes dependency conflicts like these: https://paste.googleplex.com/4786486378496000

tests/test_pytorch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ def test_nn(self):
1515
data_torch = autograd.Variable(torch.randn(2, 5))
1616
linear_torch(data_torch)
1717

18+
@gpu_test
19+
def test_linalg(self):
20+
A = torch.randn(3, 3).t().to('cuda')
21+
B = torch.randn(3).t().to('cuda')
22+
23+
result = torch.linalg.solve(A, B)
24+
self.assertEqual(3, result.shape[0])
25+
1826
@gpu_test
1927
def test_gpu_computation(self):
2028
cuda = torch.device('cuda')

0 commit comments

Comments
 (0)