Skip to content

Commit c16adb0

Browse files
authored
[mlir][Target][NVPTX] Add fatbin support to NVPTX compilation. (#65398)
Currently, the NVPTX tool compilation path only calls `ptxas`; thus, the GPU running the binary must be an exact match of the arch of the target, or else the runtime throws an error due to the arch mismatch. This patch adds a call to `fatbinary`, creating a fat binary with the cubin object and the PTX code, allowing the driver to JIT the PTX at runtime if there's an arch mismatch.
1 parent 43c2036 commit c16adb0

File tree

4 files changed

+147
-74
lines changed

4 files changed

+147
-74
lines changed

mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,23 @@ class TargetOptions {
4343
public:
4444
/// The target representation of the compilation process.
4545
typedef enum {
46-
offload, /// The process should produce an offloading representation. For
47-
/// the NVVM & ROCDL targets this option produces LLVM IR.
48-
assembly, /// The process should produce assembly code.
49-
binary /// The process should produce a binary.
46+
offload = 1, /// The process should produce an offloading representation.
47+
/// For the NVVM & ROCDL targets this option produces LLVM IR.
48+
assembly = 2, /// The process should produce assembly code.
49+
binary = 4, /// The process should produce a binary.
50+
fatbinary = 8, /// The process should produce a fat binary.
51+
binOrFatbin =
52+
binary |
53+
fatbinary, /// The process should produce a binary or fatbinary. It's up
54+
/// to the target to decide which.
5055
} CompilationTarget;
5156

5257
/// Constructor initializing the toolkit path, the list of files to link to,
5358
/// extra command line options & the compilation target. The default
5459
/// compilation target is `binary`.
5560
TargetOptions(StringRef toolkitPath = {},
5661
ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
57-
CompilationTarget compilationTarget = binary);
62+
CompilationTarget compilationTarget = binOrFatbin);
5863

5964
/// Returns the typeID.
6065
TypeID getTypeID() const;
@@ -80,7 +85,7 @@ class TargetOptions {
8085
/// appropiate value: ie. `TargetOptions(TypeID::get<DerivedClass>())`.
8186
TargetOptions(TypeID typeID, StringRef toolkitPath = {},
8287
ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
83-
CompilationTarget compilationTarget = binary);
88+
CompilationTarget compilationTarget = binOrFatbin);
8489

8590
/// Path to the target toolkit.
8691
std::string toolkitPath;

mlir/include/mlir/Dialect/GPU/Transforms/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def GpuModuleToBinaryPass
7777
"Extra files to link to.">,
7878
Option<"cmdOptions", "opts", "std::string", [{""}],
7979
"Command line options to pass to the tools.">,
80-
Option<"compilationTarget", "format", "std::string", [{"bin"}],
80+
Option<"compilationTarget", "format", "std::string", [{"binOrFatbin"}],
8181
"The target representation of the compilation process.">
8282
];
8383
}

