@@ -180,11 +180,12 @@ class NVPTXSerializer : public SerializeGPUModuleBase {
180
180
// Create a temp file.
181
181
std::optional<TmpFile> createTemp (StringRef name, StringRef suffix);
182
182
183
- // Find the PTXAS compiler. The search order is:
183
+ // Find the `tool` path, where `tool` is the name of the binary to search,
184
+ // i.e. `ptxas` or `fatbinary`. The search order is:
184
185
// 1. The toolkit path in `targetOptions`.
185
186
// 2. In the system PATH.
186
187
// 3. The path from `getCUDAToolkitPath()`.
187
- std::optional<std::string> findPtxas () const ;
188
+ std::optional<std::string> findTool (StringRef tool) ;
188
189
189
190
// Target options.
190
191
gpu::TargetOptions targetOptions;
@@ -213,48 +214,58 @@ gpu::GPUModuleOp NVPTXSerializer::getOperation() {
213
214
return dyn_cast<gpu::GPUModuleOp>(&SerializeGPUModuleBase::getOperation ());
214
215
}
215
216
216
- std::optional<std::string> NVPTXSerializer::findPtxas () const {
217
- // Find the `ptxas` compiler .
217
+ std::optional<std::string> NVPTXSerializer::findTool (StringRef tool) {
218
+ // Find the `tool` path .
218
219
// 1. Check the toolkit path given in the command line.
219
220
StringRef pathRef = targetOptions.getToolkitPath ();
220
221
SmallVector<char , 256 > path;
221
222
if (pathRef.size ()) {
222
223
path.insert (path.begin (), pathRef.begin (), pathRef.end ());
223
- llvm::sys::path::append (path, " bin" , " ptxas " );
224
+ llvm::sys::path::append (path, " bin" , tool );
224
225
if (llvm::sys::fs::can_execute (path))
225
226
return StringRef (path.data (), path.size ()).str ();
226
227
}
227
228
228
229
// 2. Check PATH.
229
230
if (std::optional<std::string> ptxasCompiler =
230
- llvm::sys::Process::FindInEnvPath (" PATH" , " ptxas " ))
231
+ llvm::sys::Process::FindInEnvPath (" PATH" , tool ))
231
232
return *ptxasCompiler;
232
233
233
234
// 3. Check `getCUDAToolkitPath()`.
234
235
pathRef = getCUDAToolkitPath ();
235
236
path.clear ();
236
237
if (pathRef.size ()) {
237
238
path.insert (path.begin (), pathRef.begin (), pathRef.end ());
238
- llvm::sys::path::append (path, " bin" , " ptxas " );
239
+ llvm::sys::path::append (path, " bin" , tool );
239
240
if (llvm::sys::fs::can_execute (path))
240
241
return StringRef (path.data (), path.size ()).str ();
241
242
}
243
+ getOperation ().emitError ()
244
+ << " Couldn't find the `" << tool
245
+ << " ` binary. Please specify the toolkit "
246
+ " path, add the compiler to $PATH, or set one of the environment "
247
+ " variables in `NVVM::getCUDAToolkitPath()`." ;
242
248
return std::nullopt;
243
249
}
244
250
245
251
// TODO: clean this method & have a generic tool driver or never emit binaries
246
252
// with this mechanism and let another stage take care of it.
247
253
std::optional<SmallVector<char , 0 >>
248
254
NVPTXSerializer::compileToBinary (const std::string &ptxCode) {
249
- // Find the PTXAS compiler.
250
- std::optional<std::string> ptxasCompiler = findPtxas ();
251
- if (!ptxasCompiler) {
252
- getOperation ().emitError ()
253
- << " Couldn't find the `ptxas` compiler. Please specify the toolkit "
254
- " path, add the compiler to $PATH, or set one of the environment "
255
- " variables in `NVVM::getCUDAToolkitPath()`." ;
255
+ // Determine if the serializer should create a fatbinary with the PTX embeded
256
+ // or a simple CUBIN binary.
257
+ const bool createFatbin =
258
+ (targetOptions.getCompilationTarget () & gpu::TargetOptions::fatbinary) ==
259
+ gpu::TargetOptions::fatbinary;
260
+
261
+ // Find the `ptxas` & `fatbinary` tools.
262
+ std::optional<std::string> ptxasCompiler = findTool (" ptxas" );
263
+ if (!ptxasCompiler)
256
264
return std::nullopt;
257
- }
265
+ std::optional<std::string> fatbinaryTool = findTool (" fatbinary" );
266
+ if (createFatbin && !fatbinaryTool)
267
+ return std::nullopt;
268
+ Location loc = getOperation ().getLoc ();
258
269
259
270
// Base name for all temp files: mlir-<module name>-<target triple>-<chip>.
260
271
std::string basename =
@@ -268,99 +279,154 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) {
268
279
std::optional<TmpFile> logFile = createTemp (basename, " log" );
269
280
if (!logFile)
270
281
return std::nullopt;
271
- std::optional<TmpFile> cubinFile = createTemp (basename, " cubin " );
272
- if (!cubinFile )
282
+ std::optional<TmpFile> binaryFile = createTemp (basename, " bin " );
283
+ if (!binaryFile )
273
284
return std::nullopt;
285
+ TmpFile cubinFile;
286
+ if (createFatbin) {
287
+ Twine cubinFilename = ptxFile->first + " .cubin" ;
288
+ cubinFile = TmpFile (cubinFilename.str (), llvm::FileRemover (cubinFilename));
289
+ } else {
290
+ cubinFile.first = binaryFile->first ;
291
+ }
274
292
275
293
std::error_code ec;
276
294
// Dump the PTX to a temp file.
277
295
{
278
296
llvm::raw_fd_ostream ptxStream (ptxFile->first , ec);
279
297
if (ec) {
280
- getOperation ().emitError ()
281
- << " Couldn't open the file: `" << ptxFile->first
282
- << " `, error message: " << ec.message ();
298
+ emitError (loc) << " Couldn't open the file: `" << ptxFile->first
299
+ << " `, error message: " << ec.message ();
283
300
return std::nullopt;
284
301
}
285
302
ptxStream << ptxCode;
286
303
if (ptxStream.has_error ()) {
287
- getOperation ().emitError ()
288
- << " An error occurred while writing the PTX to: `" << ptxFile->first
289
- << " `." ;
304
+ emitError (loc) << " An error occurred while writing the PTX to: `"
305
+ << ptxFile->first << " `." ;
290
306
return std::nullopt;
291
307
}
292
308
ptxStream.flush ();
293
309
}
294
310
295
- // Create PTX args.
311
+ // Command redirects.
312
+ std::optional<StringRef> redirects[] = {
313
+ std::nullopt,
314
+ logFile->first ,
315
+ logFile->first ,
316
+ };
317
+
318
+ // Get any extra args passed in `targetOptions`.
319
+ std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> cmdOpts =
320
+ targetOptions.tokenizeCmdOptions ();
321
+
322
+ // Create ptxas args.
296
323
std::string optLevel = std::to_string (this ->optLevel );
297
324
SmallVector<StringRef, 12 > ptxasArgs (
298
325
{StringRef (" ptxas" ), StringRef (" -arch" ), getTarget ().getChip (),
299
- StringRef (ptxFile->first ), StringRef (" -o" ), StringRef (cubinFile-> first ),
326
+ StringRef (ptxFile->first ), StringRef (" -o" ), StringRef (cubinFile. first ),
300
327
" --opt-level" , optLevel});
301
328
302
- std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> cmdOpts =
303
- targetOptions.tokenizeCmdOptions ();
304
- for (auto arg : cmdOpts.second )
305
- ptxasArgs.push_back (arg);
329
+ bool useFatbin32 = false ;
330
+ for (auto cArg : cmdOpts.second ) {
331
+ // All `cmdOpts` are for `ptxas` except `-32` which passes `-32` to
332
+ // `fatbinary`, indicating a 32-bit target. By default a 64-bit target is
333
+ // assumed.
334
+ if (StringRef arg (cArg); arg != " -32" )
335
+ ptxasArgs.push_back (arg);
336
+ else
337
+ useFatbin32 = true ;
338
+ }
306
339
307
- std::optional<StringRef> redirects[] = {
308
- std::nullopt,
309
- logFile->first ,
310
- logFile->first ,
311
- };
340
+ // Create the `fatbinary` args.
341
+ StringRef chip = getTarget ().getChip ();
342
+ // Remove the arch prefix to obtain the compute capability.
343
+ chip.consume_front (" sm_" ), chip.consume_front (" compute_" );
344
+ // Embed the cubin object.
345
+ std::string cubinArg =
346
+ llvm::formatv (" --image3=kind=elf,sm={0},file={1}" , chip, cubinFile.first )
347
+ .str ();
348
+ // Embed the PTX file so the driver can JIT if needed.
349
+ std::string ptxArg =
350
+ llvm::formatv (" --image3=kind=ptx,sm={0},file={1}" , chip, ptxFile->first )
351
+ .str ();
352
+ SmallVector<StringRef, 6 > fatbinArgs ({StringRef (" fatbinary" ),
353
+ useFatbin32 ? " -32" : " -64" , cubinArg,
354
+ ptxArg, " --create" , binaryFile->first });
355
+
356
+ // Dump tool invocation commands.
357
+ #define DEBUG_TYPE " serialize-to-binary"
358
+ LLVM_DEBUG ({
359
+ llvm::dbgs () << " Tool invocation for module: "
360
+ << getOperation ().getNameAttr () << " \n " ;
361
+ llvm::interleave (ptxasArgs, llvm::dbgs (), " " );
362
+ llvm::dbgs () << " \n " ;
363
+ if (createFatbin) {
364
+ llvm::interleave (fatbinArgs, llvm::dbgs (), " " );
365
+ llvm::dbgs () << " \n " ;
366
+ }
367
+ });
368
+ #undef DEBUG_TYPE
312
369
313
- // Invoke PTXAS .
370
+ // Helper function for printing tool error logs .
314
371
std::string message;
315
- if (llvm::sys::ExecuteAndWait (ptxasCompiler.value (), ptxasArgs,
316
- /* Env=*/ std::nullopt,
317
- /* Redirects=*/ redirects,
318
- /* SecondsToWait=*/ 0 ,
319
- /* MemoryLimit=*/ 0 ,
320
- /* ErrMsg=*/ &message)) {
372
+ auto emitLogError =
373
+ [&](StringRef toolName) -> std::optional<SmallVector<char , 0 >> {
321
374
if (message.empty ()) {
322
- llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> ptxasStderr =
375
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> toolStderr =
323
376
llvm::MemoryBuffer::getFile (logFile->first );
324
- if (ptxasStderr )
325
- getOperation (). emitError () << " PTXAS invocation failed. PTXAS log :\n "
326
- << ptxasStderr ->get ()->getBuffer ();
377
+ if (toolStderr )
378
+ emitError (loc ) << toolName << " invocation failed. Log :\n "
379
+ << toolStderr ->get ()->getBuffer ();
327
380
else
328
- getOperation (). emitError () << " PTXAS invocation failed." ;
381
+ emitError (loc ) << toolName << " invocation failed." ;
329
382
return std::nullopt;
330
383
}
331
- getOperation (). emitError ()
332
- << " PTXAS invocation failed, error message: " << message;
384
+ emitError (loc) << toolName
385
+ << " invocation failed, error message: " << message;
333
386
return std::nullopt;
334
- }
387
+ };
335
388
336
- // Dump the output of PTXAS, helpful if the verbose flag was passed.
389
+ // Invoke PTXAS.
390
+ if (llvm::sys::ExecuteAndWait (ptxasCompiler.value (), ptxasArgs,
391
+ /* Env=*/ std::nullopt,
392
+ /* Redirects=*/ redirects,
393
+ /* SecondsToWait=*/ 0 ,
394
+ /* MemoryLimit=*/ 0 ,
395
+ /* ErrMsg=*/ &message))
396
+ return emitLogError (" `ptxas`" );
397
+
398
+ // Invoke `fatbin`.
399
+ message.clear ();
400
+ if (createFatbin && llvm::sys::ExecuteAndWait (*fatbinaryTool, fatbinArgs,
401
+ /* Env=*/ std::nullopt,
402
+ /* Redirects=*/ redirects,
403
+ /* SecondsToWait=*/ 0 ,
404
+ /* MemoryLimit=*/ 0 ,
405
+ /* ErrMsg=*/ &message))
406
+ return emitLogError (" `fatbinary`" );
407
+
408
+ // Dump the output of the tools, helpful if the verbose flag was passed.
337
409
#define DEBUG_TYPE " serialize-to-binary"
338
410
LLVM_DEBUG ({
339
- llvm::dbgs () << " PTXAS invocation for module: "
340
- << getOperation ().getNameAttr () << " \n " ;
341
- llvm::dbgs () << " Command: " ;
342
- llvm::interleave (ptxasArgs, llvm::dbgs (), " " );
343
- llvm::dbgs () << " \n " ;
344
- llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> ptxasLog =
411
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> logBuffer =
345
412
llvm::MemoryBuffer::getFile (logFile->first );
346
- if (ptxasLog && (*ptxasLog )->getBuffer ().size ()) {
347
- llvm::dbgs () << " Output:\n " << (*ptxasLog )->getBuffer () << " \n " ;
413
+ if (logBuffer && (*logBuffer )->getBuffer ().size ()) {
414
+ llvm::dbgs () << " Output:\n " << (*logBuffer )->getBuffer () << " \n " ;
348
415
llvm::dbgs ().flush ();
349
416
}
350
417
});
351
418
#undef DEBUG_TYPE
352
419
353
- // Read the cubin file.
354
- llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> cubinBuffer =
355
- llvm::MemoryBuffer::getFile (cubinFile->first );
356
- if (!cubinBuffer) {
357
- getOperation ().emitError ()
358
- << " Couldn't open the file: `" << cubinFile->first
359
- << " `, error message: " << cubinBuffer.getError ().message ();
420
+ // Read the fatbin.
421
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> binaryBuffer =
422
+ llvm::MemoryBuffer::getFile (binaryFile->first );
423
+ if (!binaryBuffer) {
424
+ emitError (loc) << " Couldn't open the file: `" << binaryFile->first
425
+ << " `, error message: " << binaryBuffer.getError ().message ();
360
426
return std::nullopt;
361
427
}
362
- StringRef cubinStr = (*cubinBuffer )->getBuffer ();
363
- return SmallVector<char , 0 >(cubinStr .begin (), cubinStr .end ());
428
+ StringRef fatbin = (*binaryBuffer )->getBuffer ();
429
+ return SmallVector<char , 0 >(fatbin .begin (), fatbin .end ());
364
430
}
365
431
366
432
#if MLIR_NVPTXCOMPILER_ENABLED == 1
0 commit comments