11
11
//
12
12
// ===----------------------------------------------------------------------===//
13
13
14
+ #include " llvm/Support/CommandLine.h"
15
+ #include " llvm/Support/ConvertUTF.h"
14
16
#include " llvm/Support/DynamicLibrary.h"
15
17
#include " llvm/Support/Error.h"
18
+ #include " llvm/Support/FileSystem.h"
19
+ #include " llvm/Support/Path.h"
20
+ #include " llvm/Support/Process.h"
21
+ #include " llvm/Support/Program.h"
16
22
#include " llvm/Support/raw_ostream.h"
23
+ #include < algorithm>
24
+ #include < string>
25
+ #include < vector>
26
+
27
+ #ifdef _WIN32
28
+ #include < windows.h>
29
+ #endif
17
30
18
31
using namespace llvm ;
19
32
@@ -31,16 +44,118 @@ typedef hipError_t (*hipGetDeviceCount_t)(int *);
31
44
typedef hipError_t (*hipDeviceGet_t)(int *, int );
32
45
typedef hipError_t (*hipGetDeviceProperties_t)(hipDeviceProp_t *, int );
33
46
34
- int printGPUsByHIP () {
47
+ extern cl::opt<bool > Verbose;
48
+
35
49
#ifdef _WIN32
36
- constexpr const char *DynamicHIPPath = " amdhip64.dll" ;
50
+ static std::vector<std::string> getSearchPaths () {
51
+ std::vector<std::string> Paths;
52
+
53
+ // Get the directory of the current executable
54
+ if (auto MainExe = sys::fs::getMainExecutable (nullptr , nullptr );
55
+ !MainExe.empty ())
56
+ Paths.push_back (sys::path::parent_path (MainExe).str ());
57
+
58
+ // Get the system directory
59
+ wchar_t SystemDirectory[MAX_PATH];
60
+ if (GetSystemDirectoryW (SystemDirectory, MAX_PATH) > 0 ) {
61
+ std::string Utf8SystemDir;
62
+ if (convertUTF16ToUTF8String (
63
+ ArrayRef<UTF16>(reinterpret_cast <const UTF16 *>(SystemDirectory),
64
+ wcslen (SystemDirectory)),
65
+ Utf8SystemDir))
66
+ Paths.push_back (Utf8SystemDir);
67
+ }
68
+
69
+ // Get the Windows directory
70
+ wchar_t WindowsDirectory[MAX_PATH];
71
+ if (GetWindowsDirectoryW (WindowsDirectory, MAX_PATH) > 0 ) {
72
+ std::string Utf8WindowsDir;
73
+ if (convertUTF16ToUTF8String (
74
+ ArrayRef<UTF16>(reinterpret_cast <const UTF16 *>(WindowsDirectory),
75
+ wcslen (WindowsDirectory)),
76
+ Utf8WindowsDir))
77
+ Paths.push_back (Utf8WindowsDir);
78
+ }
79
+
80
+ // Get the current working directory
81
+ SmallVector<char , 256 > CWD;
82
+ if (sys::fs::current_path (CWD))
83
+ Paths.push_back (std::string (CWD.begin (), CWD.end ()));
84
+
85
+ // Get the PATH environment variable
86
+ if (auto PathEnv = sys::Process::GetEnv (" PATH" )) {
87
+ SmallVector<StringRef, 16 > PathList;
88
+ StringRef (*PathEnv).split (PathList, sys::EnvPathSeparator);
89
+ for (auto &Path : PathList)
90
+ Paths.push_back (Path.str ());
91
+ }
92
+
93
+ return Paths;
94
+ }
95
+
96
+ // Custom comparison function for dll name
97
+ static bool compareVersions (const std::string &a, const std::string &b) {
98
+ // Extract version numbers
99
+ int versionA = std::stoi (a.substr (a.find_last_of (' _' ) + 1 ));
100
+ int versionB = std::stoi (b.substr (b.find_last_of (' _' ) + 1 ));
101
+ return versionA > versionB;
102
+ }
103
+
104
+ #endif
105
+
106
+ // On Windows, prefer amdhip64_n.dll where n is ROCm major version and greater
107
+ // value of n takes precedence. If amdhip64_n.dll is not found, fall back to
108
+ // amdhip64.dll. The reason is that a normal driver installation only has
109
+ // amdhip64_n.dll but we do not know what n is since this progrm may be used
110
+ // with a future version of HIP runtime.
111
+ //
112
+ // On Linux, always use default libamdhip64.so.
113
+ static std::pair<std::string, bool > findNewestHIPDLL () {
114
+ #ifdef _WIN32
115
+ StringRef HipDLLPrefix = " amdhip64_" ;
116
+ StringRef HipDLLSuffix = " .dll" ;
117
+
118
+ std::vector<std::string> SearchPaths = getSearchPaths ();
119
+ std::vector<std::string> DLLNames;
120
+
121
+ for (const auto &Dir : SearchPaths) {
122
+ std::error_code EC;
123
+ for (sys::fs::directory_iterator DirIt (Dir, EC), DirEnd;
124
+ DirIt != DirEnd && !EC; DirIt.increment (EC)) {
125
+ StringRef Filename = sys::path::filename (DirIt->path ());
126
+ if (Filename.starts_with (HipDLLPrefix) &&
127
+ Filename.ends_with (HipDLLSuffix))
128
+ DLLNames.push_back (sys::path::convert_to_slash (DirIt->path ()));
129
+ }
130
+ if (!DLLNames.empty ())
131
+ break ;
132
+ }
133
+
134
+ if (DLLNames.empty ())
135
+ return {" amdhip64.dll" , true };
136
+
137
+ std::sort (DLLNames.begin (), DLLNames.end (), compareVersions);
138
+ return {DLLNames[0 ], false };
37
139
#else
38
- constexpr const char *DynamicHIPPath = " libamdhip64.so" ;
140
+ // On Linux, fallback to default shared object
141
+ return {" libamdhip64.so" , true };
39
142
#endif
143
+ }
144
+
145
+ int printGPUsByHIP () {
146
+ auto [DynamicHIPPath, IsFallback] = findNewestHIPDLL ();
147
+
148
+ if (Verbose) {
149
+ if (IsFallback)
150
+ outs () << " Using default HIP runtime: " << DynamicHIPPath << " \n " ;
151
+ else
152
+ outs () << " Found HIP runtime: " << DynamicHIPPath << " \n " ;
153
+ }
40
154
41
155
std::string ErrMsg;
42
156
auto DynlibHandle = std::make_unique<llvm::sys::DynamicLibrary>(
43
- llvm::sys::DynamicLibrary::getPermanentLibrary (DynamicHIPPath, &ErrMsg));
157
+ llvm::sys::DynamicLibrary::getPermanentLibrary (DynamicHIPPath.c_str (),
158
+ &ErrMsg));
44
159
if (!DynlibHandle->isValid ()) {
45
160
llvm::errs () << " Failed to load " << DynamicHIPPath << " : " << ErrMsg
46
161
<< ' \n ' ;
0 commit comments