Skip to content

Commit 2e1049d

Browse files
pruthvistonydnikolaev-amd
authored andcommitted
CONSOLIDATED COMMITS: Triton build updates
========================================== Triton build conditionalized on ROCM_VERSION Include the ROCm version in triton version (cherry picked from commit 7d33910)
1 parent 4e3b69a commit 2e1049d

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

.github/scripts/build_triton_wheel.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#!/usr/bin/env python3
22

33
import os
4+
import re
45
import shutil
56
import sys
67
from pathlib import Path
@@ -47,6 +48,31 @@ def patch_init_py(
4748
with open(path, "w") as f:
4849
f.write(orig)
4950

51+
def get_rocm_version() -> str:
52+
rocm_path = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH') or "/opt/rocm"
53+
rocm_version = "0.0.0"
54+
rocm_version_h = f"{rocm_path}/include/rocm-core/rocm_version.h"
55+
if not os.path.isfile(rocm_version_h):
56+
rocm_version_h = f"{rocm_path}/include/rocm_version.h"
57+
58+
# The file could be missing due to 1) ROCm version < 5.2, or 2) no ROCm install.
59+
if os.path.isfile(rocm_version_h):
60+
RE_MAJOR = re.compile(r"#define\s+ROCM_VERSION_MAJOR\s+(\d+)")
61+
RE_MINOR = re.compile(r"#define\s+ROCM_VERSION_MINOR\s+(\d+)")
62+
RE_PATCH = re.compile(r"#define\s+ROCM_VERSION_PATCH\s+(\d+)")
63+
major, minor, patch = 0, 0, 0
64+
for line in open(rocm_version_h):
65+
match = RE_MAJOR.search(line)
66+
if match:
67+
major = int(match.group(1))
68+
match = RE_MINOR.search(line)
69+
if match:
70+
minor = int(match.group(1))
71+
match = RE_PATCH.search(line)
72+
if match:
73+
patch = int(match.group(1))
74+
rocm_version = str(major)+"."+str(minor)+"."+str(patch)
75+
return rocm_version
5076

5177
def build_triton(
5278
*,
@@ -62,6 +88,14 @@ def build_triton(
6288
max_jobs = os.cpu_count() or 1
6389
env["MAX_JOBS"] = str(max_jobs)
6490

91+
version_suffix = ""
92+
if not release:
93+
# Nightly binaries include the triton commit hash, i.e. 2.1.0+e6216047b8
94+
# while release build should only include the version, i.e. 2.1.0
95+
rocm_version = get_rocm_version()
96+
version_suffix = f"+rocm{rocm_version}_{commit_hash[:10]}"
97+
version += version_suffix
98+
6599
with TemporaryDirectory() as tmpdir:
66100
triton_basedir = Path(tmpdir) / "triton"
67101
triton_pythondir = triton_basedir / "python"
@@ -99,6 +133,8 @@ def build_triton(
99133
cwd=triton_basedir,
100134
shell=True,
101135
)
136+
cur_rocm_ver = get_rocm_version()
137+
check_call(["scripts/amd/setup_rocm_libs.sh", cur_rocm_ver], cwd=triton_basedir)
102138
print("ROCm libraries setup for triton installation...")
103139

104140
check_call(

0 commit comments

Comments
 (0)