Skip to content

Commit b0b8e83

Browse files
authored
[mlir] Fix use-after-free bugs in {RankedTensorType|VectorType}::Builder (#68969)
Previously, these would set their ArrayRef members to reference their storage SmallVectors after a copy-on-write (COW) operation. This leads to a use-after-free if the builder is copied and the original destroyed (as the new builder would still reference the old SmallVector). This could easily accidentally occur in code like (annotated): ```c++ // 1. `VectorType::Builder(type)` constructs a new temporary builder // 2. `.dropDim(0)` updates the temporary builder by reference, and returns a `VectorType::Builder&` // - Modifying the shape is a COW operation, so `storage` is used, and `shape` updated the reference it // 3. Assigning the reference to `auto` copies the builder (via the default C++ copy ctor) // - There's no special handling for `shape` and `storage`, so the new shape points to the old builder's `storage` auto newType = VectorType::Builder(type).dropDim(0); // 4. When this line is reached the original temporary builder is destroyed // - Actually constructing the vector type is now a use-after-free VectorType newVectorType = VectorType(newType); ``` This is fixed with these changes by using `CopyOnWriteArrayRef<T>`, which implements the same functionality, but ensures no dangling references are possible if it's copied. --- The VectorType::Builder also set the ArrayRef<bool> scalableDims member to a temporary SmallVector when the provided scalableDims are empty. This again leads to a use-after-free, and is unnecessary as VectorType::get already handles being passed an empty scalableDims array. These bugs were in-part caught by UBSAN, see: https://lab.llvm.org/buildbot/#/builders/5/builds/37355
1 parent 28e4f97 commit b0b8e83

File tree

3 files changed

+192
-44
lines changed

3 files changed

+192
-44
lines changed

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 15 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "mlir/IR/BuiltinAttributeInterfaces.h"
1313
#include "mlir/IR/BuiltinTypeInterfaces.h"
14+
#include "mlir/Support/ADTExtras.h"
1415

1516
namespace llvm {
1617
class BitVector;
@@ -274,20 +275,14 @@ class RankedTensorType::Builder {
274275
/// Erase a dim from shape @pos.
275276
Builder &dropDim(unsigned pos) {
276277
assert(pos < shape.size() && "overflow");
277-
if (storage.empty())
278-
storage.append(shape.begin(), shape.end());
279-
storage.erase(storage.begin() + pos);
280-
shape = {storage.data(), storage.size()};
278+
shape.erase(pos);
281279
return *this;
282280
}
283281

284282
/// Insert a val into shape @pos.
285283
Builder &insertDim(int64_t val, unsigned pos) {
286284
assert(pos <= shape.size() && "overflow");
287-
if (storage.empty())
288-
storage.append(shape.begin(), shape.end());
289-
storage.insert(storage.begin() + pos, val);
290-
shape = {storage.data(), storage.size()};
285+
shape.insert(pos, val);
291286
return *this;
292287
}
293288

@@ -296,9 +291,7 @@ class RankedTensorType::Builder {
296291
}
297292

298293
private:
299-
ArrayRef<int64_t> shape;
300-
// Owning shape data for copy-on-write operations.
301-
SmallVector<int64_t> storage;
294+
CopyOnWriteArrayRef<int64_t> shape;
302295
Type elementType;
303296
Attribute encoding;
304297
};
@@ -313,27 +306,18 @@ class VectorType::Builder {
313306
public:
314307
/// Build from another VectorType.
315308
explicit Builder(VectorType other)
316-
: shape(other.getShape()), elementType(other.getElementType()),
309+
: elementType(other.getElementType()), shape(other.getShape()),
317310
scalableDims(other.getScalableDims()) {}
318311

319312
/// Build from scratch.
320313
Builder(ArrayRef<int64_t> shape, Type elementType,
321-
unsigned numScalableDims = 0, ArrayRef<bool> scalableDims = {})
322-
: shape(shape), elementType(elementType) {
323-
if (scalableDims.empty())
324-
scalableDims = SmallVector<bool>(shape.size(), false);
325-
else
326-
this->scalableDims = scalableDims;
327-
}
314+
ArrayRef<bool> scalableDims = {})
315+
: elementType(elementType), shape(shape), scalableDims(scalableDims) {}
328316

329317
Builder &setShape(ArrayRef<int64_t> newShape,
330318
ArrayRef<bool> newIsScalableDim = {}) {
331-
if (newIsScalableDim.empty())
332-
scalableDims = SmallVector<bool>(shape.size(), false);
333-
else
334-
scalableDims = newIsScalableDim;
335-
336319
shape = newShape;
320+
scalableDims = newIsScalableDim;
337321
return *this;
338322
}
339323

@@ -345,25 +329,16 @@ class VectorType::Builder {
345329
/// Erase a dim from shape @pos.
346330
Builder &dropDim(unsigned pos) {
347331
assert(pos < shape.size() && "overflow");
348-
if (storage.empty())
349-
storage.append(shape.begin(), shape.end());
350-
if (storageScalableDims.empty())
351-
storageScalableDims.append(scalableDims.begin(), scalableDims.end());
352-
storage.erase(storage.begin() + pos);
353-
storageScalableDims.erase(storageScalableDims.begin() + pos);
354-
shape = {storage.data(), storage.size()};
355-
scalableDims =
356-
ArrayRef<bool>(storageScalableDims.data(), storageScalableDims.size());
332+
shape.erase(pos);
333+
if (!scalableDims.empty())
334+
scalableDims.erase(pos);
357335
return *this;
358336
}
359337

360338
/// Set a dim in shape @pos to val.
361339
Builder &setDim(unsigned pos, int64_t val) {
362-
if (storage.empty())
363-
storage.append(shape.begin(), shape.end());
364-
assert(pos < storage.size() && "overflow");
365-
storage[pos] = val;
366-
shape = {storage.data(), storage.size()};
340+
assert(pos < shape.size() && "overflow");
341+
shape.set(pos, val);
367342
return *this;
368343
}
369344

@@ -372,13 +347,9 @@ class VectorType::Builder {
372347
}
373348

374349
private:
375-
ArrayRef<int64_t> shape;
376-
// Owning shape data for copy-on-write operations.
377-
SmallVector<int64_t> storage;
378350
Type elementType;
379-
ArrayRef<bool> scalableDims;
380-
// Owning scalableDims data for copy-on-write operations.
381-
SmallVector<bool> storageScalableDims;
351+
CopyOnWriteArrayRef<int64_t> shape;
352+
CopyOnWriteArrayRef<bool> scalableDims;
382353
};
383354

384355
/// Given an `originalShape` and a `reducedShape` assumed to be a subset of

mlir/include/mlir/Support/ADTExtras.h

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
//===- ADTExtras.h - Extra ADTs for use in MLIR -----------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_SUPPORT_ADTEXTRAS_H
10+
#define MLIR_SUPPORT_ADTEXTRAS_H
11+
12+
#include "llvm/ADT/ArrayRef.h"
13+
#include "llvm/ADT/SmallVector.h"
14+
15+
namespace mlir {
16+
17+
//===----------------------------------------------------------------------===//
18+
// CopyOnWriteArrayRef<T>
19+
//===----------------------------------------------------------------------===//
20+
21+
// A wrapper around an ArrayRef<T> that copies to a SmallVector<T> on
22+
// modification. This is for use in the mlir::<Type>::Builders.
23+
template <typename T>
24+
class CopyOnWriteArrayRef {
25+
public:
26+
CopyOnWriteArrayRef(ArrayRef<T> array) : nonOwning(array){};
27+
28+
CopyOnWriteArrayRef &operator=(ArrayRef<T> array) {
29+
nonOwning = array;
30+
owningStorage = {};
31+
return *this;
32+
}
33+
34+
void insert(size_t index, T value) {
35+
SmallVector<T> &vector = ensureCopy();
36+
vector.insert(vector.begin() + index, value);
37+
}
38+
39+
void erase(size_t index) {
40+
// Note: A copy can be avoided when just dropping the front/back dims.
41+
if (isNonOwning() && index == 0) {
42+
nonOwning = nonOwning.drop_front();
43+
} else if (isNonOwning() && index == size() - 1) {
44+
nonOwning = nonOwning.drop_back();
45+
} else {
46+
SmallVector<T> &vector = ensureCopy();
47+
vector.erase(vector.begin() + index);
48+
}
49+
}
50+
51+
void set(size_t index, T value) { ensureCopy()[index] = value; }
52+
53+
size_t size() const { return ArrayRef<T>(*this).size(); }
54+
55+
bool empty() const { return ArrayRef<T>(*this).empty(); }
56+
57+
operator ArrayRef<T>() const {
58+
return nonOwning.empty() ? ArrayRef<T>(owningStorage) : nonOwning;
59+
}
60+
61+
private:
62+
bool isNonOwning() const { return !nonOwning.empty(); }
63+
64+
SmallVector<T> &ensureCopy() {
65+
// Empty non-owning storage signals the array has been copied to the owning
66+
// storage (or both are empty). Note: `nonOwning` should never reference
67+
// `owningStorage`. This can lead to dangling references if the
68+
// CopyOnWriteArrayRef<T> is copied.
69+
if (isNonOwning()) {
70+
owningStorage = SmallVector<T>(nonOwning);
71+
nonOwning = {};
72+
}
73+
return owningStorage;
74+
}
75+
76+
ArrayRef<T> nonOwning;
77+
SmallVector<T> owningStorage;
78+
};
79+
80+
} // namespace mlir
81+
82+
#endif

mlir/unittests/IR/ShapedTypeTest.cpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,4 +131,99 @@ TEST(ShapedTypeTest, CloneVector) {
131131
VectorType::get(vectorNewShape, vectorNewType));
132132
}
133133

134+
TEST(ShapedTypeTest, VectorTypeBuilder) {
135+
MLIRContext context;
136+
Type f32 = FloatType::getF32(&context);
137+
138+
SmallVector<int64_t> shape{2, 4, 8, 9, 1};
139+
SmallVector<bool> scalableDims{true, false, true, false, false};
140+
VectorType vectorType = VectorType::get(shape, f32, scalableDims);
141+
142+
{
143+
// Drop some dims.
144+
VectorType dropFrontTwoDims =
145+
VectorType::Builder(vectorType).dropDim(0).dropDim(0);
146+
ASSERT_EQ(vectorType.getElementType(), dropFrontTwoDims.getElementType());
147+
ASSERT_EQ(vectorType.getShape().drop_front(2), dropFrontTwoDims.getShape());
148+
ASSERT_EQ(vectorType.getScalableDims().drop_front(2),
149+
dropFrontTwoDims.getScalableDims());
150+
}
151+
152+
{
153+
// Set some dims.
154+
VectorType setTwoDims =
155+
VectorType::Builder(vectorType).setDim(0, 10).setDim(3, 12);
156+
ASSERT_EQ(setTwoDims.getShape(), ArrayRef<int64_t>({10, 4, 8, 12, 1}));
157+
ASSERT_EQ(vectorType.getElementType(), setTwoDims.getElementType());
158+
ASSERT_EQ(vectorType.getScalableDims(), setTwoDims.getScalableDims());
159+
}
160+
161+
{
162+
// Test for bug from:
163+
// https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a
164+
// Constructs a temporary builder, modifies it, copies it to `builder`.
165+
// This used to lead to a use-after-free. Running under sanitizers will
166+
// catch any issues.
167+
VectorType::Builder builder = VectorType::Builder(vectorType).setDim(0, 16);
168+
VectorType newVectorType = VectorType(builder);
169+
ASSERT_EQ(newVectorType.getDimSize(0), 16);
170+
}
171+
172+
{
173+
// Make builder from scratch (without scalable dims) -- this use to lead to
174+
// a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.
175+
// Running under sanitizers will catch any issues.
176+
SmallVector<int64_t> shape{1, 2, 3, 4};
177+
VectorType::Builder builder(shape, f32);
178+
ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(shape));
179+
}
180+
181+
{
182+
// Set vector shape (without scalable dims) -- this use to lead to
183+
// a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.
184+
// Running under sanitizers will catch any issues.
185+
VectorType::Builder builder(vectorType);
186+
SmallVector<int64_t> newShape{2, 2};
187+
builder.setShape(newShape);
188+
ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(newShape));
189+
}
190+
}
191+
192+
TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
193+
MLIRContext context;
194+
Type f32 = FloatType::getF32(&context);
195+
196+
SmallVector<int64_t> shape{2, 4, 8, 16, 32};
197+
RankedTensorType tensorType = RankedTensorType::get(shape, f32);
198+
199+
{
200+
// Drop some dims.
201+
RankedTensorType dropFrontTwoDims =
202+
RankedTensorType::Builder(tensorType).dropDim(0).dropDim(1).dropDim(0);
203+
ASSERT_EQ(tensorType.getElementType(), dropFrontTwoDims.getElementType());
204+
ASSERT_EQ(dropFrontTwoDims.getShape(), ArrayRef<int64_t>({16, 32}));
205+
}
206+
207+
{
208+
// Insert some dims.
209+
RankedTensorType insertTwoDims =
210+
RankedTensorType::Builder(tensorType).insertDim(7, 2).insertDim(9, 3);
211+
ASSERT_EQ(tensorType.getElementType(), insertTwoDims.getElementType());
212+
ASSERT_EQ(insertTwoDims.getShape(),
213+
ArrayRef<int64_t>({2, 4, 7, 9, 8, 16, 32}));
214+
}
215+
216+
{
217+
// Test for bug from:
218+
// https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a
219+
// Constructs a temporary builder, modifies it, copies it to `builder`.
220+
// This used to lead to a use-after-free. Running under sanitizers will
221+
// catch any issues.
222+
RankedTensorType::Builder builder =
223+
RankedTensorType::Builder(tensorType).dropDim(0);
224+
RankedTensorType newTensorType = RankedTensorType(builder);
225+
ASSERT_EQ(tensorType.getShape().drop_front(), newTensorType.getShape());
226+
}
227+
}
228+
134229
} // namespace

0 commit comments

Comments
 (0)