Skip to content

Commit 6fb2a80

Browse files
[mlir][spirv] Truncate Literal String size at max number words (#142916)
If not truncated the SPIRV serialization would not fail but instead produce an invalid SPIR-V module. --------- Signed-off-by: Davide Grohmann <[email protected]>
1 parent 76197ea commit 6fb2a80

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

mlir/include/mlir/Target/SPIRV/SPIRVBinaryUtils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ constexpr uint32_t kMagicNumber = 0x07230203;
3030
/// The serializer tool ID registered to the Khronos Group
3131
constexpr uint32_t kGeneratorNumber = 22;
3232

33+
/// Max number of words
34+
/// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#_universal_limits
35+
constexpr uint32_t kMaxWordCount = 65535;
36+
37+
/// Max number of words for literal
38+
constexpr uint32_t kMaxLiteralWordCount = kMaxWordCount - 3;
39+
3340
/// Appends a SPRI-V module header to `header` with the given `version` and
3441
/// `idBound`.
3542
void appendModuleHeader(SmallVectorImpl<uint32_t> &header,

mlir/lib/Target/SPIRV/SPIRVBinaryUtils.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
1414
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1515
#include "llvm/Config/llvm-config.h" // for LLVM_VERSION_MAJOR
16+
#include "llvm/Support/Debug.h"
17+
18+
#define DEBUG_TYPE "spirv-binary-utils"
1619

1720
using namespace mlir;
1821

@@ -67,8 +70,19 @@ uint32_t spirv::getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode) {
6770
void spirv::encodeStringLiteralInto(SmallVectorImpl<uint32_t> &binary,
6871
StringRef literal) {
6972
// We need to encode the literal and the null termination.
70-
auto encodingSize = literal.size() / 4 + 1;
71-
auto bufferStartSize = binary.size();
73+
size_t encodingSize = literal.size() / 4 + 1;
74+
size_t sizeOfDataToCopy = literal.size();
75+
if (encodingSize >= kMaxLiteralWordCount) {
76+
// Reserve one word for the null termination.
77+
encodingSize = kMaxLiteralWordCount - 1;
78+
// Do not override the last word (null termination) when copying.
79+
sizeOfDataToCopy = (encodingSize - 1) * 4;
80+
LLVM_DEBUG(llvm::dbgs()
81+
<< "Truncating string literal to max size ("
82+
<< (kMaxLiteralWordCount - 1) << "): " << literal << "\n");
83+
}
84+
size_t bufferStartSize = binary.size();
7285
binary.resize(bufferStartSize + encodingSize, 0);
73-
std::memcpy(binary.data() + bufferStartSize, literal.data(), literal.size());
86+
std::memcpy(binary.data() + bufferStartSize, literal.data(),
87+
sizeOfDataToCopy);
7488
}

0 commit comments

Comments
 (0)