Skip to content

Commit f8abbf7

Browse files
committed
[WIP] enable AMD GPU
Signed-off-by: Vicky Tsang <[email protected]>
1 parent 42bec18 commit f8abbf7

File tree

2 files changed

+25
-7
lines changed

2 files changed

+25
-7
lines changed

monai/deploy/packager/util.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@ def verify_base_image(base_image: str) -> str:
4242
str: returns string identifier of the dockerfile template to build MAP
4343
if valid base image provided, returns empty string otherwise
4444
"""
45-
valid_prefixes = {"nvcr.io/nvidia/cuda": "ubuntu", "nvcr.io/nvidia/pytorch": "pytorch"}
45+
import torch
46+
if "AMD" not in torch.cuda.get_device_name(0):
47+
valid_prefixes = {"nvcr.io/nvidia/cuda": "ubuntu", "nvcr.io/nvidia/pytorch": "pytorch"}
48+
else:
49+
valid_prefixes = {"rocm": "ubuntu", "rocm/pytorch": "pytorch"}
4650

4751
for prefix, template in valid_prefixes.items():
4852
if prefix in base_image:
@@ -89,12 +93,22 @@ def initialize_args(args: Namespace) -> Dict:
8993
if args.base:
9094
dockerfile_type = verify_base_image(args.base)
9195
if not dockerfile_type:
92-
logger.error(
93-
"Provided base image '{}' is not supported \n \
94-
Please provide a Cuda or Pytorch image from https://ngc.nvidia.com/ (nvcr.io/nvidia)".format(
95-
args.base
96+
import torch
97+
if "AMD" not in torch.cuda.get_device_name(0):
98+
logger.error(
99+
"Provided base image '{}' is not supported \n \
100+
Please provide a Cuda or Pytorch image from https://ngc.nvidia.com/ (nvcr.io/nvidia)".format(
101+
args.base
102+
)
96103
)
97-
)
104+
else:
105+
logger.error(
106+
"Provided base image '{}' is not supported \n \
107+
Please provide a ROCm or Pytorch image from https://hub.docker.com/r/rocm/pytorch".format(
108+
args.base
109+
)
110+
)
111+
98112
sys.exit(1)
99113

100114
processed_args["dockerfile_type"] = dockerfile_type if args.base else DefaultValues.DOCKERFILE_TYPE

monai/deploy/runner/runner.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ def run_app(map_name: str, input_path: Path, output_path: Path, app_info: dict,
8787
# Use nvidia-docker if GPU resources are requested
8888
requested_gpus = get_requested_gpus(pkg_info)
8989
if requested_gpus > 0:
90-
cmd = "nvidia-docker run --rm -a STDERR"
90+
import torch
91+
if "AMD" not in torch.cuda.get_device_name(0):
92+
cmd = "nvidia-docker run --rm -a STDERR"
9193

9294
if not quiet:
9395
cmd += " -a STDOUT"
@@ -160,6 +162,8 @@ def pkg_specific_dependency_verification(pkg_info: dict) -> bool:
160162
"""
161163
requested_gpus = get_requested_gpus(pkg_info)
162164
if requested_gpus > 0:
165+
import torch
166+
if "AMD" not in torch.cuda.get_device_name(0):
163167
# check for nvidia-docker
164168
prog = "nvidia-docker"
165169
logger.info('--> Verifying if "%s" is installed...\n', prog)

0 commit comments

Comments
 (0)