Skip to content

Commit 7be6a40

Browse files
River707tensorflower-gardener
authored andcommitted
Add new indexed_accessor_range_base and indexed_accessor_range classes that simplify defining index-able ranges.
Many ranges want similar functionality from a range type(e.g. slice/drop_front/operator[]/etc.), so these classes provide a generic implementation that may be used by many different types of ranges. This removes some code duplication, and also empowers many of the existing range types in MLIR(e.g. result type ranges, operand ranges, ElementsAttr ranges, etc.). This change only updates RegionRange and ValueRange, more ranges will be updated in followup commits. PiperOrigin-RevId: 284615679
1 parent 56da744 commit 7be6a40

File tree

8 files changed

+206
-141
lines changed

8 files changed

+206
-141
lines changed

mlir/include/mlir/IR/Attributes.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -641,12 +641,12 @@ class DenseElementIndexedIteratorImpl
641641
/// Return the current index for this iterator, adjusted for the case of a
642642
/// splat.
643643
ptrdiff_t getDataIndex() const {
644-
bool isSplat = this->object.getInt();
644+
bool isSplat = this->base.getInt();
645645
return isSplat ? 0 : this->index;
646646
}
647647

648-
/// Return the data object pointer.
649-
const char *getData() const { return this->object.getPointer(); }
648+
/// Return the data base pointer.
649+
const char *getData() const { return this->base.getPointer(); }
650650
};
651651
} // namespace detail
652652

mlir/include/mlir/IR/Block.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,9 +460,9 @@ class SuccessorIterator final
460460
Block *>(object, index) {}
461461

462462
SuccessorIterator(const SuccessorIterator &other)
463-
: SuccessorIterator(other.object, other.index) {}
463+
: SuccessorIterator(other.base, other.index) {}
464464

465-
Block *operator*() const { return this->object->getSuccessor(this->index); }
465+
Block *operator*() const { return this->base->getSuccessor(this->index); }
466466

467467
/// Get the successor number in the terminator.
468468
unsigned getSuccessorIndex() const { return this->index; }

mlir/include/mlir/IR/Operation.h

Lines changed: 20 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ class OperandIterator final
668668
: indexed_accessor_iterator<OperandIterator, Operation *, Value *,
669669
Value *, Value *>(object, index) {}
670670

671-
Value *operator*() const { return this->object->getOperand(this->index); }
671+
Value *operator*() const { return this->base->getOperand(this->index); }
672672
};
673673

674674
/// This class implements the operand type iterators for the Operation
@@ -721,11 +721,11 @@ class ResultIterator final
721721
Value *, Value *> {
722722
public:
723723
/// Initializes the result iterator to the specified index.
724-
ResultIterator(Operation *object, unsigned index)
724+
ResultIterator(Operation *base, unsigned index)
725725
: indexed_accessor_iterator<ResultIterator, Operation *, Value *, Value *,
726-
Value *>(object, index) {}
726+
Value *>(base, index) {}
727727

728-
Value *operator*() const { return this->object->getResult(this->index); }
728+
Value *operator*() const { return this->base->getResult(this->index); }
729729
};
730730

