Skip to content

[mlir] Fix use-after-free bugs in {RankedTensorType|VectorType}::Builder #68969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 15 additions & 44 deletions mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/Support/ADTExtras.h"

namespace llvm {
class BitVector;
Expand Down Expand Up @@ -274,20 +275,14 @@ class RankedTensorType::Builder {
/// Erase a dim from shape @pos.
Builder &dropDim(unsigned pos) {
assert(pos < shape.size() && "overflow");
if (storage.empty())
storage.append(shape.begin(), shape.end());
storage.erase(storage.begin() + pos);
shape = {storage.data(), storage.size()};
shape.erase(pos);
return *this;
}

/// Insert a val into shape @pos.
Builder &insertDim(int64_t val, unsigned pos) {
assert(pos <= shape.size() && "overflow");
if (storage.empty())
storage.append(shape.begin(), shape.end());
storage.insert(storage.begin() + pos, val);
shape = {storage.data(), storage.size()};
shape.insert(pos, val);
return *this;
}

Expand All @@ -296,9 +291,7 @@ class RankedTensorType::Builder {
}

private:
ArrayRef<int64_t> shape;
// Owning shape data for copy-on-write operations.
SmallVector<int64_t> storage;
CopyOnWriteArrayRef<int64_t> shape;
Type elementType;
Attribute encoding;
};
Expand All @@ -313,27 +306,18 @@ class VectorType::Builder {
public:
/// Build from another VectorType.
explicit Builder(VectorType other)
: shape(other.getShape()), elementType(other.getElementType()),
: elementType(other.getElementType()), shape(other.getShape()),
scalableDims(other.getScalableDims()) {}

/// Build from scratch.
Builder(ArrayRef<int64_t> shape, Type elementType,
unsigned numScalableDims = 0, ArrayRef<bool> scalableDims = {})
: shape(shape), elementType(elementType) {
if (scalableDims.empty())
scalableDims = SmallVector<bool>(shape.size(), false);
else
this->scalableDims = scalableDims;
}
ArrayRef<bool> scalableDims = {})
: elementType(elementType), shape(shape), scalableDims(scalableDims) {}

Builder &setShape(ArrayRef<int64_t> newShape,
ArrayRef<bool> newIsScalableDim = {}) {
if (newIsScalableDim.empty())
scalableDims = SmallVector<bool>(shape.size(), false);
else
scalableDims = newIsScalableDim;

shape = newShape;
scalableDims = newIsScalableDim;
return *this;
}

Expand All @@ -345,25 +329,16 @@ class VectorType::Builder {
/// Erase a dim from shape @pos.
Builder &dropDim(unsigned pos) {
assert(pos < shape.size() && "overflow");
if (storage.empty())
storage.append(shape.begin(), shape.end());
if (storageScalableDims.empty())
storageScalableDims.append(scalableDims.begin(), scalableDims.end());
storage.erase(storage.begin() + pos);
storageScalableDims.erase(storageScalableDims.begin() + pos);
shape = {storage.data(), storage.size()};
scalableDims =
ArrayRef<bool>(storageScalableDims.data(), storageScalableDims.size());
shape.erase(pos);
if (!scalableDims.empty())
scalableDims.erase(pos);
return *this;
}

/// Set a dim in shape @pos to val.
Builder &setDim(unsigned pos, int64_t val) {
if (storage.empty())
storage.append(shape.begin(), shape.end());
assert(pos < storage.size() && "overflow");
storage[pos] = val;
shape = {storage.data(), storage.size()};
assert(pos < shape.size() && "overflow");
shape.set(pos, val);
return *this;
}

Expand All @@ -372,13 +347,9 @@ class VectorType::Builder {
}

private:
ArrayRef<int64_t> shape;
// Owning shape data for copy-on-write operations.
SmallVector<int64_t> storage;
Type elementType;
ArrayRef<bool> scalableDims;
// Owning scalableDims data for copy-on-write operations.
SmallVector<bool> storageScalableDims;
CopyOnWriteArrayRef<int64_t> shape;
CopyOnWriteArrayRef<bool> scalableDims;
};

/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
Expand Down
82 changes: 82 additions & 0 deletions mlir/include/mlir/Support/ADTExtras.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
//===- ADTExtras.h - Extra ADTs for use in MLIR -----------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_SUPPORT_ADTEXTRAS_H
#define MLIR_SUPPORT_ADTEXTRAS_H

#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"

namespace mlir {

//===----------------------------------------------------------------------===//
// CopyOnWriteArrayRef<T>
//===----------------------------------------------------------------------===//

// A wrapper around an ArrayRef<T> that copies to a SmallVector<T> on
// modification. This is for use in the mlir::<Type>::Builders.
template <typename T>
class CopyOnWriteArrayRef {
public:
CopyOnWriteArrayRef(ArrayRef<T> array) : nonOwning(array){};

CopyOnWriteArrayRef &operator=(ArrayRef<T> array) {
nonOwning = array;
owningStorage = {};
return *this;
}

void insert(size_t index, T value) {
SmallVector<T> &vector = ensureCopy();
vector.insert(vector.begin() + index, value);
}

void erase(size_t index) {
// Note: A copy can be avoided when just dropping the front/back dims.
if (isNonOwning() && index == 0) {
nonOwning = nonOwning.drop_front();
} else if (isNonOwning() && index == size() - 1) {
nonOwning = nonOwning.drop_back();
} else {
SmallVector<T> &vector = ensureCopy();
vector.erase(vector.begin() + index);
}
}

void set(size_t index, T value) { ensureCopy()[index] = value; }

size_t size() const { return ArrayRef<T>(*this).size(); }

bool empty() const { return ArrayRef<T>(*this).empty(); }

operator ArrayRef<T>() const {
return nonOwning.empty() ? ArrayRef<T>(owningStorage) : nonOwning;
}

private:
bool isNonOwning() const { return !nonOwning.empty(); }

SmallVector<T> &ensureCopy() {
// Empty non-owning storage signals the array has been copied to the owning
// storage (or both are empty). Note: `nonOwning` should never reference
// `owningStorage`. This can lead to dangling references if the
// CopyOnWriteArrayRef<T> is copied.
if (isNonOwning()) {
owningStorage = SmallVector<T>(nonOwning);
nonOwning = {};
}
return owningStorage;
}

ArrayRef<T> nonOwning;
SmallVector<T> owningStorage;
};

} // namespace mlir

#endif
95 changes: 95 additions & 0 deletions mlir/unittests/IR/ShapedTypeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,99 @@ TEST(ShapedTypeTest, CloneVector) {
VectorType::get(vectorNewShape, vectorNewType));
}

