Skip to content

[mlir][Target][NVPTX] Add fatbin support to NVPTX compilation. #65398

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,23 @@ class TargetOptions {
public:
/// The target representation of the compilation process.
typedef enum {
offload, /// The process should produce an offloading representation. For
/// the NVVM & ROCDL targets this option produces LLVM IR.
assembly, /// The process should produce assembly code.
binary /// The process should produce a binary.
offload = 1, /// The process should produce an offloading representation.
/// For the NVVM & ROCDL targets this option produces LLVM IR.
assembly = 2, /// The process should produce assembly code.
binary = 4, /// The process should produce a binary.
fatbinary = 8, /// The process should produce a fat binary.
binOrFatbin =
binary |
fatbinary, /// The process should produce a binary or fatbinary. It's up
/// to the target to decide which.
} CompilationTarget;

/// Constructor initializing the toolkit path, the list of files to link to,
/// extra command line options & the compilation target. The default
/// compilation target is `binary`.
TargetOptions(StringRef toolkitPath = {},
ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
CompilationTarget compilationTarget = binary);
CompilationTarget compilationTarget = binOrFatbin);

/// Returns the typeID.
TypeID getTypeID() const;
Expand All @@ -80,7 +85,7 @@ class TargetOptions {
/// appropiate value: ie. `TargetOptions(TypeID::get<DerivedClass>())`.
TargetOptions(TypeID typeID, StringRef toolkitPath = {},
ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
CompilationTarget compilationTarget = binary);
CompilationTarget compilationTarget = binOrFatbin);