mlir/lib/Dialect/GPU/Transforms/ModuleToBinary.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ void GpuModuleToBinaryPass::runOnOperation() {
6161
.Cases("offloading", "llvm", TargetOptions::offload)
6262
.Cases("assembly", "isa", TargetOptions::assembly)
6363
.Cases("binary", "bin", TargetOptions::binary)
64+
.Cases("fatbinary", "fatbin", TargetOptions::fatbinary)
65+
.Case("binOrFatbin", TargetOptions::binOrFatbin)
6466
.Default(-1);
6567
if (targetFormat == -1)
6668
getOperation()->emitError() << "Invalid format specified.";

mlir/lib/Target/LLVM/NVVM/Target.cpp

Lines changed: 133 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,12 @@ class NVPTXSerializer : public SerializeGPUModuleBase {
180180
// Create a temp file.
181181
std::optional<TmpFile> createTemp(StringRef name, StringRef suffix);
182182

183-
// Find the PTXAS compiler. The search order is:
183+
// Find the `tool` path, where `tool` is the name of the binary to search,
184+
// i.e. `ptxas` or `fatbinary`. The search order is:
184185
// 1. The toolkit path in `targetOptions`.
185186
// 2. In the system PATH.
186187
// 3. The path from `getCUDAToolkitPath()`.
187-
std::optional<std::string> findPtxas() const;
188+
std::optional<std::string> findTool(StringRef tool);
188189

189190
// Target options.
190191
gpu::TargetOptions targetOptions;
@@ -213,48 +214,58 @@ gpu::GPUModuleOp NVPTXSerializer::getOperation() {
213214
return dyn_cast<gpu::GPUModuleOp>(&SerializeGPUModuleBase::getOperation());
214215
}
215216

216-
std::optional<std::string> NVPTXSerializer::findPtxas() const {
217-
// Find the `ptxas` compiler.
217+
std::optional<std::string> NVPTXSerializer::findTool(StringRef tool) {
218+
// Find the `tool` path.
218219
// 1. Check the toolkit path given in the command line.
219220
StringRef pathRef = targetOptions.getToolkitPath();
220221
SmallVector<char, 256> path;
221222
if (pathRef.size()) {
222223
path.insert(path.begin(), pathRef.begin(), pathRef.end());
223-
llvm::sys::path::append(path, "bin", "ptxas");
224+
llvm::sys::path::append(path, "bin", tool);
224225
if (llvm::sys::fs::can_execute(path))
225226
return StringRef(path.data(), path.size()).str();
226227
}
227228

228229
// 2. Check PATH.
229230
if (std::optional<std::string> ptxasCompiler =
230-
llvm::sys::Process::FindInEnvPath("PATH", "ptxas"))
231+
llvm::sys::Process::FindInEnvPath("PATH", tool))
231232
return *ptxasCompiler;
232233

233234
// 3. Check `getCUDAToolkitPath()`.
234235
pathRef = getCUDAToolkitPath();
235236
path.clear();
236237
if (pathRef.size()) {
237238
path.insert(path.begin(), pathRef.begin(), pathRef.end());
238-
llvm::sys::path::append(path, "bin", "ptxas");
239+
llvm::sys::path::append(path, "bin", tool);
239240
if (llvm::sys::fs::can_execute(path))
240241
return StringRef(path.data(), path.size()).str();
241242
}
243+
getOperation().emitError()
244+
<< "Couldn't find the `" << tool
245+
<< "` binary. Please specify the toolkit "
246+
"path, add the compiler to $PATH, or set one of the environment "
247+
"variables in `NVVM::getCUDAToolkitPath()`.";
242248
return std::nullopt;
243249
}
244250

245251
// TODO: clean this method & have a generic tool driver or never emit binaries
246252
// with this mechanism and let another stage take care of it.
247253
std::optional<SmallVector<char, 0>>
248254
NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
249-
// Find the PTXAS compiler.
250-
std::optional<std::string> ptxasCompiler = findPtxas();
251-
if (!ptxasCompiler) {
252-
getOperation().emitError()
253-
<< "Couldn't find the `ptxas` compiler. Please specify the toolkit "
254-
"path, add the compiler to $PATH, or set one of the environment "
255-
"variables in `NVVM::getCUDAToolkitPath()`.";
255+
// Determine if the serializer should create a fatbinary with the PTX embeded
256+
// or a simple CUBIN binary.
257+
const bool createFatbin =
258+
(targetOptions.getCompilationTarget() & gpu::TargetOptions::fatbinary) ==
259+
gpu::TargetOptions::fatbinary;
260+
261+
// Find the `ptxas` & `fatbinary` tools.
262+
std::optional<std::string> ptxasCompiler = findTool("ptxas");
263+
if (!ptxasCompiler)
256264
return std::nullopt;
257-
}
265+
std::optional<std::string> fatbinaryTool = findTool("fatbinary");
266+
if (createFatbin && !fatbinaryTool)
267+
return std::nullopt;
268+
Location loc = getOperation().getLoc();
258269

259270
// Base name for all temp files: mlir-<module name>-<target triple>-<chip>.
260271
std::string basename =
@@ -268,99 +279,154 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
268279
std::optional<TmpFile> logFile = createTemp(basename, "log");
269280
if (!logFile)
270281
return std::nullopt;
271-
std::optional<TmpFile> cubinFile = createTemp(basename, "cubin");
272-
if (!cubinFile)
282+
std::optional<TmpFile> binaryFile = createTemp(basename, "bin");
283+
if (!binaryFile)
273284
return std::nullopt;
285+
TmpFile cubinFile;
286+
if (createFatbin) {
287+
Twine cubinFilename = ptxFile->first + ".cubin";
288+
cubinFile = TmpFile(cubinFilename.str(), llvm::FileRemover(cubinFilename));
289+
} else {
290+
cubinFile.first = binaryFile->first;
291+
}
274292

275293
std::error_code ec;
276294
// Dump the PTX to a temp file.
277295
{
278296
llvm::raw_fd_ostream ptxStream(ptxFile->first, ec);
279297
if (ec) {
280-
getOperation().emitError()
281-
<< "Couldn't open the file: `" << ptxFile->first
282-
<< "`, error message: " << ec.message();
298+
emitError(loc) << "Couldn't open the file: `" << ptxFile->first
299+
<< "`, error message: " << ec.message();
283300
return std::nullopt;
284301
}
285302
ptxStream << ptxCode;
286303
if (ptxStream.has_error()) {
287-
getOperation().emitError()
288-
<< "An error occurred while writing the PTX to: `" << ptxFile->first
289-
<< "`.";
304+
emitError(loc) << "An error occurred while writing the PTX to: `"
305+
<< ptxFile->first << "`.";
290306
return std::nullopt;
291307
}
292308
ptxStream.flush();
293309
}
294310

295-
// Create PTX args.
311+
// Command redirects.
312+
std::optional<StringRef> redirects[] = {
313+
std::nullopt,
314+
logFile->first,
315+
logFile->first,
316+
};
317+
318+
// Get any extra args passed in `targetOptions`.
319+
std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> cmdOpts =
320+
targetOptions.tokenizeCmdOptions();
321+
322+
// Create ptxas args.
296323
std::string optLevel = std::to_string(this->optLevel);
297324
SmallVector<StringRef, 12> ptxasArgs(
298325
{StringRef("ptxas"), StringRef("-arch"), getTarget().getChip(),
299-
StringRef(ptxFile->first), StringRef("-o"), StringRef(cubinFile->first),
326+
StringRef(ptxFile->first), StringRef("-o"), StringRef(cubinFile.first),
300327
"--opt-level", optLevel});
301328

302-
std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> cmdOpts =
303-
targetOptions.tokenizeCmdOptions();
304-
for (auto arg : cmdOpts.second)
305-
ptxasArgs.push_back(arg);
329+
bool useFatbin32 = false;
330+
for (auto cArg : cmdOpts.second) {
331+
// All `cmdOpts` are for `ptxas` except `-32` which passes `-32` to
332+
// `fatbinary`, indicating a 32-bit target. By default a 64-bit target is
333+
// assumed.
334+
if (StringRef arg(cArg); arg != "-32")
335+
ptxasArgs.push_back(arg);
336+
else
337+
useFatbin32 = true;
338+
}
306339

307-
std::optional<StringRef> redirects[] = {
308-
std::nullopt,
309-
logFile->first,
310-
logFile->first,
311-
};
340+
// Create the `fatbinary` args.
341+
StringRef chip = getTarget().getChip();
342+
// Remove the arch prefix to obtain the compute capability.
343+
chip.consume_front("sm_"), chip.consume_front("compute_");
344+
// Embed the cubin object.
345+
std::string cubinArg =
346+
llvm::formatv("--image3=kind=elf,sm={0},file={1}", chip, cubinFile.first)
347+
.str();
348+
// Embed the PTX file so the driver can JIT if needed.
349+
std::string ptxArg =
350+
llvm::formatv("--image3=kind=ptx,sm={0},file={1}", chip, ptxFile->first)
351+
.str();
352+
SmallVector<StringRef, 6> fatbinArgs({StringRef("fatbinary"),
353+
useFatbin32 ? "-32" : "-64", cubinArg,
354+
ptxArg, "--create", binaryFile->first});
355+
356+
// Dump tool invocation commands.
357+
#define DEBUG_TYPE "serialize-to-binary"
358+
LLVM_DEBUG({
359+
llvm::dbgs() << "Tool invocation for module: "
360+
<< getOperation().getNameAttr() << "\n";
361+
llvm::interleave(ptxasArgs, llvm::dbgs(), " ");
362+
llvm::dbgs() << "\n";
363+
if (createFatbin) {
364+
llvm::interleave(fatbinArgs, llvm::dbgs(), " ");
365+
llvm::dbgs() << "\n";
366+
}
367+
});
368+
#undef DEBUG_TYPE
312369

313-
// Invoke PTXAS.
370+
// Helper function for printing tool error logs.
314371
std::string message;
315-
if (llvm::sys::ExecuteAndWait(ptxasCompiler.value(), ptxasArgs,
316-
/*Env=*/std::nullopt,
317-
/*Redirects=*/redirects,
318-
/*SecondsToWait=*/0,
319-
/*MemoryLimit=*/0,
320-
/*ErrMsg=*/&message)) {
372+
auto emitLogError =
373+
[&](StringRef toolName) -> std::optional<SmallVector<char, 0>> {
321374
if (message.empty()) {
322-
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> ptxasStderr =
375+
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> toolStderr =
323376
llvm::MemoryBuffer::getFile(logFile->first);
324-
if (ptxasStderr)
325-
getOperation().emitError() << "PTXAS invocation failed. PTXAS log:\n"
326-
<< ptxasStderr->get()->getBuffer();
377+
if (toolStderr)
378+
emitError(loc) << toolName << " invocation failed. Log:\n"
379+
<< toolStderr->get()->getBuffer();
327380
else
328-
getOperation().emitError() << "PTXAS invocation failed.";
381+
emitError(loc) << toolName << " invocation failed.";
329382
return std::nullopt;
330383
}
331-
getOperation().emitError()
332-
<< "PTXAS invocation failed, error message: " << message;
384+
emitError(loc) << toolName
385+
<< " invocation failed, error message: " << message;
333386
return std::nullopt;
334-
}
387+
};
335388

336-
// Dump the output of PTXAS, helpful if the verbose flag was passed.
389+
// Invoke PTXAS.
390+
if (llvm::sys::ExecuteAndWait(ptxasCompiler.value(), ptxasArgs,
391+
/*Env=*/std::nullopt,
392+
/*Redirects=*/redirects,
393+
/*SecondsToWait=*/0,
394+
/*MemoryLimit=*/0,
395+
/*ErrMsg=*/&message))
396+
return emitLogError("`ptxas`");
397+
398+
// Invoke `fatbin`.
399+
message.clear();
400+
if (createFatbin && llvm::sys::ExecuteAndWait(*fatbinaryTool, fatbinArgs,
401+
/*Env=*/std::nullopt,
402+
/*Redirects=*/redirects,
403+
/*SecondsToWait=*/0,
404+
/*MemoryLimit=*/0,
405+
/*ErrMsg=*/&message))
406+
return emitLogError("`fatbinary`");
407+
408+
// Dump the output of the tools, helpful if the verbose flag was passed.
337409
#define DEBUG_TYPE "serialize-to-binary"
338410
LLVM_DEBUG({
339-
llvm::dbgs() << "PTXAS invocation for module: "
340-
<< getOperation().getNameAttr() << "\n";
341-
llvm::dbgs() << "Command: ";
342-
llvm::interleave(ptxasArgs, llvm::dbgs(), " ");
343-
llvm::dbgs() << "\n";
344-
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> ptxasLog =
411+
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> logBuffer =
345412
llvm::MemoryBuffer::getFile(logFile->first);
346-
if (ptxasLog && (*ptxasLog)->getBuffer().size()) {
347-
llvm::dbgs() << "Output:\n" << (*ptxasLog)->getBuffer() << "\n";
413+
if (logBuffer && (*logBuffer)->getBuffer().size()) {
414+
llvm::dbgs() << "Output:\n" << (*logBuffer)->getBuffer() << "\n";
348415
llvm::dbgs().flush();
349416
}
350417
});
351418
#undef DEBUG_TYPE
352419

353-
// Read the cubin file.
354-
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> cubinBuffer =
355-
llvm::MemoryBuffer::getFile(cubinFile->first);
356-
if (!cubinBuffer) {
357-
getOperation().emitError()
358-
<< "Couldn't open the file: `" << cubinFile->first
359-
<< "`, error message: " << cubinBuffer.getError().message();
420+
// Read the fatbin.
421+
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> binaryBuffer =
422+
llvm::MemoryBuffer::getFile(binaryFile->first);
423+
if (!binaryBuffer) {
424+
emitError(loc) << "Couldn't open the file: `" << binaryFile->first
425+
<< "`, error message: " << binaryBuffer.getError().message();
360426
return std::nullopt;
361427
}
362-
StringRef cubinStr = (*cubinBuffer)->getBuffer();
363-
return SmallVector<char, 0>(cubinStr.begin(), cubinStr.end());
428+
StringRef fatbin = (*binaryBuffer)->getBuffer();
429+
return SmallVector<char, 0>(fatbin.begin(), fatbin.end());
364430
}
365431

366432
#if MLIR_NVPTXCOMPILER_ENABLED == 1

0 commit comments

Comments
 (0)