Skip to content

Commit 94a4ca4

Browse files
committed
[mlir] Add a TypeRange class that functions similar to ValueRange.
Summary: This class wraps around the various different ways to construct a range of Type, without forcing the materialization of that range into a contiguous vector. Differential Revision: https://reviews.llvm.org/D74646
1 parent 2d146aa commit 94a4ca4

File tree

4 files changed

+103
-9
lines changed

4 files changed

+103
-9
lines changed

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,9 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, bool value) {
188188
return p << (value ? StringRef("true") : "false");
189189
}
190190

191-
template <typename IteratorT>
192-
inline OpAsmPrinter &
193-
operator<<(OpAsmPrinter &p,
194-
const iterator_range<ValueTypeIterator<IteratorT>> &types) {
191+
template <typename ValueRangeT>
192+
inline OpAsmPrinter &operator<<(OpAsmPrinter &p,
193+
const ValueTypeRange<ValueRangeT> &types) {
195194
interleaveComma(types, p);
196195
return p;
197196
}

mlir/include/mlir/IR/Operation.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ class Operation final
232232

233233
// Support operand type iteration.
234234
using operand_type_iterator = operand_range::type_iterator;
235-
using operand_type_range = iterator_range<operand_type_iterator>;
235+
using operand_type_range = operand_range::type_range;
236236
operand_type_iterator operand_type_begin() { return operand_begin(); }
237237
operand_type_iterator operand_type_end() { return operand_end(); }
238238
operand_type_range getOperandTypes() { return getOperands().getTypes(); }
@@ -260,7 +260,7 @@ class Operation final
260260

261261
/// Support result type iteration.
262262
using result_type_iterator = result_range::type_iterator;
263-
using result_type_range = ArrayRef<Type>;
263+
using result_type_range = result_range::type_range;
264264
result_type_iterator result_type_begin() { return getResultTypes().begin(); }
265265
result_type_iterator result_type_end() { return getResultTypes().end(); }
266266
result_type_range getResultTypes();

mlir/include/mlir/IR/OperationSupport.h

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,17 @@ struct OperationState;
3232
class OpAsmParser;
3333
class OpAsmParserResult;
3434
class OpAsmPrinter;
35+
class OperandRange;
3536
class OpFoldResult;
3637
class ParseResult;
3738
class Pattern;
3839
class Region;
40+
class ResultRange;
3941
class RewritePattern;
4042
class Type;
4143
class Value;
4244
class ValueRange;
45+
template <typename ValueRangeT> class ValueTypeRange;
4346

4447
/// This is an adaptor from a list of values to named operands of OpTy. In a
4548
/// generic operation context, e.g., in dialect conversions, an ordered array of
@@ -535,6 +538,46 @@ class OpPrintingFlags {
535538
// Operation Value-Iterators
536539
//===----------------------------------------------------------------------===//
537540

541+
//===----------------------------------------------------------------------===//
542+
// TypeRange
543+
544+
/// This class provides an abstraction over the various different ranges of
545+
/// value types. In many cases, this prevents the need to explicitly materialize
546+
/// a SmallVector/std::vector. This class should be used in places that are not
547+
/// suitable for a more derived type (e.g. ArrayRef) or a template range
548+
/// parameter.
549+
class TypeRange
550+
: public detail::indexed_accessor_range_base<
551+
TypeRange,
552+
llvm::PointerUnion<const Value *, const Type *, OpOperand *>, Type,
553+
Type, Type> {
554+
public:
555+
using RangeBaseT::RangeBaseT;
556+
TypeRange(ArrayRef<Type> types = llvm::None);
557+
explicit TypeRange(OperandRange values);
558+
explicit TypeRange(ResultRange values);
559+
explicit TypeRange(ValueRange values);
560+
template <typename ValueRangeT>
561+
TypeRange(ValueTypeRange<ValueRangeT> values)
562+
: TypeRange(ValueRangeT(values.begin().getCurrent(),
563+
values.end().getCurrent())) {}
564+
565+
private:
566+
/// The owner of the range is either:
567+
/// * A pointer to the first element of an array of values.
568+
/// * A pointer to the first element of an array of types.
569+
/// * A pointer to the first element of an array of operands.
570+
using OwnerT = llvm::PointerUnion<const Value *, const Type *, OpOperand *>;
571+
572+
/// See `detail::indexed_accessor_range_base` for details.
573+
static OwnerT offset_base(OwnerT object, ptrdiff_t index);
574+
/// See `detail::indexed_accessor_range_base` for details.
575+
static Type dereference_iterator(OwnerT object, ptrdiff_t index);
576+
577+
/// Allow access to `offset_base` and `dereference_iterator`.
578+
friend RangeBaseT;
579+
};
580+
538581
//===----------------------------------------------------------------------===//
539582
// ValueTypeRange
540583

@@ -555,6 +598,18 @@ class ValueTypeIterator final
555598
: llvm::mapped_iterator<ValueIteratorT, Type (*)(Value)>(it, &unwrap) {}
556599
};
557600

601+
/// This class implements iteration on the types of a given range of values.
602+
template <typename ValueRangeT>
603+
class ValueTypeRange final
604+
: public llvm::iterator_range<
605+
ValueTypeIterator<typename ValueRangeT::iterator>> {
606+
public:
607+
using llvm::iterator_range<
608+
ValueTypeIterator<typename ValueRangeT::iterator>>::iterator_range;
609+
template <typename Container>
610+
ValueTypeRange(Container &&c) : ValueTypeRange(c.begin(), c.end()) {}
611+
};
612+
558613
//===----------------------------------------------------------------------===//
559614
// OperandRange
560615

@@ -568,7 +623,8 @@ class OperandRange final
568623

569624
/// Returns the types of the values within this range.
570625
using type_iterator = ValueTypeIterator<iterator>;
571-
iterator_range<type_iterator> getTypes() const { return {begin(), end()}; }
626+
using type_range = ValueTypeRange<OperandRange>;
627+
type_range getTypes() const { return {begin(), end()}; }
572628

573629
private:
574630
/// See `detail::indexed_accessor_range_base` for details.
@@ -598,7 +654,8 @@ class ResultRange final
598654

599655
/// Returns the types of the values within this range.
600656
using type_iterator = ArrayRef<Type>::iterator;
601-
ArrayRef<Type> getTypes() const;
657+
using type_range = ArrayRef<Type>;
658+
type_range getTypes() const;
602659

603660
private:
604661
/// See `indexed_accessor_range` for details.
@@ -666,7 +723,8 @@ class ValueRange final
666723

667724
/// Returns the types of the values within this range.
668725
using type_iterator = ValueTypeIterator<iterator>;
669-
iterator_range<type_iterator> getTypes() const { return {begin(), end()}; }
726+
using type_range = ValueTypeRange<ValueRange>;
727+
type_range getTypes() const { return {begin(), end()}; }
670728

671729
private:
672730
using OwnerT = detail::ValueRangeOwner;

mlir/lib/IR/OperationSupport.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,43 @@ void detail::OperandStorage::grow(ResizableStorage &resizeUtil,
140140
// Operation Value-Iterators
141141
//===----------------------------------------------------------------------===//
142142

143+
//===----------------------------------------------------------------------===//
144+
// TypeRange
145+
146+
TypeRange::TypeRange(ArrayRef<Type> types)
147+
: TypeRange(types.data(), types.size()) {}
148+
TypeRange::TypeRange(OperandRange values)
149+
: TypeRange(values.begin().getBase(), values.size()) {}
150+
TypeRange::TypeRange(ResultRange values)
151+
: TypeRange(values.getBase()->getResultTypes().slice(values.getStartIndex(),
152+
values.size())) {}
153+
TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) {
154+
detail::ValueRangeOwner owner = values.begin().getBase();
155+
if (auto *op = reinterpret_cast<Operation *>(owner.ptr.dyn_cast<void *>()))
156+
this->base = &op->getResultTypes()[owner.startIndex];
157+
else if (auto *operand = owner.ptr.dyn_cast<OpOperand *>())
158+
this->base = operand;
159+
else
160+
this->base = owner.ptr.get<const Value *>();
161+
}
162+
163+
/// See `detail::indexed_accessor_range_base` for details.
164+
TypeRange::OwnerT TypeRange::offset_base(OwnerT object, ptrdiff_t index) {
165+
if (auto *value = object.dyn_cast<const Value *>())
166+
return {value + index};
167+
if (auto *operand = object.dyn_cast<OpOperand *>())
168+
return {operand + index};
169+
return {object.dyn_cast<const Type *>() + index};
170+
}
171+
/// See `detail::indexed_accessor_range_base` for details.
172+
Type TypeRange::dereference_iterator(OwnerT object, ptrdiff_t index) {
173+
if (auto *value = object.dyn_cast<const Value *>())
174+
return (value + index)->getType();
175+
if (auto *operand = object.dyn_cast<OpOperand *>())
176+
return (operand + index)->get().getType();
177+
return object.dyn_cast<const Type *>()[index];
178+
}
179+
143180
//===----------------------------------------------------------------------===//
144181
// OperandRange
145182

0 commit comments

Comments
 (0)