Skip to content

Commit d6afb67

Browse files
committed
CONSOLIDATED COMMITS: Triton build updates
========================================== Triton build conditionalized on ROCM_VERSION Include the ROCm version in triton version
1 parent 92e23cd commit d6afb67

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

.github/scripts/build_triton_wheel.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python3
22

33
import os
4+
import re
45
import shutil
56
import sys
67
from pathlib import Path
@@ -47,6 +48,31 @@ def patch_init_py(
4748
with open(path, "w") as f:
4849
f.write(orig)
4950

51+
def get_rocm_version() -> str:
52+
rocm_path = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH') or "/opt/rocm"
53+
rocm_version = "0.0.0"
54+
rocm_version_h = f"{rocm_path}/include/rocm-core/rocm_version.h"
55+
if not os.path.isfile(rocm_version_h):
56+
rocm_version_h = f"{rocm_path}/include/rocm_version.h"
57+
58+
# The file could be missing due to 1) ROCm version < 5.2, or 2) no ROCm install.
59+
if os.path.isfile(rocm_version_h):
60+
RE_MAJOR = re.compile(r"#define\s+ROCM_VERSION_MAJOR\s+(\d+)")
61+
RE_MINOR = re.compile(r"#define\s+ROCM_VERSION_MINOR\s+(\d+)")
62+
RE_PATCH = re.compile(r"#define\s+ROCM_VERSION_PATCH\s+(\d+)")
63+
major, minor, patch = 0, 0, 0
64+
for line in open(rocm_version_h):
65+
match = RE_MAJOR.search(line)
66+
if match:
67+
major = int(match.group(1))
68+
match = RE_MINOR.search(line)
69+
if match:
70+
minor = int(match.group(1))
71+
match = RE_PATCH.search(line)
72+
if match:
73+
patch = int(match.group(1))
74+
rocm_version = str(major)+"."+str(minor)+"."+str(patch)
75+
return rocm_version
5076

5177
def build_triton(
5278
*,
@@ -63,6 +89,14 @@ def build_triton(
6389
max_jobs = os.cpu_count() or 1
6490
env["MAX_JOBS"] = str(max_jobs)
6591

92+
version_suffix = ""
93+
if not release:
94+
# Nightly binaries include the triton commit hash, i.e. 2.1.0+e6216047b8
95+
# while release build should only include the version, i.e. 2.1.0
96+
rocm_version = get_rocm_version()
97+
version_suffix = f"+rocm{rocm_version}_{commit_hash[:10]}"
98+
version += version_suffix
99+
66100
with TemporaryDirectory() as tmpdir:
67101
triton_basedir = Path(tmpdir) / "triton"
68102
triton_pythondir = triton_basedir / "python"
@@ -149,6 +183,8 @@ def build_triton(
149183
cwd=triton_basedir,
150184
shell=True,
151185
)
186+
cur_rocm_ver = get_rocm_version()
187+
check_call(["scripts/amd/setup_rocm_libs.sh", cur_rocm_ver], cwd=triton_basedir)
152188
print("ROCm libraries setup for triton installation...")
153189

154190
check_call(

0 commit comments

Comments
 (0)