Skip to content

Commit 01bfced

Browse files
chsiggZijunZhaoCCK
authored andcommitted
[mlir][bytecode] Check that bytecode source buffer is sufficiently aligned. (llvm#66380)
Before this change, the `ByteCode` test failed on CentOS 7 with devtoolset-9, because strings happen to be only 8 byte aligned. In general though, strings have no alignment guarantee. Increase resource alignment in test to 32 bytes. Adjust test to sufficiently align buffer. Add test to check error when buffer is insufficiently aligned.
1 parent 36161b4 commit 01bfced

File tree

3 files changed

+91
-15
lines changed

3 files changed

+91
-15
lines changed

mlir/lib/Bytecode/Reader/BytecodeReader.cpp

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#include "mlir/Bytecode/BytecodeImplementation.h"
1212
#include "mlir/Bytecode/BytecodeOpInterface.h"
1313
#include "mlir/Bytecode/Encoding.h"
14-
#include "mlir/IR/BuiltinDialect.h"
1514
#include "mlir/IR/BuiltinOps.h"
1615
#include "mlir/IR/Diagnostics.h"
1716
#include "mlir/IR/OpImplementation.h"
@@ -20,15 +19,13 @@
2019
#include "mlir/Support/LLVM.h"
2120
#include "mlir/Support/LogicalResult.h"
2221
#include "llvm/ADT/ArrayRef.h"
23-
#include "llvm/ADT/MapVector.h"
2422
#include "llvm/ADT/ScopeExit.h"
25-
#include "llvm/ADT/SmallString.h"
2623
#include "llvm/ADT/StringExtras.h"
2724
#include "llvm/ADT/StringRef.h"
2825
#include "llvm/Support/Endian.h"
2926
#include "llvm/Support/MemoryBufferRef.h"
30-
#include "llvm/Support/SaveAndRestore.h"
3127
#include "llvm/Support/SourceMgr.h"
28+
3229
#include <cstddef>
3330
#include <list>
3431
#include <memory>
@@ -93,25 +90,36 @@ namespace {
9390
class EncodingReader {
9491
public:
9592
explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc)
96-
: dataIt(contents.data()), dataEnd(contents.end()), fileLoc(fileLoc) {}
93+
: buffer(contents), dataIt(buffer.begin()), fileLoc(fileLoc) {}
9794
explicit EncodingReader(StringRef contents, Location fileLoc)
9895
: EncodingReader({reinterpret_cast<const uint8_t *>(contents.data()),
9996
contents.size()},
10097
fileLoc) {}
10198

10299
/// Returns true if the entire section has been read.
103-
bool empty() const { return dataIt == dataEnd; }
100+
bool empty() const { return dataIt == buffer.end(); }
104101

105102
/// Returns the remaining size of the bytecode.
106-
size_t size() const { return dataEnd - dataIt; }
103+
size_t size() const { return buffer.end() - dataIt; }
107104

108105
/// Align the current reader position to the specified alignment.
109106
LogicalResult alignTo(unsigned alignment) {
110107
if (!llvm::isPowerOf2_32(alignment))
111108
return emitError("expected alignment to be a power-of-two");
112109

110+
auto isUnaligned = [&](const uint8_t *ptr) {
111+
return ((uintptr_t)ptr & (alignment - 1)) != 0;
112+
};
113+
114+
// Ensure the data buffer was sufficiently aligned in the first place.
115+
if (LLVM_UNLIKELY(isUnaligned(buffer.begin()))) {
116+
return emitError("expected bytecode buffer to be aligned to ", alignment,
117+
", but got pointer: '0x" +
118+
llvm::utohexstr((uintptr_t)buffer.begin()) + "'");
119+
}
120+
113121
// Shift the reader position to the next alignment boundary.
114-
while (uintptr_t(dataIt) & (uintptr_t(alignment) - 1)) {
122+
while (isUnaligned(dataIt)) {
115123
uint8_t padding;
116124
if (failed(parseByte(padding)))
117125
return failure();
@@ -123,7 +131,7 @@ class EncodingReader {
123131

124132
// Ensure the data iterator is now aligned. This case is unlikely because we
125133
// *just* went through the effort to align the data iterator.
126-
if (LLVM_UNLIKELY(!llvm::isAddrAligned(llvm::Align(alignment), dataIt))) {
134+
if (LLVM_UNLIKELY(isUnaligned(dataIt))) {
127135
return emitError("expected data iterator aligned to ", alignment,
128136
", but got pointer: '0x" +
129137
llvm::utohexstr((uintptr_t)dataIt) + "'");
@@ -320,8 +328,11 @@ class EncodingReader {
320328
return success();
321329
}
322330

323-
/// The current data iterator, and an iterator to the end of the buffer.
324-
const uint8_t *dataIt, *dataEnd;
331+
/// The bytecode buffer.
332+
ArrayRef<uint8_t> buffer;
333+
334+
/// The current iterator within the 'buffer'.
335+
const uint8_t *dataIt;
325336

326337
/// A location for the bytecode used to report errors.
327338
Location fileLoc;

mlir/unittests/Bytecode/BytecodeTest.cpp

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "mlir/Bytecode/BytecodeReader.h"
109
#include "mlir/Bytecode/BytecodeWriter.h"
1110
#include "mlir/IR/AsmState.h"
1211
#include "mlir/IR/BuiltinAttributes.h"
@@ -22,7 +21,7 @@
2221
using namespace llvm;
2322
using namespace mlir;
2423

25-
using testing::ElementsAre;
24+
using ::testing::StartsWith;
2625

2726
StringLiteral IRWithResources = R"(
2827
module @TestDialectResources attributes {
@@ -31,7 +30,7 @@ module @TestDialectResources attributes {
3130
{-#
3231
dialect_resources: {
3332
builtin: {
34-
resource: "0x1000000001000000020000000300000004000000"
33+
resource: "0x2000000001000000020000000300000004000000"
3534
}
3635
}
3736
#-}
@@ -49,10 +48,19 @@ TEST(Bytecode, MultiModuleWithResource) {
4948
std::string buffer;
5049
llvm::raw_string_ostream ostream(buffer);
5150
ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), ostream)));
51+
ostream.flush();
52+
53+
// Create copy of buffer which is aligned to requested resource alignment.
54+
constexpr size_t kAlignment = 0x20;
55+
size_t buffer_size = buffer.size();
56+
buffer.reserve(buffer_size + kAlignment - 1);
57+
size_t pad = ~(uintptr_t)buffer.data() + 1 & kAlignment - 1;
58+
buffer.insert(0, pad, ' ');
59+
StringRef aligned_buffer(buffer.data() + pad, buffer_size);
5260

5361
// Parse it back
5462
OwningOpRef<Operation *> roundTripModule =
55-
parseSourceString<Operation *>(ostream.str(), parseConfig);
63+
parseSourceString<Operation *>(aligned_buffer, parseConfig);
5664
ASSERT_TRUE(roundTripModule);
5765

5866
// FIXME: Parsing external resources does not work on big-endian
@@ -80,3 +88,39 @@ TEST(Bytecode, MultiModuleWithResource) {
8088
checkResourceAttribute(*module);
8189
checkResourceAttribute(*roundTripModule);
8290
}
91+
92+
TEST(Bytecode, InsufficientAlignmentFailure) {
93+
MLIRContext context;
94+
Builder builder(&context);
95+
ParserConfig parseConfig(&context);
96+
OwningOpRef<Operation *> module =
97+
parseSourceString<Operation *>(IRWithResources, parseConfig);
98+
ASSERT_TRUE(module);
99+
100+
// Write the module to bytecode
101+
std::string buffer;
102+
llvm::raw_string_ostream ostream(buffer);
103+
ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), ostream)));
104+
ostream.flush();
105+
106+
// Create copy of buffer which is insufficiently aligned.
107+
constexpr size_t kAlignment = 0x20;
108+
size_t buffer_size = buffer.size();
109+
buffer.reserve(buffer_size + kAlignment - 1);
110+
size_t pad = ~(uintptr_t)buffer.data() + kAlignment / 2 + 1 & kAlignment - 1;
111+
buffer.insert(0, pad, ' ');
112+
StringRef misaligned_buffer(buffer.data() + pad, buffer_size);
113+
114+
std::unique_ptr<Diagnostic> diagnostic;
115+
context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
116+
diagnostic = std::make_unique<Diagnostic>(std::move(diag));
117+
});
118+
119+
// Try to parse it back and check for alignment error.
120+
OwningOpRef<Operation *> roundTripModule =
121+
parseSourceString<Operation *>(misaligned_buffer, parseConfig);
122+
EXPECT_FALSE(roundTripModule);
123+
ASSERT_TRUE(diagnostic);
124+
EXPECT_THAT(diagnostic->str(),
125+
StartsWith("expected bytecode buffer to be aligned to 32"));
126+
}

utils/bazel/llvm-project-overlay/mlir/unittests/BUILD.bazel

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,27 @@ cc_test(
359359
],
360360
)
361361

362+
cc_test(
363+
name = "bytecode_tests",
364+
size = "small",
365+
srcs = glob([
366+
"Bytecode/*.cpp",
367+
"Bytecode/*.h",
368+
"Bytecode/*/*.cpp",
369+
"Bytecode/*/*.h",
370+
]),
371+
deps = [
372+
"//llvm:Support",
373+
"//mlir:BytecodeReader",
374+
"//mlir:BytecodeWriter",
375+
"//mlir:IR",
376+
"//mlir:Parser",
377+
"//third-party/unittest:gmock",
378+
"//third-party/unittest:gtest",
379+
"//third-party/unittest:gtest_main",
380+
],
381+
)
382+
362383
cc_test(
363384
name = "conversion_tests",
364385
size = "small",

0 commit comments

Comments
 (0)