Skip to content

Commit 3ede817

Browse files
authored
[Libomptarget] Fix JIT on the NVPTX target by calling ptx manually (#77801)
Summary: Recently a patch added an assertion in the GlobalHandler to indicate when an ELF was not used. This began to fire whenever NVPTX JIT was used, because the JIT pass output a PTX file instead of an ELF. The CUModuleLoad method consumes `.s` internally and compiles it to a cubin, however, this is too late as we perform several checks on the ELF directly for the presence of certain symbols and to read some necessary constants. This results in inconsistent behaviour. To address this, this patch simply calls `ptxas` manually, similar to how `lld` is called for the AMDGPU JIT pass. This is inevitably going to be slower than simply passing it to the CUDA routine due to the overhead involved in file IO and a fork call, but it's necessary for correctness. CUDA provides an API for compiling PTX manually. However, this only started showing up in CUDA 11.1 and is only provided "officially" in a static library. The `libnvidia-ptxjitcompiler.so` next to the CUDA driver has the same symbols and can likely be used as a replacement. This would be the faster solution. However, given that it's not documented it may have some issues.
1 parent 114e6d7 commit 3ede817

File tree

1 file changed

+62
-0
lines changed
  • openmp/libomptarget/plugins-nextgen/cuda/src

1 file changed

+62
-0
lines changed

openmp/libomptarget/plugins-nextgen/cuda/src/rtl.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
#include "llvm/Frontend/OpenMP/OMPConstants.h"
2929
#include "llvm/Frontend/OpenMP/OMPGridValues.h"
3030
#include "llvm/Support/Error.h"
31+
#include "llvm/Support/FileOutputBuffer.h"
32+
#include "llvm/Support/FileSystem.h"
33+
#include "llvm/Support/Program.h"
3134

3235
namespace llvm {
3336
namespace omp {
@@ -397,6 +400,65 @@ struct CUDADeviceTy : public GenericDeviceTy {
397400
return callGlobalCtorDtorCommon(Plugin, Image, /*IsCtor=*/false);
398401
}
399402

403+
Expected<std::unique_ptr<MemoryBuffer>>
404+
doJITPostProcessing(std::unique_ptr<MemoryBuffer> MB) const override {
405+
// TODO: We should be able to use the 'nvidia-ptxjitcompiler' interface to
406+
// avoid the call to 'ptxas'.
407+
SmallString<128> PTXInputFilePath;
408+
std::error_code EC = sys::fs::createTemporaryFile("nvptx-pre-link-jit", "s",
409+
PTXInputFilePath);
410+
if (EC)
411+
return Plugin::error("Failed to create temporary file for ptxas");
412+
413+
// Write the file's contents to the output file.
414+
Expected<std::unique_ptr<FileOutputBuffer>> OutputOrErr =
415+
FileOutputBuffer::create(PTXInputFilePath, MB->getBuffer().size());
416+
if (!OutputOrErr)
417+
return OutputOrErr.takeError();
418+
std::unique_ptr<FileOutputBuffer> Output = std::move(*OutputOrErr);
419+
llvm::copy(MB->getBuffer(), Output->getBufferStart());
420+
if (Error E = Output->commit())
421+
return std::move(E);
422+
423+
SmallString<128> PTXOutputFilePath;
424+
EC = sys::fs::createTemporaryFile("nvptx-post-link-jit", "cubin",
425+
PTXOutputFilePath);
426+
if (EC)
427+
return Plugin::error("Failed to create temporary file for ptxas");
428+
429+
// Try to find `ptxas` in the path to compile the PTX to a binary.
430+
const auto ErrorOrPath = sys::findProgramByName("ptxas");
431+
if (!ErrorOrPath)
432+
return Plugin::error("Failed to find 'ptxas' on the PATH.");
433+
434+
std::string Arch = getComputeUnitKind();
435+
StringRef Args[] = {*ErrorOrPath,
436+
"-m64",
437+
"-O2",
438+
"--gpu-name",
439+
Arch,
440+
"--output-file",
441+
PTXOutputFilePath,
442+
PTXInputFilePath};
443+
444+
std::string ErrMsg;
445+
if (sys::ExecuteAndWait(*ErrorOrPath, Args, std::nullopt, {}, 0, 0,
446+
&ErrMsg))
447+
return Plugin::error("Running 'ptxas' failed: %s\n", ErrMsg.c_str());
448+
449+
auto BufferOrErr = MemoryBuffer::getFileOrSTDIN(PTXOutputFilePath.data());
450+
if (!BufferOrErr)
451+
return Plugin::error("Failed to open temporary file for ptxas");
452+
453+
// Clean up the temporary files afterwards.
454+
if (sys::fs::remove(PTXOutputFilePath))
455+
return Plugin::error("Failed to remove temporary file for ptxas");
456+
if (sys::fs::remove(PTXInputFilePath))
457+
return Plugin::error("Failed to remove temporary file for ptxas");
458+
459+
return std::move(*BufferOrErr);
460+
}
461+
400462
/// Allocate and construct a CUDA kernel.
401463
Expected<GenericKernelTy &>
402464
constructKernel(const __tgt_offload_entry &KernelEntry) override {

0 commit comments

Comments
 (0)