|
1 | 1 | #!/usr/bin/env python3
|
2 | 2 | import os
|
| 3 | +import re |
3 | 4 | import shutil
|
4 | 5 | import sys
|
5 | 6 | from pathlib import Path
|
@@ -64,6 +65,31 @@ def patch_init_py(
|
64 | 65 | with open(path, "w") as f:
|
65 | 66 | f.write(orig)
|
66 | 67 |
|
| 68 | +def get_rocm_version() -> str: |
| 69 | + rocm_path = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH') or "/opt/rocm" |
| 70 | + rocm_version = "0.0.0" |
| 71 | + rocm_version_h = f"{rocm_path}/include/rocm-core/rocm_version.h" |
| 72 | + if not os.path.isfile(rocm_version_h): |
| 73 | + rocm_version_h = f"{rocm_path}/include/rocm_version.h" |
| 74 | + |
| 75 | + # The file could be missing due to 1) ROCm version < 5.2, or 2) no ROCm install. |
| 76 | + if os.path.isfile(rocm_version_h): |
| 77 | + RE_MAJOR = re.compile(r"#define\s+ROCM_VERSION_MAJOR\s+(\d+)") |
| 78 | + RE_MINOR = re.compile(r"#define\s+ROCM_VERSION_MINOR\s+(\d+)") |
| 79 | + RE_PATCH = re.compile(r"#define\s+ROCM_VERSION_PATCH\s+(\d+)") |
| 80 | + major, minor, patch = 0, 0, 0 |
| 81 | + for line in open(rocm_version_h): |
| 82 | + match = RE_MAJOR.search(line) |
| 83 | + if match: |
| 84 | + major = int(match.group(1)) |
| 85 | + match = RE_MINOR.search(line) |
| 86 | + if match: |
| 87 | + minor = int(match.group(1)) |
| 88 | + match = RE_PATCH.search(line) |
| 89 | + if match: |
| 90 | + patch = int(match.group(1)) |
| 91 | + rocm_version = str(major)+"."+str(minor)+"."+str(patch) |
| 92 | + return rocm_version |
67 | 93 |
|
68 | 94 | def build_triton(
|
69 | 95 | *,
|
@@ -170,7 +196,8 @@ def build_triton(
|
170 | 196 | version=f"{version}",
|
171 | 197 | expected_version=None,
|
172 | 198 | )
|
173 |
| - check_call("scripts/amd/setup_rocm_libs.sh", cwd=triton_basedir, shell=True) |
| 199 | + cur_rocm_ver = get_rocm_version() |
| 200 | + check_call(["scripts/amd/setup_rocm_libs.sh", cur_rocm_ver], cwd=triton_basedir) |
174 | 201 | print("ROCm libraries setup for triton installation...")
|
175 | 202 |
|
176 | 203 | check_call(
|
|
0 commit comments