/// Path to the target toolkit.
std::string toolkitPath;
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def GpuModuleToBinaryPass
"Extra files to link to.">,
Option<"cmdOptions", "opts", "std::string", [{""}],
"Command line options to pass to the tools.">,
Option<"compilationTarget", "format", "std::string", [{"bin"}],
Option<"compilationTarget", "format", "std::string", [{"binOrFatbin"}],
"The target representation of the compilation process.">
];
}
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ void GpuModuleToBinaryPass::runOnOperation() {
.Cases("offloading", "llvm", TargetOptions::offload)
.Cases("assembly", "isa", TargetOptions::assembly)
.Cases("binary", "bin", TargetOptions::binary)
.Cases("fatbinary", "fatbin", TargetOptions::fatbinary)
.Case("binOrFatbin", TargetOptions::binOrFatbin)
.Default(-1);
if (targetFormat == -1)
getOperation()->emitError() << "Invalid format specified.";
Expand Down
200 changes: 133 additions & 67 deletions mlir/lib/Target/LLVM/NVVM/Target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,12 @@ class NVPTXSerializer : public SerializeGPUModuleBase {
// Create a temp file.
std::optional<TmpFile> createTemp(StringRef name, StringRef suffix);

// Find the PTXAS compiler. The search order is:
// Find the `tool` path, where `tool` is the name of the binary to search,
// i.e. `ptxas` or `fatbinary`. The search order is:
// 1. The toolkit path in `targetOptions`.
// 2. In the system PATH.
// 3. The path from `getCUDAToolkitPath()`.
std::optional<std::string> findPtxas() const;
std::optional<std::string> findTool(StringRef tool);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please document the tool arg


// Target options.
gpu::TargetOptions targetOptions;
Expand Down Expand Up @@ -213,48 +214,58 @@ gpu::GPUModuleOp NVPTXSerializer::getOperation() {
return dyn_cast<gpu::GPUModuleOp>(&SerializeGPUModuleBase::getOperation());
}

std::optional<std::string> NVPTXSerializer::findPtxas() const {
// Find the `ptxas` compiler.
std::optional<std::string> NVPTXSerializer::findTool(StringRef tool) {
// Find the `tool` path.
// 1. Check the toolkit path given in the command line.
StringRef pathRef = targetOptions.getToolkitPath();
SmallVector<char, 256> path;
if (pathRef.size()) {
path.insert(path.begin(), pathRef.begin(), pathRef.end());
llvm::sys::path::append(path, "bin", "ptxas");
llvm::sys::path::append(path, "bin", tool);
if (llvm::sys::fs::can_execute(path))
return StringRef(path.data(), path.size()).str();
}

// 2. Check PATH.
if (std::optional<std::string> ptxasCompiler =
llvm::sys::Process::FindInEnvPath("PATH", "ptxas"))
llvm::sys::Process::FindInEnvPath("PATH", tool))
return *ptxasCompiler;

// 3. Check `getCUDAToolkitPath()`.
pathRef = getCUDAToolkitPath();
path.clear();
if (pathRef.size()) {
path.insert(path.begin(), pathRef.begin(), pathRef.end());
llvm::sys::path::append(path, "bin", "ptxas");
llvm::sys::path::append(path, "bin", tool);
if (llvm::sys::fs::can_execute(path))
return StringRef(path.data(), path.size()).str();
}
getOperation().emitError()
<< "Couldn't find the `" << tool
<< "` binary. Please specify the toolkit "
"path, add the compiler to $PATH, or set one of the environment "
"variables in `NVVM::getCUDAToolkitPath()`.";
return std::nullopt;
}

// TODO: clean this method & have a generic tool driver or never emit binaries
// with this mechanism and let another stage take care of it.
std::optional<SmallVector<char, 0>>
NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
// Find the PTXAS compiler.
std::optional<std::string> ptxasCompiler = findPtxas();
if (!ptxasCompiler) {
getOperation().emitError()
<< "Couldn't find the `ptxas` compiler. Please specify the toolkit "
"path, add the compiler to $PATH, or set one of the environment "
"variables in `NVVM::getCUDAToolkitPath()`.";
// Determine if the serializer should create a fatbinary with the PTX embeded
// or a simple CUBIN binary.
const bool createFatbin =
(targetOptions.getCompilationTarget() & gpu::TargetOptions::fatbinary) ==
gpu::TargetOptions::fatbinary;

// Find the `ptxas` & `fatbinary` tools.
std::optional<std::string> ptxasCompiler = findTool("ptxas");
if (!ptxasCompiler)
return std::nullopt;
}
std::optional<std::string> fatbinaryTool = findTool("fatbinary");
if (createFatbin && !fatbinaryTool)
return std::nullopt;
Location loc = getOperation().getLoc();

// Base name for all temp files: mlir-<module name>-<target triple>-<chip>.
std::string basename =
Expand All @@ -268,99 +279,154 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
std::optional<TmpFile> logFile = createTemp(basename, "log");
if (!logFile)
return std::nullopt;
std::optional<TmpFile> cubinFile = createTemp(basename, "cubin");
if (!cubinFile)
std::optional<TmpFile> binaryFile = createTemp(basename, "bin");
if (!binaryFile)
return std::nullopt;
TmpFile cubinFile;
if (createFatbin) {
Twine cubinFilename = ptxFile->first + ".cubin";
cubinFile = TmpFile(cubinFilename.str(), llvm::FileRemover(cubinFilename));
} else {
cubinFile.first = binaryFile->first;
}

std::error_code ec;
// Dump the PTX to a temp file.
{
llvm::raw_fd_ostream ptxStream(ptxFile->first, ec);
if (ec) {
getOperation().emitError()
<< "Couldn't open the file: `" << ptxFile->first
<< "`, error message: " << ec.message();
emitError(loc) << "Couldn't open the file: `" << ptxFile->first
<< "`, error message: " << ec.message();
return std::nullopt;
}
ptxStream << ptxCode;
if (ptxStream.has_error()) {
getOperation().emitError()
<< "An error occurred while writing the PTX to: `" << ptxFile->first
<< "`.";
emitError(loc) << "An error occurred while writing the PTX to: `"
<< ptxFile->first << "`.";
return std::nullopt;
}
ptxStream.flush();
}

// Create PTX args.
// Command redirects.
std::optional<StringRef> redirects[] = {
std::nullopt,
logFile->first,
logFile->first,
};

// Get any extra args passed in `targetOptions`.
std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> cmdOpts =
targetOptions.tokenizeCmdOptions();

// Create ptxas args.
std::string optLevel = std::to_string(this->optLevel);
SmallVector<StringRef, 12> ptxasArgs(
{StringRef("ptxas"), StringRef("-arch"), getTarget().getChip(),
StringRef(ptxFile->first), StringRef("-o"), StringRef(cubinFile->first),
StringRef(ptxFile->first), StringRef("-o"), StringRef(cubinFile.first),
"--opt-level", optLevel});

std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> cmdOpts =
targetOptions.tokenizeCmdOptions();
for (auto arg : cmdOpts.second)
ptxasArgs.push_back(arg);
bool useFatbin32 = false;
for (auto cArg : cmdOpts.second) {
// All `cmdOpts` are for `ptxas` except `-32` which passes `-32` to
// `fatbinary`, indicating a 32-bit target. By default a 64-bit target is
// assumed.
if (StringRef arg(cArg); arg != "-32")
ptxasArgs.push_back(arg);
else
useFatbin32 = true;
}

std::optional<StringRef> redirects[] = {
std::nullopt,
logFile->first,
logFile->first,
};
// Create the `fatbinary` args.
StringRef chip = getTarget().getChip();
// Remove the arch prefix to obtain the compute capability.
chip.consume_front("sm_"), chip.consume_front("compute_");
// Embed the cubin object.
std::string cubinArg =
llvm::formatv("--image3=kind=elf,sm={0},file={1}", chip, cubinFile.first)
.str();
// Embed the PTX file so the driver can JIT if needed.
std::string ptxArg =
llvm::formatv("--image3=kind=ptx,sm={0},file={1}", chip, ptxFile->first)
.str();
SmallVector<StringRef, 6> fatbinArgs({StringRef("fatbinary"),
useFatbin32 ? "-32" : "-64", cubinArg,
ptxArg, "--create", binaryFile->first});

// Dump tool invocation commands.
#define DEBUG_TYPE "serialize-to-binary"
LLVM_DEBUG({
llvm::dbgs() << "Tool invocation for module: "
<< getOperation().getNameAttr() << "\n";
llvm::interleave(ptxasArgs, llvm::dbgs(), " ");
llvm::dbgs() << "\n";
if (createFatbin) {
llvm::interleave(fatbinArgs, llvm::dbgs(), " ");
llvm::dbgs() << "\n";
}
});
#undef DEBUG_TYPE

// Invoke PTXAS.
// Helper function for printing tool error logs.
std::string message;
if (llvm::sys::ExecuteAndWait(ptxasCompiler.value(), ptxasArgs,
/*Env=*/std::nullopt,
/*Redirects=*/redirects,
/*SecondsToWait=*/0,
/*MemoryLimit=*/0,
/*ErrMsg=*/&message)) {
auto emitLogError =
[&](StringRef toolName) -> std::optional<SmallVector<char, 0>> {
if (message.empty()) {
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> ptxasStderr =
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> toolStderr =
llvm::MemoryBuffer::getFile(logFile->first);
if (ptxasStderr)
getOperation().emitError() << "PTXAS invocation failed. PTXAS log:\n"
<< ptxasStderr->get()->getBuffer();
if (toolStderr)
emitError(loc) << toolName << " invocation failed. Log:\n"
<< toolStderr->get()->getBuffer();
else
getOperation().emitError() << "PTXAS invocation failed.";
emitError(loc) << toolName << " invocation failed.";
return std::nullopt;
}
getOperation().emitError()
<< "PTXAS invocation failed, error message: " << message;
emitError(loc) << toolName
<< " invocation failed, error message: " << message;
return std::nullopt;
}
};

// Dump the output of PTXAS, helpful if the verbose flag was passed.
// Invoke PTXAS.
if (llvm::sys::ExecuteAndWait(ptxasCompiler.value(), ptxasArgs,
/*Env=*/std::nullopt,
/*Redirects=*/redirects,
/*SecondsToWait=*/0,
/*MemoryLimit=*/0,
/*ErrMsg=*/&message))
return emitLogError("`ptxas`");

// Invoke `fatbin`.
message.clear();
if (createFatbin && llvm::sys::ExecuteAndWait(*fatbinaryTool, fatbinArgs,
/*Env=*/std::nullopt,
/*Redirects=*/redirects,
/*SecondsToWait=*/0,
/*MemoryLimit=*/0,
/*ErrMsg=*/&message))
return emitLogError("`fatbinary`");

// Dump the output of the tools, helpful if the verbose flag was passed.
#define DEBUG_TYPE "serialize-to-binary"
LLVM_DEBUG({
llvm::dbgs() << "PTXAS invocation for module: "
<< getOperation().getNameAttr() << "\n";
llvm::dbgs() << "Command: ";
llvm::interleave(ptxasArgs, llvm::dbgs(), " ");
llvm::dbgs() << "\n";
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> ptxasLog =
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> logBuffer =
llvm::MemoryBuffer::getFile(logFile->first);
if (ptxasLog && (*ptxasLog)->getBuffer().size()) {
llvm::dbgs() << "Output:\n" << (*ptxasLog)->getBuffer() << "\n";
if (logBuffer && (*logBuffer)->getBuffer().size()) {
llvm::dbgs() << "Output:\n" << (*logBuffer)->getBuffer() << "\n";
llvm::dbgs().flush();
}
});
#undef DEBUG_TYPE

// Read the cubin file.
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> cubinBuffer =
llvm::MemoryBuffer::getFile(cubinFile->first);
if (!cubinBuffer) {
getOperation().emitError()
<< "Couldn't open the file: `" << cubinFile->first
<< "`, error message: " << cubinBuffer.getError().message();
// Read the fatbin.
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> binaryBuffer =
llvm::MemoryBuffer::getFile(binaryFile->first);
if (!binaryBuffer) {
emitError(loc) << "Couldn't open the file: `" << binaryFile->first
<< "`, error message: " << binaryBuffer.getError().message();
return std::nullopt;
}
StringRef cubinStr = (*cubinBuffer)->getBuffer();
return SmallVector<char, 0>(cubinStr.begin(), cubinStr.end());
StringRef fatbin = (*binaryBuffer)->getBuffer();
return SmallVector<char, 0>(fatbin.begin(), fatbin.end());
}

#if MLIR_NVPTXCOMPILER_ENABLED == 1
Expand Down