Skip to content

Commit 1a7e1fa

Browse files
committed
Triton build conditionalized on ROCM_VERSION
1 parent 1bc9722 commit 1a7e1fa

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

.github/scripts/build_triton_wheel.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python3
22
import os
3+
import re
34
import shutil
45
import sys
56
from pathlib import Path
@@ -67,6 +68,31 @@ def patch_init_py(
6768
with open(path, "w") as f:
6869
f.write(orig)
6970

71+
def get_rocm_version() -> str:
72+
rocm_path = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH') or "/opt/rocm"
73+
rocm_version = "0.0.0"
74+
rocm_version_h = f"{rocm_path}/include/rocm-core/rocm_version.h"
75+
if not os.path.isfile(rocm_version_h):
76+
rocm_version_h = f"{rocm_path}/include/rocm_version.h"
77+
78+
# The file could be missing due to 1) ROCm version < 5.2, or 2) no ROCm install.
79+
if os.path.isfile(rocm_version_h):
80+
RE_MAJOR = re.compile(r"#define\s+ROCM_VERSION_MAJOR\s+(\d+)")
81+
RE_MINOR = re.compile(r"#define\s+ROCM_VERSION_MINOR\s+(\d+)")
82+
RE_PATCH = re.compile(r"#define\s+ROCM_VERSION_PATCH\s+(\d+)")
83+
major, minor, patch = 0, 0, 0
84+
for line in open(rocm_version_h):
85+
match = RE_MAJOR.search(line)
86+
if match:
87+
major = int(match.group(1))
88+
match = RE_MINOR.search(line)
89+
if match:
90+
minor = int(match.group(1))
91+
match = RE_PATCH.search(line)
92+
if match:
93+
patch = int(match.group(1))
94+
rocm_version = str(major)+"."+str(minor)+"."+str(patch)
95+
return rocm_version
7096

7197
def build_triton(
7298
*,
@@ -166,7 +192,8 @@ def build_triton(
166192
version=f"{version}",
167193
expected_version=ROCM_TRITION_VERSION,
168194
)
169-
check_call("scripts/amd/setup_rocm_libs.sh", cwd=triton_basedir, shell=True)
195+
cur_rocm_ver = get_rocm_version()
196+
check_call(["scripts/amd/setup_rocm_libs.sh", cur_rocm_ver], cwd=triton_basedir)
170197
print("ROCm libraries setup for triton installation...")
171198

172199
check_call(

0 commit comments

Comments
 (0)