File tree Expand file tree Collapse file tree 5 files changed +21
-3
lines changed Expand file tree Collapse file tree 5 files changed +21
-3
lines changed Original file line number Diff line number Diff line change @@ -12,8 +12,10 @@ ARG TORCHVISION_VERSION
12
12
FROM gcr.io/kaggle-images/python-lightgbm-whl:${GPU_BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${LIGHTGBM_VERSION} AS lightgbm_whl
13
13
FROM gcr.io/kaggle-images/python-torch-whl:${GPU_BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${TORCH_VERSION} AS torch_whl
14
14
FROM ${BASE_IMAGE_REPO}/${GPU_BASE_IMAGE_NAME}:${BASE_IMAGE_TAG}
15
- ENV CUDA_MAJOR_VERSION=11
16
- ENV CUDA_MINOR_VERSION=0
15
+ ARG CUDA_MAJOR_VERSION
16
+ ARG CUDA_MINOR_VERSION
17
+ ENV CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION}
18
+ ENV CUDA_MINOR_VERSION=${CUDA_MINOR_VERSION}
17
19
# NVIDIA binaries from the host are mounted to /opt/bin.
18
20
ENV PATH=/opt/bin:${PATH}
19
21
# Add CUDA stubs to LD_LIBRARY_PATH to support building the GPU image on a CPU machine.
@@ -99,7 +101,8 @@ RUN conda install implicit && \
99
101
# Install PyTorch
100
102
{{ if eq .Accelerator "gpu" }}
101
103
COPY --from=torch_whl /tmp/whl/*.whl /tmp/torch/
102
- RUN pip install /tmp/torch/*.whl && \
104
+ RUN conda install -c pytorch magma-cuda${CUDA_MAJOR_VERSION}${CUDA_MINOR_VERSION} && \
105
+ pip install /tmp/torch/*.whl && \
103
106
rm -rf /tmp/torch && \
104
107
/tmp/clean-layer.sh
105
108
{{ else }}
Original file line number Diff line number Diff line change @@ -37,6 +37,8 @@ pipeline {
37
37
--build-arg TORCHAUDIO_VERSION=$TORCHAUDIO_VERSION \
38
38
--build-arg TORCHTEXT_VERSION=$TORCHTEXT_VERSION \
39
39
--build-arg TORCHVISION_VERSION=$TORCHVISION_VERSION \
40
+ --build-arg CUDA_MAJOR_VERSION=$CUDA_MAJOR_VERSION \
41
+ --build-arg CUDA_MINOR_VERSION=$CUDA_MINOR_VERSION \
40
42
--push
41
43
'''
42
44
}
Original file line number Diff line number Diff line change @@ -7,3 +7,5 @@ TORCH_VERSION=1.9.1
7
7
TORCHAUDIO_VERSION=0.9.1
8
8
TORCHTEXT_VERSION=0.10.1
9
9
TORCHVISION_VERSION=0.10.1
10
+ CUDA_MAJOR_VERSION=11
11
+ CUDA_MINOR_VERSION=0
Original file line number Diff line number Diff line change @@ -6,12 +6,15 @@ ARG PACKAGE_VERSION
6
6
ARG TORCHAUDIO_VERSION
7
7
ARG TORCHTEXT_VERSION
8
8
ARG TORCHVISION_VERSION
9
+ ARG CUDA_MAJOR_VERSION
10
+ ARG CUDA_MINOR_VERSION
9
11
10
12
# TORCHVISION_VERSION is mandatory
11
13
RUN test -n "$TORCHVISION_VERSION"
12
14
13
15
# Build instructions: https://github.com/pytorch/pytorch#from-source
14
16
RUN conda install astunparse numpy ninja pyyaml mkl mkl-include setuptools==59.5.0 cmake cffi typing_extensions future six requests dataclasses
17
+ RUN conda install -c pytorch magma-cuda${CUDA_MAJOR_VERSION}${CUDA_MINOR_VERSION}
15
18
16
19
# By default, it uses the version from version.txt which includes the `a0` (alpha zero) suffix and part of the git hash.
17
20
# This causes dependency conflicts like these: https://paste.googleplex.com/4786486378496000
Original file line number Diff line number Diff line change @@ -15,6 +15,14 @@ def test_nn(self):
15
15
data_torch = autograd .Variable (torch .randn (2 , 5 ))
16
16
linear_torch (data_torch )
17
17
18
+ @gpu_test
19
+ def test_linalg (self ):
20
+ A = torch .randn (3 , 3 ).t ().to ('cuda' )
21
+ B = torch .randn (3 ).t ().to ('cuda' )
22
+
23
+ result = torch .linalg .solve (A , B )
24
+ self .assertEqual (3 , result .shape [0 ])
25
+
18
26
@gpu_test
19
27
def test_gpu_computation (self ):
20
28
cuda = torch .device ('cuda' )
You can’t perform that action at this time.
0 commit comments