Skip to content

Commit ab6e63a

Browse files
authored
[mlir] Make single value ValueRanges memory safer (#121996)
A very common mistake users (and yours truly) make when using `ValueRange`s is assigning a temporary `Value` to it. Example: ```cpp ValueRange values = op.getOperand(); apiThatUsesValueRange(values); ``` The issue is caused by the implicit `const Value&` constructor: As per C++ rules a const reference can be constructed from a temporary and the address of it taken. After the statement, the temporary goes out of scope and `stack-use-after-free` error occurs. This PR fixes that issue by making `ValueRange` capable of owning a single `Value` instance for that case specifically. While technically a departure from the other owner types that are non-owning, I'd argue that this behavior is more intuitive for the majority of users that usually don't need to care about the lifetime of `Value` instances. `TypeRange` has similarly been adopted to accept a single `Type` instance to implement `getTypes`.
1 parent 1e53f95 commit ab6e63a

File tree

5 files changed

+66
-16
lines changed

5 files changed

+66
-16
lines changed

mlir/include/mlir/IR/TypeRange.h

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,12 @@ namespace mlir {
2929
/// a SmallVector/std::vector. This class should be used in places that are not
3030
/// suitable for a more derived type (e.g. ArrayRef) or a template range
3131
/// parameter.
32-
class TypeRange : public llvm::detail::indexed_accessor_range_base<
33-
TypeRange,
34-
llvm::PointerUnion<const Value *, const Type *,
35-
OpOperand *, detail::OpResultImpl *>,
36-
Type, Type, Type> {
32+
class TypeRange
33+
: public llvm::detail::indexed_accessor_range_base<
34+
TypeRange,
35+
llvm::PointerUnion<const Value *, const Type *, OpOperand *,
36+
detail::OpResultImpl *, Type>,
37+
Type, Type, Type> {
3738
public:
3839
using RangeBaseT::RangeBaseT;
3940
TypeRange(ArrayRef<Type> types = std::nullopt);
@@ -44,8 +45,11 @@ class TypeRange : public llvm::detail::indexed_accessor_range_base<
4445
TypeRange(ValueTypeRange<ValueRangeT> values)
4546
: TypeRange(ValueRange(ValueRangeT(values.begin().getCurrent(),
4647
values.end().getCurrent()))) {}
47-
template <typename Arg, typename = std::enable_if_t<std::is_constructible<
48-
ArrayRef<Type>, Arg>::value>>
48+
49+
TypeRange(Type type) : TypeRange(type, /*count=*/1) {}
50+
template <typename Arg, typename = std::enable_if_t<
51+
std::is_constructible_v<ArrayRef<Type>, Arg> &&
52+
!std::is_constructible_v<Type, Arg>>>
4953
TypeRange(Arg &&arg) : TypeRange(ArrayRef<Type>(std::forward<Arg>(arg))) {}
5054
TypeRange(std::initializer_list<Type> types)
5155
: TypeRange(ArrayRef<Type>(types)) {}
@@ -56,8 +60,9 @@ class TypeRange : public llvm::detail::indexed_accessor_range_base<
5660
/// * A pointer to the first element of an array of types.
5761
/// * A pointer to the first element of an array of operands.
5862
/// * A pointer to the first element of an array of results.
63+
/// * A single 'Type' instance.
5964
using OwnerT = llvm::PointerUnion<const Value *, const Type *, OpOperand *,
60-
detail::OpResultImpl *>;
65+
detail::OpResultImpl *, Type>;
6166

6267
/// See `llvm::detail::indexed_accessor_range_base` for details.
6368
static OwnerT offset_base(OwnerT object, ptrdiff_t index);

mlir/include/mlir/IR/ValueRange.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -374,16 +374,16 @@ class ResultRange::UseIterator final
374374
/// SmallVector/std::vector. This class should be used in places that are not
375375
/// suitable for a more derived type (e.g. ArrayRef) or a template range
376376
/// parameter.
377-
class ValueRange final
378-
: public llvm::detail::indexed_accessor_range_base<
379-
ValueRange,
380-
PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *>,
381-
Value, Value, Value> {
377+
class ValueRange final : public llvm::detail::indexed_accessor_range_base<
378+
ValueRange,
379+
PointerUnion<const Value *, OpOperand *,
380+
detail::OpResultImpl *, Value>,
381+
Value, Value, Value> {
382382
public:
383383
/// The type representing the owner of a ValueRange. This is either a list of
384-
/// values, operands, or results.
384+
/// values, operands, or results or a single value.
385385
using OwnerT =
386-
PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *>;
386+
PointerUnion<const Value *, OpOperand *, detail::OpResultImpl *, Value>;
387387

388388
using RangeBaseT::RangeBaseT;
389389

@@ -392,7 +392,7 @@ class ValueRange final
392392
std::is_constructible<ArrayRef<Value>, Arg>::value &&
393393
!std::is_convertible<Arg, Value>::value>>
394394
ValueRange(Arg &&arg) : ValueRange(ArrayRef<Value>(std::forward<Arg>(arg))) {}
395-
ValueRange(const Value &value) : ValueRange(&value, /*count=*/1) {}
395+
ValueRange(Value value) : ValueRange(value, /*count=*/1) {}
396396
ValueRange(const std::initializer_list<Value> &values)
397397
: ValueRange(ArrayRef<Value>(values)) {}
398398
ValueRange(iterator_range<OperandRange::iterator> values)

mlir/lib/IR/OperationSupport.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,15 @@ ValueRange::ValueRange(ResultRange values)
653653
/// See `llvm::detail::indexed_accessor_range_base` for details.
654654
ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
655655
ptrdiff_t index) {
656+
if (llvm::isa_and_nonnull<Value>(owner)) {
657+
// Prevent out-of-bounds indexing for single values.
658+
// Note that we do allow an index of 1 as is required by 'slice'ing that
659+
// returns an empty range. This also matches the usual rules of C++ of being
660+
// allowed to index past the last element of an array.
661+
assert(index <= 1 && "out-of-bound offset into single-value 'ValueRange'");
662+
// Return nullptr to quickly cause segmentation faults on misuse.
663+
return index == 0 ? owner : nullptr;
664+
}
656665
if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
657666
return {value + index};
658667
if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
@@ -661,6 +670,10 @@ ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner,
661670
}
662671
/// See `llvm::detail::indexed_accessor_range_base` for details.
663672
Value ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) {
673+
if (auto value = llvm::dyn_cast_if_present<Value>(owner)) {
674+
assert(index == 0 && "cannot offset into single-value 'ValueRange'");
675+
return value;
676+
}
664677
if (const auto *value = llvm::dyn_cast_if_present<const Value *>(owner))
665678
return value[index];
666679
if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))

mlir/lib/IR/TypeRange.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,23 @@ TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) {
3131
this->base = result;
3232
else if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(owner))
3333
this->base = operand;
34+
else if (auto value = llvm::dyn_cast_if_present<Value>(owner))
35+
this->base = value.getType();
3436
else
3537
this->base = cast<const Value *>(owner);
3638
}
3739

3840
/// See `llvm::detail::indexed_accessor_range_base` for details.
3941
TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptrdiff_t index) {
42+
if (llvm::isa_and_nonnull<Type>(object)) {
43+
// Prevent out-of-bounds indexing for single values.
44+
// Note that we do allow an index of 1 as is required by 'slice'ing that
45+
// returns an empty range. This also matches the usual rules of C++ of being
46+
// allowed to index past the last element of an array.
47+
assert(index <= 1 && "out-of-bound offset into single-value 'ValueRange'");
48+
// Return nullptr to quickly cause segmentation faults on misuse.
49+
return index == 0 ? object : nullptr;
50+
}
4051
if (const auto *value = llvm::dyn_cast_if_present<const Value *>(object))
4152
return {value + index};
4253
if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(object))
@@ -48,6 +59,10 @@ TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptrdiff_t index) {
4859

4960
/// See `llvm::detail::indexed_accessor_range_base` for details.
5061
Type TypeRange::dereference_iterator(OwnerT object, ptrdiff_t index) {
62+
if (auto type = llvm::dyn_cast_if_present<Type>(object)) {
63+
assert(index == 0 && "cannot offset into single-value 'TypeRange'");
64+
return type;
65+
}
5166
if (const auto *value = llvm::dyn_cast_if_present<const Value *>(object))
5267
return (value + index)->getType();
5368
if (auto *operand = llvm::dyn_cast_if_present<OpOperand *>(object))

mlir/unittests/IR/OperationSupportTest.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,4 +313,21 @@ TEST(OperationEquivalenceTest, HashWorksWithFlags) {
313313
op2->destroy();
314314
}
315315

316+
TEST(ValueRangeTest, ValueConstructable) {
317+
MLIRContext context;
318+
Builder builder(&context);
319+
320+
Operation *useOp =
321+
createOp(&context, /*operands=*/std::nullopt, builder.getIntegerType(16));
322+
// Valid construction despite a temporary 'OpResult'.
323+
ValueRange operands = useOp->getResult(0);
324+
325+
useOp->setOperands(operands);
326+
EXPECT_EQ(useOp->getNumOperands(), 1u);
327+
EXPECT_EQ(useOp->getOperand(0), useOp->getResult(0));
328+
329+
useOp->dropAllUses();
330+
useOp->destroy();
331+
}
332+
316333
} // namespace

0 commit comments

Comments
 (0)