TEST(ShapedTypeTest, VectorTypeBuilder) {
MLIRContext context;
Type f32 = FloatType::getF32(&context);

SmallVector<int64_t> shape{2, 4, 8, 9, 1};
SmallVector<bool> scalableDims{true, false, true, false, false};
VectorType vectorType = VectorType::get(shape, f32, scalableDims);

{
// Drop some dims.
VectorType dropFrontTwoDims =
VectorType::Builder(vectorType).dropDim(0).dropDim(0);
ASSERT_EQ(vectorType.getElementType(), dropFrontTwoDims.getElementType());
ASSERT_EQ(vectorType.getShape().drop_front(2), dropFrontTwoDims.getShape());
ASSERT_EQ(vectorType.getScalableDims().drop_front(2),
dropFrontTwoDims.getScalableDims());
}

{
// Set some dims.
VectorType setTwoDims =
VectorType::Builder(vectorType).setDim(0, 10).setDim(3, 12);
ASSERT_EQ(setTwoDims.getShape(), ArrayRef<int64_t>({10, 4, 8, 12, 1}));
ASSERT_EQ(vectorType.getElementType(), setTwoDims.getElementType());
ASSERT_EQ(vectorType.getScalableDims(), setTwoDims.getScalableDims());
}

{
// Test for bug from:
// https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a
// Constructs a temporary builder, modifies it, copies it to `builder`.
// This used to lead to a use-after-free. Running under sanitizers will
// catch any issues.
VectorType::Builder builder = VectorType::Builder(vectorType).setDim(0, 16);
VectorType newVectorType = VectorType(builder);
ASSERT_EQ(newVectorType.getDimSize(0), 16);
}

{
// Make builder from scratch (without scalable dims) -- this use to lead to
// a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.
// Running under sanitizers will catch any issues.
SmallVector<int64_t> shape{1, 2, 3, 4};
VectorType::Builder builder(shape, f32);
ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(shape));
}

{
// Set vector shape (without scalable dims) -- this use to lead to
// a use-after-free see: https://github.com/llvm/llvm-project/pull/68969.
// Running under sanitizers will catch any issues.
VectorType::Builder builder(vectorType);
SmallVector<int64_t> newShape{2, 2};
builder.setShape(newShape);
ASSERT_EQ(VectorType(builder).getShape(), ArrayRef(newShape));
}
}

TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
MLIRContext context;
Type f32 = FloatType::getF32(&context);

SmallVector<int64_t> shape{2, 4, 8, 16, 32};
RankedTensorType tensorType = RankedTensorType::get(shape, f32);

{
// Drop some dims.
RankedTensorType dropFrontTwoDims =
RankedTensorType::Builder(tensorType).dropDim(0).dropDim(1).dropDim(0);
ASSERT_EQ(tensorType.getElementType(), dropFrontTwoDims.getElementType());
ASSERT_EQ(dropFrontTwoDims.getShape(), ArrayRef<int64_t>({16, 32}));
}

{
// Insert some dims.
RankedTensorType insertTwoDims =
RankedTensorType::Builder(tensorType).insertDim(7, 2).insertDim(9, 3);
ASSERT_EQ(tensorType.getElementType(), insertTwoDims.getElementType());
ASSERT_EQ(insertTwoDims.getShape(),
ArrayRef<int64_t>({2, 4, 7, 9, 8, 16, 32}));
}

{
// Test for bug from:
// https://github.com/llvm/llvm-project/commit/b44b3494f60296db6aca38a14cab061d9b747a0a
// Constructs a temporary builder, modifies it, copies it to `builder`.
// This used to lead to a use-after-free. Running under sanitizers will
// catch any issues.
RankedTensorType::Builder builder =
RankedTensorType::Builder(tensorType).dropDim(0);
RankedTensorType newTensorType = RankedTensorType(builder);
ASSERT_EQ(tensorType.getShape().drop_front(), newTensorType.getShape());
}
}

} // namespace