|
1 | 1 | #!/usr/bin/env python3
|
2 | 2 |
|
3 | 3 | import os
|
| 4 | +import re |
4 | 5 | import shutil
|
5 | 6 | import sys
|
6 | 7 | from pathlib import Path
|
@@ -47,6 +48,31 @@ def patch_init_py(
|
47 | 48 | with open(path, "w") as f:
|
48 | 49 | f.write(orig)
|
49 | 50 |
|
| 51 | +def get_rocm_version() -> str: |
| 52 | + rocm_path = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH') or "/opt/rocm" |
| 53 | + rocm_version = "0.0.0" |
| 54 | + rocm_version_h = f"{rocm_path}/include/rocm-core/rocm_version.h" |
| 55 | + if not os.path.isfile(rocm_version_h): |
| 56 | + rocm_version_h = f"{rocm_path}/include/rocm_version.h" |
| 57 | + |
| 58 | + # The file could be missing due to 1) ROCm version < 5.2, or 2) no ROCm install. |
| 59 | + if os.path.isfile(rocm_version_h): |
| 60 | + RE_MAJOR = re.compile(r"#define\s+ROCM_VERSION_MAJOR\s+(\d+)") |
| 61 | + RE_MINOR = re.compile(r"#define\s+ROCM_VERSION_MINOR\s+(\d+)") |
| 62 | + RE_PATCH = re.compile(r"#define\s+ROCM_VERSION_PATCH\s+(\d+)") |
| 63 | + major, minor, patch = 0, 0, 0 |
| 64 | + for line in open(rocm_version_h): |
| 65 | + match = RE_MAJOR.search(line) |
| 66 | + if match: |
| 67 | + major = int(match.group(1)) |
| 68 | + match = RE_MINOR.search(line) |
| 69 | + if match: |
| 70 | + minor = int(match.group(1)) |
| 71 | + match = RE_PATCH.search(line) |
| 72 | + if match: |
| 73 | + patch = int(match.group(1)) |
| 74 | + rocm_version = str(major)+"."+str(minor)+"."+str(patch) |
| 75 | + return rocm_version |
50 | 76 |
|
51 | 77 | # TODO: remove patch_setup_py() once we have a proper fix for https://github.com/triton-lang/triton/issues/4527
|
52 | 78 | def patch_setup_py(path: Path) -> None:
|
@@ -85,7 +111,8 @@ def build_triton(
|
85 | 111 | if not release:
|
86 | 112 | # Nightly binaries include the triton commit hash, i.e. 2.1.0+e6216047b8
|
87 | 113 | # while release build should only include the version, i.e. 2.1.0
|
88 |
| - version_suffix = f"+{commit_hash[:10]}" |
| 114 | + rocm_version = get_rocm_version() |
| 115 | + version_suffix = f"+rocm{rocm_version}_{commit_hash[:10]}" |
89 | 116 | version += version_suffix
|
90 | 117 |
|
91 | 118 | with TemporaryDirectory() as tmpdir:
|
@@ -175,6 +202,8 @@ def build_triton(
|
175 | 202 | cwd=triton_basedir,
|
176 | 203 | shell=True,
|
177 | 204 | )
|
| 205 | + cur_rocm_ver = get_rocm_version() |
| 206 | + check_call(["scripts/amd/setup_rocm_libs.sh", cur_rocm_ver], cwd=triton_basedir) |
178 | 207 | print("ROCm libraries setup for triton installation...")
|
179 | 208 |
|
180 | 209 | check_call(
|
|
0 commit comments