1
1
#!/usr/bin/env python3
2
2
3
3
import os
4
+ import re
4
5
import shutil
5
6
import sys
6
7
from pathlib import Path
@@ -47,6 +48,31 @@ def patch_init_py(
47
48
with open (path , "w" ) as f :
48
49
f .write (orig )
49
50
51
+ def get_rocm_version () -> str :
52
+ rocm_path = os .environ .get ('ROCM_HOME' ) or os .environ .get ('ROCM_PATH' ) or "/opt/rocm"
53
+ rocm_version = "0.0.0"
54
+ rocm_version_h = f"{ rocm_path } /include/rocm-core/rocm_version.h"
55
+ if not os .path .isfile (rocm_version_h ):
56
+ rocm_version_h = f"{ rocm_path } /include/rocm_version.h"
57
+
58
+ # The file could be missing due to 1) ROCm version < 5.2, or 2) no ROCm install.
59
+ if os .path .isfile (rocm_version_h ):
60
+ RE_MAJOR = re .compile (r"#define\s+ROCM_VERSION_MAJOR\s+(\d+)" )
61
+ RE_MINOR = re .compile (r"#define\s+ROCM_VERSION_MINOR\s+(\d+)" )
62
+ RE_PATCH = re .compile (r"#define\s+ROCM_VERSION_PATCH\s+(\d+)" )
63
+ major , minor , patch = 0 , 0 , 0
64
+ for line in open (rocm_version_h ):
65
+ match = RE_MAJOR .search (line )
66
+ if match :
67
+ major = int (match .group (1 ))
68
+ match = RE_MINOR .search (line )
69
+ if match :
70
+ minor = int (match .group (1 ))
71
+ match = RE_PATCH .search (line )
72
+ if match :
73
+ patch = int (match .group (1 ))
74
+ rocm_version = str (major )+ "." + str (minor )+ "." + str (patch )
75
+ return rocm_version
50
76
51
77
def build_triton (
52
78
* ,
@@ -63,6 +89,14 @@ def build_triton(
63
89
max_jobs = os .cpu_count () or 1
64
90
env ["MAX_JOBS" ] = str (max_jobs )
65
91
92
+ version_suffix = ""
93
+ if not release :
94
+ # Nightly binaries include the triton commit hash, i.e. 2.1.0+e6216047b8
95
+ # while release build should only include the version, i.e. 2.1.0
96
+ rocm_version = get_rocm_version ()
97
+ version_suffix = f"+rocm{ rocm_version } _{ commit_hash [:10 ]} "
98
+ version += version_suffix
99
+
66
100
with TemporaryDirectory () as tmpdir :
67
101
triton_basedir = Path (tmpdir ) / "triton"
68
102
triton_pythondir = triton_basedir / "python"
@@ -149,6 +183,8 @@ def build_triton(
149
183
cwd = triton_basedir ,
150
184
shell = True ,
151
185
)
186
+ cur_rocm_ver = get_rocm_version ()
187
+ check_call (["scripts/amd/setup_rocm_libs.sh" , cur_rocm_ver ], cwd = triton_basedir )
152
188
print ("ROCm libraries setup for triton installation..." )
153
189
154
190
check_call (
0 commit comments