11
11
//
12
12
// ===----------------------------------------------------------------------===//
13
13
14
+ #include " llvm/ADT/STLExtras.h"
15
+ #include " llvm/Support/CommandLine.h"
16
+ #include " llvm/Support/ConvertUTF.h"
14
17
#include " llvm/Support/DynamicLibrary.h"
15
18
#include " llvm/Support/Error.h"
19
+ #include " llvm/Support/FileSystem.h"
20
+ #include " llvm/Support/Path.h"
21
+ #include " llvm/Support/Process.h"
22
+ #include " llvm/Support/Program.h"
23
+ #include " llvm/Support/VersionTuple.h"
16
24
#include " llvm/Support/raw_ostream.h"
25
+ #include < algorithm>
26
+ #include < string>
27
+ #include < vector>
28
+
29
+ #ifdef _WIN32
30
+ #include < windows.h>
31
+ #endif
17
32
18
33
using namespace llvm ;
19
34
@@ -31,16 +46,124 @@ typedef hipError_t (*hipGetDeviceCount_t)(int *);
31
46
typedef hipError_t (*hipDeviceGet_t)(int *, int );
32
47
typedef hipError_t (*hipGetDeviceProperties_t)(hipDeviceProp_t *, int );
33
48
34
- int printGPUsByHIP () {
49
+ extern cl::opt<bool > Verbose;
50
+
35
51
#ifdef _WIN32
36
- constexpr const char *DynamicHIPPath = " amdhip64.dll" ;
52
+ static std::vector<std::string> getSearchPaths () {
53
+ std::vector<std::string> Paths;
54
+
55
+ // Get the directory of the current executable
56
+ if (auto MainExe = sys::fs::getMainExecutable (nullptr , nullptr );
57
+ !MainExe.empty ())
58
+ Paths.push_back (sys::path::parent_path (MainExe).str ());
59
+
60
+ // Get the system directory
61
+ wchar_t SystemDirectory[MAX_PATH];
62
+ if (GetSystemDirectoryW (SystemDirectory, MAX_PATH) > 0 ) {
63
+ std::string Utf8SystemDir;
64
+ if (convertUTF16ToUTF8String (
65
+ ArrayRef<UTF16>(reinterpret_cast <const UTF16 *>(SystemDirectory),
66
+ wcslen (SystemDirectory)),
67
+ Utf8SystemDir))
68
+ Paths.push_back (Utf8SystemDir);
69
+ }
70
+
71
+ // Get the Windows directory
72
+ wchar_t WindowsDirectory[MAX_PATH];
73
+ if (GetWindowsDirectoryW (WindowsDirectory, MAX_PATH) > 0 ) {
74
+ std::string Utf8WindowsDir;
75
+ if (convertUTF16ToUTF8String (
76
+ ArrayRef<UTF16>(reinterpret_cast <const UTF16 *>(WindowsDirectory),
77
+ wcslen (WindowsDirectory)),
78
+ Utf8WindowsDir))
79
+ Paths.push_back (Utf8WindowsDir);
80
+ }
81
+
82
+ // Get the current working directory
83
+ SmallVector<char , 256 > CWD;
84
+ if (sys::fs::current_path (CWD))
85
+ Paths.push_back (std::string (CWD.begin (), CWD.end ()));
86
+
87
+ // Get the PATH environment variable
88
+ if (std::optional<std::string> PathEnv = sys::Process::GetEnv (" PATH" )) {
89
+ SmallVector<StringRef, 16 > PathList;
90
+ StringRef (*PathEnv).split (PathList, sys::EnvPathSeparator);
91
+ for (auto &Path : PathList)
92
+ Paths.push_back (Path.str ());
93
+ }
94
+
95
+ return Paths;
96
+ }
97
+
98
+ // Custom comparison function for dll name
99
+ static bool compareVersions (StringRef A, StringRef B) {
100
+ auto ParseVersion = [](StringRef S) -> VersionTuple {
101
+ unsigned Pos = S.find_last_of (' _' );
102
+ StringRef VerStr = (Pos == StringRef::npos) ? S : S.substr (Pos + 1 );
103
+ VersionTuple Vt;
104
+ (void )VersionTuple::parse (VerStr, Vt);
105
+ return Vt;
106
+ };
107
+
108
+ VersionTuple VtA = ParseVersion (A);
109
+ VersionTuple VtB = ParseVersion (B);
110
+ return VtA > VtB;
111
+ }
112
+ #endif
113
+
114
+ // On Windows, prefer amdhip64_n.dll where n is ROCm major version and greater
115
+ // value of n takes precedence. If amdhip64_n.dll is not found, fall back to
116
+ // amdhip64.dll. The reason is that a normal driver installation only has
117
+ // amdhip64_n.dll but we do not know what n is since this program may be used
118
+ // with a future version of HIP runtime.
119
+ //
120
+ // On Linux, always use default libamdhip64.so.
121
+ static std::pair<std::string, bool > findNewestHIPDLL () {
122
+ #ifdef _WIN32
123
+ StringRef HipDLLPrefix = " amdhip64_" ;
124
+ StringRef HipDLLSuffix = " .dll" ;
125
+
126
+ std::vector<std::string> SearchPaths = getSearchPaths ();
127
+ std::vector<std::string> DLLNames;
128
+
129
+ for (const auto &Dir : SearchPaths) {
130
+ std::error_code EC;
131
+ for (sys::fs::directory_iterator DirIt (Dir, EC), DirEnd;
132
+ DirIt != DirEnd && !EC; DirIt.increment (EC)) {
133
+ StringRef Filename = sys::path::filename (DirIt->path ());
134
+ if (Filename.starts_with (HipDLLPrefix) &&
135
+ Filename.ends_with (HipDLLSuffix))
136
+ DLLNames.push_back (sys::path::convert_to_slash (DirIt->path ()));
137
+ }
138
+ if (!DLLNames.empty ())
139
+ break ;
140
+ }
141
+
142
+ if (DLLNames.empty ())
143
+ return {" amdhip64.dll" , true };
144
+
145
+ llvm::sort (DLLNames, compareVersions);
146
+ return {DLLNames[0 ], false };
37
147
#else
38
- constexpr const char *DynamicHIPPath = " libamdhip64.so" ;
148
+ // On Linux, fallback to default shared object
149
+ return {" libamdhip64.so" , true };
39
150
#endif
151
+ }
152
+
153
+ int printGPUsByHIP () {
154
+ auto [DynamicHIPPath, IsFallback] = findNewestHIPDLL ();
155
+
156
+ if (Verbose) {
157
+ if (IsFallback)
158
+ outs () << " Using default HIP runtime: " << DynamicHIPPath << ' \n ' ;
159
+ else
160
+ outs () << " Found HIP runtime: " << DynamicHIPPath << ' \n ' ;
161
+ }
40
162
41
163
std::string ErrMsg;
42
164
auto DynlibHandle = std::make_unique<llvm::sys::DynamicLibrary>(
43
- llvm::sys::DynamicLibrary::getPermanentLibrary (DynamicHIPPath, &ErrMsg));
165
+ llvm::sys::DynamicLibrary::getPermanentLibrary (DynamicHIPPath.c_str (),
166
+ &ErrMsg));
44
167
if (!DynlibHandle->isValid ()) {
45
168
llvm::errs () << " Failed to load " << DynamicHIPPath << " : " << ErrMsg
46
169
<< ' \n ' ;
0 commit comments