Skip to content

Commit 3f16be2

Browse files
jataylopruthvistony
authored andcommitted
release/2.2 triton commit pin for rocm6.1 conditionalisation (#1369)
* Triton build conditionalized on ROCM_VERSION (cherry picked from commit 1a7e1fa) * Update pinned commit for rocm6.1 conditionalisation --------- Co-authored-by: Pruthvi Madugundu <[email protected]>
1 parent 701d205 commit 3f16be2

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
d08e16b738ab550c3af51305df624d5c823dc445
1+
54d0bb9e4b2e2dab7dc899008c0f14915f665a2f

.github/scripts/build_triton_wheel.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python3
22
import os
3+
import re
34
import shutil
45
import sys
56
from pathlib import Path
@@ -64,6 +65,31 @@ def patch_init_py(
6465
with open(path, "w") as f:
6566
f.write(orig)
6667

68+
def get_rocm_version() -> str:
69+
rocm_path = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH') or "/opt/rocm"
70+
rocm_version = "0.0.0"
71+
rocm_version_h = f"{rocm_path}/include/rocm-core/rocm_version.h"
72+
if not os.path.isfile(rocm_version_h):
73+
rocm_version_h = f"{rocm_path}/include/rocm_version.h"
74+
75+
# The file could be missing due to 1) ROCm version < 5.2, or 2) no ROCm install.
76+
if os.path.isfile(rocm_version_h):
77+
RE_MAJOR = re.compile(r"#define\s+ROCM_VERSION_MAJOR\s+(\d+)")
78+
RE_MINOR = re.compile(r"#define\s+ROCM_VERSION_MINOR\s+(\d+)")
79+
RE_PATCH = re.compile(r"#define\s+ROCM_VERSION_PATCH\s+(\d+)")
80+
major, minor, patch = 0, 0, 0
81+
for line in open(rocm_version_h):
82+
match = RE_MAJOR.search(line)
83+
if match:
84+
major = int(match.group(1))
85+
match = RE_MINOR.search(line)
86+
if match:
87+
minor = int(match.group(1))
88+
match = RE_PATCH.search(line)
89+
if match:
90+
patch = int(match.group(1))
91+
rocm_version = str(major)+"."+str(minor)+"."+str(patch)
92+
return rocm_version
6793

6894
def build_triton(
6995
*,
@@ -170,7 +196,8 @@ def build_triton(
170196
version=f"{version}",
171197
expected_version=None,
172198
)
173-
check_call("scripts/amd/setup_rocm_libs.sh", cwd=triton_basedir, shell=True)
199+
cur_rocm_ver = get_rocm_version()
200+
check_call(["scripts/amd/setup_rocm_libs.sh", cur_rocm_ver], cwd=triton_basedir)
174201
print("ROCm libraries setup for triton installation...")
175202

176203
check_call(

0 commit comments

Comments
 (0)