Skip to content

Commit 9f2feeb

Browse files
authored
[mlir][gpu][nvptx] Remove null terminator when outputting PTX (#133019)
PTX source files are expected to only contain ASCII text (https://docs.nvidia.com/cuda/parallel-thread-execution/#source-format) and no null terminators. `ptxas` has so far not enforced this but is moving towards doing so. This revealed a problem where the null terminator is getting printed out in the output file in MLIR path when outputting ptx directly. Only add the null on the assembly output path for JIT instead of in output of `moduleToObject `.
1 parent f1c6612 commit 9f2feeb

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

mlir/lib/Target/LLVM/NVVM/Target.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -722,12 +722,8 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
722722
#undef DEBUG_TYPE
723723

724724
// Return PTX if the compilation target is `assembly`.
725-
if (targetOptions.getCompilationTarget() ==
726-
gpu::CompilationTarget::Assembly) {
727-
// Make sure to include the null terminator.
728-
StringRef bin(serializedISA->c_str(), serializedISA->size() + 1);
729-
return SmallVector<char, 0>(bin.begin(), bin.end());
730-
}
725+
if (targetOptions.getCompilationTarget() == gpu::CompilationTarget::Assembly)
726+
return SmallVector<char, 0>(serializedISA->begin(), serializedISA->end());
731727

732728
std::optional<SmallVector<char, 0>> result;
733729
moduleToObjectTimer.startTimer();

mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,11 @@ LogicalResult SelectObjectAttrImpl::embedBinary(
116116
llvm::Module *module = moduleTranslation.getLLVMModule();
117117

118118
// Embed the object as a global string.
119+
// Add null for assembly output for JIT paths that expect null-terminated
120+
// strings.
121+
bool addNull = (object.getFormat() == gpu::CompilationTarget::Assembly);
119122
llvm::Constant *binary = llvm::ConstantDataArray::getString(
120-
builder.getContext(), object.getObject().getValue(), false);
123+
builder.getContext(), object.getObject().getValue(), addNull);
121124
llvm::GlobalVariable *serializedObj =
122125
new llvm::GlobalVariable(*module, binary->getType(), true,
123126
llvm::GlobalValue::LinkageTypes::InternalLinkage,

mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(SerializeNVVMToPTX)) {
130130

131131
ASSERT_TRUE(
132132
StringRef(object->data(), object->size()).contains("nvvm_kernel"));
133+
ASSERT_TRUE(StringRef(object->data(), object->size()).count('\0') == 0);
133134
}
134135
}
135136

0 commit comments

Comments
 (0)