731731
/// This class implements the result type iterators for the Operation
@@ -799,15 +799,19 @@ inline auto Operation::getResultTypes() -> result_type_range {
799799
/// SmallVector/std::vector. This class should be used in places that are not
800800
/// suitable for a more derived type (e.g. ArrayRef) or a template range
801801
/// parameter.
802-
class ValueRange {
802+
class ValueRange
803+
: public detail::indexed_accessor_range_base<
804+
ValueRange,
805+
llvm::PointerUnion<Value *const *, OpOperand *, OpResult *>, Value *,
806+
Value *, Value *> {
803807
/// The type representing the owner of this range. This is either a list of
804808
/// values, operands, or results.
805809
using OwnerT = llvm::PointerUnion<Value *const *, OpOperand *, OpResult *>;
806810

807811
public:
808-
ValueRange(const ValueRange &) = default;
809-
ValueRange(ValueRange &&) = default;
810-
ValueRange &operator=(const ValueRange &) = default;
812+
using detail::indexed_accessor_range_base<
813+
ValueRange, OwnerT, Value *, Value *,
814+
Value *>::indexed_accessor_range_base;
811815

812816
template <typename Arg,
813817
typename = typename std::enable_if_t<
@@ -822,46 +826,15 @@ class ValueRange {
822826
ValueRange(iterator_range<OperandIterator> values);
823827
ValueRange(iterator_range<ResultIterator> values);
824828

825-
/// An iterator element of this range.
826-
class Iterator : public indexed_accessor_iterator<Iterator, OwnerT, Value *,
827-
Value *, Value *> {
828-
public:
829-
Value *operator*() const;
830-
831-
private:
832-
Iterator(OwnerT owner, unsigned curIndex);
833-
834-
/// Allow access to the constructor.
835-
friend ValueRange;
836-
};
837-
838-
Iterator begin() const { return Iterator(owner, 0); }
839-
Iterator end() const { return Iterator(owner, count); }
840-
Value *operator[](unsigned index) const {
841-
assert(index < size() && "invalid index for value range");
842-
return *std::next(begin(), index);
843-
}
844-
845-
/// Return the size of this range.
846-
size_t size() const { return count; }
847-
848-
/// Return if the range is empty.
849-
bool empty() const { return size() == 0; }
850-
851-
/// Drop the first N elements, and keep M elements.
852-
ValueRange slice(unsigned n, unsigned m) const;
853-
/// Drop the first n elements.
854-
ValueRange drop_front(unsigned n = 1) const;
855-
/// Drop the last n elements.
856-
ValueRange drop_back(unsigned n = 1) const;
857-
858829
private:
859-
ValueRange(OwnerT owner, unsigned count) : owner(owner), count(count) {}
860-
861-
/// The object that owns the provided range of values.
862-
OwnerT owner;
863-
/// The size from the owning range.
864-
unsigned count;
830+
/// See `detail::indexed_accessor_range_base` for details.
831+
static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index);
832+
/// See `detail::indexed_accessor_range_base` for details.
833+
static Value *dereference_iterator(const OwnerT &owner, ptrdiff_t index);
834+
835+
/// Allow access to `offset_base` and `dereference_iterator`.
836+
friend detail::indexed_accessor_range_base<ValueRange, OwnerT, Value *,
837+
Value *, Value *>;
865838
};
866839

867840
} // end namespace mlir

mlir/include/mlir/IR/Region.h

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -165,14 +165,19 @@ class Region {
165165
/// SmallVector/std::vector. This class should be used in places that are not
166166
/// suitable for a more derived type (e.g. ArrayRef) or a template range
167167
/// parameter.
168-
class RegionRange {
168+
class RegionRange
169+
: public detail::indexed_accessor_range_base<
170+
RegionRange,
171+
llvm::PointerUnion<Region *, const std::unique_ptr<Region> *>,
172+
Region *, Region *, Region *> {
169173
/// The type representing the owner of this range. This is either a list of
170174
/// values, operands, or results.
171175
using OwnerT = llvm::PointerUnion<Region *, const std::unique_ptr<Region> *>;
172176

173177
public:
174-
RegionRange(const RegionRange &) = default;
175-
RegionRange(RegionRange &&) = default;
178+
using detail::indexed_accessor_range_base<
179+
RegionRange, OwnerT, Region *, Region *,
180+
Region *>::indexed_accessor_range_base;
176181

177182
RegionRange(MutableArrayRef<Region> regions = llvm::None);
178183

@@ -184,33 +189,15 @@ class RegionRange {
184189
}
185190
RegionRange(ArrayRef<std::unique_ptr<Region>> regions);
186191

187-
/// An iterator element of this range.
188-
class Iterator : public indexed_accessor_iterator<Iterator, OwnerT, Region *,
189-
Region *, Region *> {
190-
public:
191-
Region *operator*() const;
192-
193-
private:
194-
Iterator(OwnerT owner, unsigned curIndex);
195-
/// Allow access to the constructor.
196-
friend RegionRange;
197-
};
198-
Iterator begin() const { return Iterator(owner, 0); }
199-
Iterator end() const { return Iterator(owner, count); }
200-
Region *operator[](unsigned index) const {
201-
assert(index < size() && "invalid index for region range");
202-
return *std::next(begin(), index);
203-
}
204-
/// Return the size of this range.
205-
size_t size() const { return count; }
206-
/// Return if the range is empty.
207-
bool empty() const { return size() == 0; }
208-
209192
private:
210-
/// The object that owns the provided range of regions.
211-
OwnerT owner;
212-
/// The size from the owning range.
213-
unsigned count;
193+
/// See `detail::indexed_accessor_range_base` for details.
194+
static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index);
195+
/// See `detail::indexed_accessor_range_base` for details.
196+
static Region *dereference_iterator(const OwnerT &owner, ptrdiff_t index);
197+
198+
/// Allow access to `offset_base` and `dereference_iterator`.
199+
friend detail::indexed_accessor_range_base<RegionRange, OwnerT, Region *,
200+
Region *, Region *>;
214201
};
215202

216203
} // end namespace mlir

mlir/include/mlir/Support/STLExtras.h

Lines changed: 129 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -147,24 +147,24 @@ using is_invocable = is_detected<detail::is_invocable, Callable, Args...>;
147147
// Extra additions to <iterator>
148148
//===----------------------------------------------------------------------===//
149149

150-
/// A utility class used to implement an iterator that contains some object and
151-
/// an index. The iterator moves the index but keeps the object constant.
152-
template <typename DerivedT, typename ObjectType, typename T,
150+
/// A utility class used to implement an iterator that contains some base object
151+
/// and an index. The iterator moves the index but keeps the base constant.
152+
template <typename DerivedT, typename BaseT, typename T,
153153
typename PointerT = T *, typename ReferenceT = T &>
154154
class indexed_accessor_iterator
155155
: public llvm::iterator_facade_base<DerivedT,
156156
std::random_access_iterator_tag, T,
157157
std::ptrdiff_t, PointerT, ReferenceT> {
158158
public:
159159
ptrdiff_t operator-(const indexed_accessor_iterator &rhs) const {
160-
assert(object == rhs.object && "incompatible iterators");
160+
assert(base == rhs.base && "incompatible iterators");
161161
return index - rhs.index;
162162
}
163163
bool operator==(const indexed_accessor_iterator &rhs) const {
164-
return object == rhs.object && index == rhs.index;
164+
return base == rhs.base && index == rhs.index;
165165
}
166166
bool operator<(const indexed_accessor_iterator &rhs) const {
167-
assert(object == rhs.object && "incompatible iterators");
167+
assert(base == rhs.base && "incompatible iterators");
168168
return index < rhs.index;
169169
}
170170

@@ -180,16 +180,134 @@ class indexed_accessor_iterator
180180
/// Returns the current index of the iterator.
181181
ptrdiff_t getIndex() const { return index; }
182182

183-
/// Returns the current object of the iterator.
184-
const ObjectType &getObject() const { return object; }
183+
/// Returns the current base of the iterator.
184+
const BaseT &getBase() const { return base; }
185185

186186
protected:
187-
indexed_accessor_iterator(ObjectType object, ptrdiff_t index)
188-
: object(object), index(index) {}
189-
ObjectType object;
187+
indexed_accessor_iterator(BaseT base, ptrdiff_t index)
188+
: base(base), index(index) {}
189+
BaseT base;
190190
ptrdiff_t index;
191191
};
192192

193+
namespace detail {
194+
/// The class represents the base of a range of indexed_accessor_iterators. It
195+
/// provides support for many different range functionalities, e.g.
196+
/// drop_front/slice/etc.. Derived range classes must implement the following
197+
/// static methods:
198+
/// * ReferenceT dereference_iterator(const BaseT &base, ptrdiff_t index)
199+
/// - Derefence an iterator pointing to the base object at the given index.
200+
/// * BaseT offset_base(const BaseT &base, ptrdiff_t index)
201+
/// - Return a new base that is offset from the provide base by 'index'
202+
/// elements.
203+
template <typename DerivedT, typename BaseT, typename T,
204+
typename PointerT = T *, typename ReferenceT = T &>
205+
class indexed_accessor_range_base {
206+
public:
207+
/// An iterator element of this range.
208+
class iterator : public indexed_accessor_iterator<iterator, BaseT, T,
209+
PointerT, ReferenceT> {
210+
public:
211+
// Index into this iterator, invoking a static method on the derived type.
212+
ReferenceT operator*() const {
213+
return DerivedT::dereference_iterator(this->getBase(), this->getIndex());
214+
}
215+
216+
private:
217+
iterator(BaseT owner, ptrdiff_t curIndex)
218+
: indexed_accessor_iterator<iterator, BaseT, T, PointerT, ReferenceT>(
219+
owner, curIndex) {}
220+
221+
/// Allow access to the constructor.
222+
friend indexed_accessor_range_base<DerivedT, BaseT, T, PointerT,
223+
ReferenceT>;
224+
};
225+
226+
iterator begin() const { return iterator(base, 0); }
227+
iterator end() const { return iterator(base, count); }
228+
ReferenceT operator[](unsigned index) const {
229+
assert(index < size() && "invalid index for value range");
230+
return *std::next(begin(), index);
231+
}
232+
233+
/// Return the size of this range.
234+
size_t size() const { return count; }
235+
236+
/// Return if the range is empty.
237+
bool empty() const { return size() == 0; }
238+
239+
/// Drop the first N elements, and keep M elements.
240+
DerivedT slice(unsigned n, unsigned m) const {
241+
assert(n + m <= size() && "invalid size specifiers");
242+
return DerivedT(DerivedT::offset_base(base, n), m);
243+
}
244+
245+
/// Drop the first n elements.
246+
DerivedT drop_front(unsigned n = 1) const {
247+
assert(size() >= n && "Dropping more elements than exist");
248+
return slice(n, size() - n);
249+
}
250+
/// Drop the last n elements.
251+
DerivedT drop_back(unsigned n = 1) const {
252+
assert(size() >= n && "Dropping more elements than exist");
253+
return DerivedT(base, size() - n);
254+
}
255+
256+
protected:
257+
indexed_accessor_range_base(BaseT base, ptrdiff_t count)
258+
: base(base), count(count) {}
259+
indexed_accessor_range_base(const indexed_accessor_range_base &) = default;
260+
indexed_accessor_range_base(indexed_accessor_range_base &&) = default;
261+
indexed_accessor_range_base &
262+
operator=(const indexed_accessor_range_base &) = default;
263+
264+
/// The base that owns the provided range of values.
265+
BaseT base;
266+
/// The size from the owning range.
267+
ptrdiff_t count;
268+
};
269+
} // end namespace detail
270+
271+
/// This class provides an implementation of a range of
272+
/// indexed_accessor_iterators where the base is not indexable. Ranges with
273+
/// bases that are offsetable should derive from indexed_accessor_range_base
274+
/// instead. Derived range classes are expected to implement the following
275+
/// static method:
276+
/// * ReferenceT dereference_iterator(const BaseT &base, ptrdiff_t index)
277+
/// - Derefence an iterator pointing to a parent base at the given index.
278+
template <typename DerivedT, typename BaseT, typename T,
279+
typename PointerT = T *, typename ReferenceT = T &>
280+
class indexed_accessor_range
281+
: public detail::indexed_accessor_range_base<
282+
indexed_accessor_range<DerivedT, BaseT, T, PointerT, ReferenceT>,
283+
std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT> {
284+
protected:
285+
indexed_accessor_range(BaseT base, ptrdiff_t startIndex, ptrdiff_t count)
286+
: detail::indexed_accessor_range_base<
287+
DerivedT, std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT>(
288+
std::make_pair(base, startIndex), count) {}
289+
290+
private:
291+
/// See `detail::indexed_accessor_range_base` for details.
292+
static std::pair<BaseT, ptrdiff_t>
293+
offset_base(const std::pair<BaseT, ptrdiff_t> &base, ptrdiff_t index) {
294+
// We encode the internal base as a pair of the derived base and a start
295+
// index into the derived base.
296+
return std::make_pair(base.first, base.second + index);
297+
}
298+
/// See `detail::indexed_accessor_range_base` for details.
299+
static ReferenceT
300+
dereference_iterator(const std::pair<BaseT, ptrdiff_t> &base,
301+
ptrdiff_t index) {
302+
return DerivedT::dereference_iterator(base.first, base.second + index);
303+
}
304+
305+
/// Allow access to `offset_base` and `dereference_iterator`.
306+
friend detail::indexed_accessor_range_base<
307+
indexed_accessor_range<DerivedT, BaseT, T, PointerT, ReferenceT>,
308+
std::pair<BaseT, ptrdiff_t>, T, PointerT, ReferenceT>;
309+
};
310+
193311
/// Given a container of pairs, return a range over the second elements.
194312
template <typename ContainerTy> auto make_second_range(ContainerTy &&c) {
195313
return llvm::map_range(

mlir/lib/IR/Attributes.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
527527

528528
/// Accesses the Attribute value at this iterator position.
529529
Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
530-
auto owner = getFromOpaquePointer(object).cast<DenseElementsAttr>();
530+
auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
531531
Type eltTy = owner.getType().getElementType();
532532
if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) {
533533
if (intEltTy.getWidth() == 1)

0 commit comments

Comments
 (0)