|
| 1 | +//===- ScalableVectorType.h - Scalable Vector Helpers -----------*- C++ -*-===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | + |
| 9 | +#ifndef MLIR_SUPPORT_SCALABLEVECTORTYPE_H |
| 10 | +#define MLIR_SUPPORT_SCALABLEVECTORTYPE_H |
| 11 | + |
| 12 | +#include "mlir/IR/BuiltinTypes.h" |
| 13 | +#include "mlir/Support/LLVM.h" |
| 14 | + |
| 15 | +namespace mlir { |
| 16 | + |
| 17 | +//===----------------------------------------------------------------------===// |
| 18 | +// VectorDim |
| 19 | +//===----------------------------------------------------------------------===// |
| 20 | + |
| 21 | +/// This class represents a dimension of a vector type. Unlike other ShapedTypes |
| 22 | +/// vector dimensions can have scalable quantities, which means the dimension |
| 23 | +/// has a known minimum size, which is scaled by a constant that is only |
| 24 | +/// known at runtime. |
| 25 | +class VectorDim { |
| 26 | +public: |
| 27 | + explicit constexpr VectorDim(int64_t quantity, bool scalable) |
| 28 | + : quantity(quantity), scalable(scalable) {}; |
| 29 | + |
| 30 | + /// Constructs a new fixed dimension. |
| 31 | + constexpr static VectorDim getFixed(int64_t quantity) { |
| 32 | + return VectorDim(quantity, false); |
| 33 | + } |
| 34 | + |
| 35 | + /// Constructs a new scalable dimension. |
| 36 | + constexpr static VectorDim getScalable(int64_t quantity) { |
| 37 | + return VectorDim(quantity, true); |
| 38 | + } |
| 39 | + |
| 40 | + /// Returns true if this dimension is scalable; |
| 41 | + constexpr bool isScalable() const { return scalable; } |
| 42 | + |
| 43 | + /// Returns true if this dimension is fixed. |
| 44 | + constexpr bool isFixed() const { return !isScalable(); } |
| 45 | + |
| 46 | + /// Returns the minimum number of elements this dimension can contain. |
| 47 | + constexpr int64_t getMinSize() const { return quantity; } |
| 48 | + |
| 49 | + /// If this dimension is fixed returns the number of elements, otherwise |
| 50 | + /// aborts. |
| 51 | + constexpr int64_t getFixedSize() const { |
| 52 | + assert(isFixed()); |
| 53 | + return quantity; |
| 54 | + } |
| 55 | + |
| 56 | + constexpr bool operator==(VectorDim const &dim) const { |
| 57 | + return quantity == dim.quantity && scalable == dim.scalable; |
| 58 | + } |
| 59 | + |
| 60 | + constexpr bool operator!=(VectorDim const &dim) const { |
| 61 | + return !(*this == dim); |
| 62 | + } |
| 63 | + |
| 64 | + /// Print the dim. |
| 65 | + void print(raw_ostream &os) { |
| 66 | + if (isScalable()) |
| 67 | + os << '['; |
| 68 | + os << getMinSize(); |
| 69 | + if (isScalable()) |
| 70 | + os << ']'; |
| 71 | + } |
| 72 | + |
| 73 | + /// Helper class for indexing into a list of sizes (and possibly empty) list |
| 74 | + /// of scalable dimensions, extracting VectorDim elements. |
| 75 | + struct Indexer { |
| 76 | + explicit Indexer(ArrayRef<int64_t> sizes, ArrayRef<bool> scalableDims) |
| 77 | + : sizes(sizes), scalableDims(scalableDims) { |
| 78 | + assert( |
| 79 | + scalableDims.empty() || |
| 80 | + sizes.size() == scalableDims.size() && |
| 81 | + "expected `scalableDims` to be empty or match `sizes` in length"); |
| 82 | + } |
| 83 | + |
| 84 | + VectorDim operator[](size_t idx) const { |
| 85 | + int64_t size = sizes[idx]; |
| 86 | + bool scalable = scalableDims.empty() ? false : scalableDims[idx]; |
| 87 | + return VectorDim(size, scalable); |
| 88 | + } |
| 89 | + |
| 90 | + ArrayRef<int64_t> sizes; |
| 91 | + ArrayRef<bool> scalableDims; |
| 92 | + }; |
| 93 | + |
| 94 | +private: |
| 95 | + int64_t quantity; |
| 96 | + bool scalable; |
| 97 | +}; |
| 98 | + |
| 99 | +inline raw_ostream &operator<<(raw_ostream &os, VectorDim dim) { |
| 100 | + dim.print(os); |
| 101 | + return os; |
| 102 | +} |
| 103 | + |
| 104 | +//===----------------------------------------------------------------------===// |
| 105 | +// VectorDimList |
| 106 | +//===----------------------------------------------------------------------===// |
| 107 | + |
| 108 | +/// Represents a non-owning list of vector dimensions. The underlying dimension |
| 109 | +/// sizes and scalability flags are stored a two seperate lists to match the |
| 110 | +/// storage of a VectorType. |
| 111 | +class VectorDimList : public VectorDim::Indexer { |
| 112 | +public: |
| 113 | + using VectorDim::Indexer::Indexer; |
| 114 | + |
| 115 | + class Iterator : public llvm::iterator_facade_base< |
| 116 | + Iterator, std::random_access_iterator_tag, VectorDim, |
| 117 | + std::ptrdiff_t, VectorDim, VectorDim> { |
| 118 | + public: |
| 119 | + Iterator(VectorDim::Indexer indexer, size_t index) |
| 120 | + : indexer(indexer), index(index) {}; |
| 121 | + |
| 122 | + // Iterator boilerplate. |
| 123 | + ptrdiff_t operator-(const Iterator &rhs) const { return index - rhs.index; } |
| 124 | + bool operator==(const Iterator &rhs) const { return index == rhs.index; } |
| 125 | + bool operator<(const Iterator &rhs) const { return index < rhs.index; } |
| 126 | + Iterator &operator+=(ptrdiff_t offset) { |
| 127 | + index += offset; |
| 128 | + return *this; |
| 129 | + } |
| 130 | + Iterator &operator-=(ptrdiff_t offset) { |
| 131 | + index -= offset; |
| 132 | + return *this; |
| 133 | + } |
| 134 | + VectorDim operator*() const { return indexer[index]; } |
| 135 | + |
| 136 | + VectorDim::Indexer getIndexer() const { return indexer; } |
| 137 | + ptrdiff_t getIndex() const { return index; } |
| 138 | + |
| 139 | + private: |
| 140 | + VectorDim::Indexer indexer; |
| 141 | + ptrdiff_t index; |
| 142 | + }; |
| 143 | + |
| 144 | + // Generic definitions. |
| 145 | + using value_type = VectorDim; |
| 146 | + using iterator = Iterator; |
| 147 | + using const_iterator = Iterator; |
| 148 | + using reverse_iterator = std::reverse_iterator<iterator>; |
| 149 | + using const_reverse_iterator = std::reverse_iterator<const_iterator>; |
| 150 | + using size_type = size_t; |
| 151 | + using difference_type = ptrdiff_t; |
| 152 | + |
| 153 | + /// Construct from iterator pair. |
| 154 | + VectorDimList(Iterator begin, Iterator end) |
| 155 | + : VectorDimList(VectorDimList(begin.getIndexer()) |
| 156 | + .slice(begin.getIndex(), end - begin)) {} |
| 157 | + |
| 158 | + VectorDimList(VectorDim::Indexer indexer) : VectorDim::Indexer(indexer) {}; |
| 159 | + |
| 160 | + /// Construct from a VectorType. |
| 161 | + static VectorDimList from(VectorType vectorType) { |
| 162 | + if (!vectorType) |
| 163 | + return VectorDimList({}, {}); |
| 164 | + return VectorDimList(vectorType.getShape(), vectorType.getScalableDims()); |
| 165 | + } |
| 166 | + |
| 167 | + Iterator begin() const { return Iterator(*this, 0); } |
| 168 | + Iterator end() const { return Iterator(*this, size()); } |
| 169 | + |
| 170 | + /// Check if the dims are empty. |
| 171 | + bool empty() const { return sizes.empty(); } |
| 172 | + |
| 173 | + /// Get the number of dims. |
| 174 | + size_t size() const { return sizes.size(); } |
| 175 | + |
| 176 | + /// Return the first dim. |
| 177 | + VectorDim front() const { return (*this)[0]; } |
| 178 | + |
| 179 | + /// Return the last dim. |
| 180 | + VectorDim back() const { return (*this)[size() - 1]; } |
| 181 | + |
| 182 | + /// Chop of the first \p n dims, and keep the remaining \p m dims. |
| 183 | + VectorDimList slice(size_t n, size_t m) const { |
| 184 | + ArrayRef<int64_t> newSizes = sizes.slice(n, m); |
| 185 | + ArrayRef<bool> newScalableDims = |
| 186 | + scalableDims.empty() ? ArrayRef<bool>{} : scalableDims.slice(n, m); |
| 187 | + return VectorDimList(newSizes, newScalableDims); |
| 188 | + } |
| 189 | + |
| 190 | + /// Drop the first \p n dims. |
| 191 | + VectorDimList dropFront(size_t n = 1) const { return slice(n, size() - n); } |
| 192 | + |
| 193 | + /// Drop the last \p n dims. |
| 194 | + VectorDimList dropBack(size_t n = 1) const { return slice(0, size() - n); } |
| 195 | + |
| 196 | + /// Return a copy of *this with only the first \p n elements. |
| 197 | + VectorDimList takeFront(size_t n = 1) const { |
| 198 | + if (n >= size()) |
| 199 | + return *this; |
| 200 | + return dropBack(size() - n); |
| 201 | + } |
| 202 | + |
| 203 | + /// Return a copy of *this with only the last \p n elements. |
| 204 | + VectorDimList takeBack(size_t n = 1) const { |
| 205 | + if (n >= size()) |
| 206 | + return *this; |
| 207 | + return dropFront(size() - n); |
| 208 | + } |
| 209 | + |
| 210 | + /// Return copy of *this with the first n dims matching the predicate removed. |
| 211 | + template <class PredicateT> |
| 212 | + VectorDimList dropWhile(PredicateT predicate) const { |
| 213 | + return VectorDimList(llvm::find_if_not(*this, predicate), end()); |
| 214 | + } |
| 215 | + |
| 216 | + /// Returns true if one or more of the dims are scalable. |
| 217 | + bool hasScalableDims() const { |
| 218 | + return llvm::is_contained(getScalableDims(), true); |
| 219 | + } |
| 220 | + |
| 221 | + /// Check for dim equality. |
| 222 | + bool equals(VectorDimList rhs) const { |
| 223 | + if (size() != rhs.size()) |
| 224 | + return false; |
| 225 | + return std::equal(begin(), end(), rhs.begin()); |
| 226 | + } |
| 227 | + |
| 228 | + /// Check for dim equality. |
| 229 | + bool equals(ArrayRef<VectorDim> rhs) const { |
| 230 | + if (size() != rhs.size()) |
| 231 | + return false; |
| 232 | + return std::equal(begin(), end(), rhs.begin()); |
| 233 | + } |
| 234 | + |
| 235 | + /// Return the underlying sizes. |
| 236 | + ArrayRef<int64_t> getSizes() const { return sizes; } |
| 237 | + |
| 238 | + /// Return the underlying scalable dims. |
| 239 | + ArrayRef<bool> getScalableDims() const { return scalableDims; } |
| 240 | +}; |
| 241 | + |
| 242 | +inline bool operator==(VectorDimList lhs, VectorDimList rhs) { |
| 243 | + return lhs.equals(rhs); |
| 244 | +} |
| 245 | + |
| 246 | +inline bool operator!=(VectorDimList lhs, VectorDimList rhs) { |
| 247 | + return !(lhs == rhs); |
| 248 | +} |
| 249 | + |
| 250 | +inline bool operator==(VectorDimList lhs, ArrayRef<VectorDim> rhs) { |
| 251 | + return lhs.equals(rhs); |
| 252 | +} |
| 253 | + |
| 254 | +inline bool operator!=(VectorDimList lhs, ArrayRef<VectorDim> rhs) { |
| 255 | + return !(lhs == rhs); |
| 256 | +} |
| 257 | + |
| 258 | +//===----------------------------------------------------------------------===// |
| 259 | +// ScalableVectorType |
| 260 | +//===----------------------------------------------------------------------===// |
| 261 | + |
| 262 | +/// A pseudo-type that wraps a VectorType that aims to provide safe APIs for |
| 263 | +/// working with scalable vectors. Slightly contrary to the name this class can |
| 264 | +/// represent both fixed and scalable vectors, however, if you are only dealing |
| 265 | +/// with fixed vectors the plain VectorType is likely more convenient. |
| 266 | +/// |
| 267 | +/// The main difference from the regular VectorType is that vector dimensions |
| 268 | +/// are _not_ represented as `int64_t`, which does not allow encoding the |
| 269 | +/// scalability into the dimension. Instead, vector dimensions are represented |
| 270 | +/// by a VectorDim class. A VectorDim stores both the size and scalability of a |
| 271 | +/// dimension. This makes common errors like only checking the size (but not the |
| 272 | +/// scalability) impossible (without being explicit with your intention). |
| 273 | +/// |
| 274 | +/// To make this convenient to work with there is VectorDimList which provides |
| 275 | +/// ArrayRef-like helper methods along with an iterator for VectorDims. |
| 276 | +/// |
| 277 | +/// ScalableVectorType can freely converted to VectorType (and vice versa), |
| 278 | +/// though there are two main ways to acquire a ScalableVectorType. |
| 279 | +/// |
| 280 | +/// Assignment: |
| 281 | +/// |
| 282 | +/// This does not check the scalability of `myVectorType`. This is valid and the |
| 283 | +/// helpers on ScalableVectorType will function as normal. |
| 284 | +/// ```c++ |
| 285 | +/// VectorType myVectorType = ...; |
| 286 | +/// ScalableVectorType scalableVector = myVectorType; |
| 287 | +/// ``` |
| 288 | +/// |
| 289 | +/// Casting: |
| 290 | +/// |
| 291 | +/// This checks the scalability of `myVectorType`. In this case, |
| 292 | +/// `scalableVector` will be falsy if `myVectorType` contains no scalable dims. |
| 293 | +/// ```c++ |
| 294 | +/// VectorType myVectorType = ...; |
| 295 | +/// auto scalableVector = dyn_cast<ScalableVectorType>(myVectorType); |
| 296 | +/// ``` |
| 297 | +class ScalableVectorType { |
| 298 | +public: |
| 299 | + using Dim = VectorDim; |
| 300 | + using DimList = VectorDimList; |
| 301 | + |
| 302 | + ScalableVectorType(VectorType vectorType) : vectorType(vectorType) {}; |
| 303 | + |
| 304 | + /// Construct a new ScalableVectorType. |
| 305 | + static ScalableVectorType get(DimList shape, Type elementType) { |
| 306 | + return VectorType::get(shape.getSizes(), elementType, |
| 307 | + shape.getScalableDims()); |
| 308 | + } |
| 309 | + |
| 310 | + /// Construct a new ScalableVectorType. |
| 311 | + static ScalableVectorType get(ArrayRef<Dim> shape, Type elementType) { |
| 312 | + SmallVector<int64_t> sizes; |
| 313 | + SmallVector<bool> scalableDims; |
| 314 | + sizes.reserve(shape.size()); |
| 315 | + scalableDims.reserve(shape.size()); |
| 316 | + for (Dim dim : shape) { |
| 317 | + sizes.push_back(dim.getMinSize()); |
| 318 | + scalableDims.push_back(dim.isScalable()); |
| 319 | + } |
| 320 | + return VectorType::get(sizes, elementType, scalableDims); |
| 321 | + } |
| 322 | + |
| 323 | + inline static bool classof(Type type) { |
| 324 | + auto vectorType = dyn_cast_if_present<VectorType>(type); |
| 325 | + return vectorType && vectorType.isScalable(); |
| 326 | + } |
| 327 | + |
| 328 | + /// Returns the value of the specified dimension (including scalability). |
| 329 | + Dim getDim(unsigned idx) const { |
| 330 | + assert(idx < getRank() && "invalid dim index for vector type"); |
| 331 | + return getDims()[idx]; |
| 332 | + } |
| 333 | + |
| 334 | + /// Returns the dimensions of this vector type (including scalability). |
| 335 | + DimList getDims() const { |
| 336 | + return DimList(vectorType.getShape(), vectorType.getScalableDims()); |
| 337 | + } |
| 338 | + |
| 339 | + /// Returns the rank of this vector type. |
| 340 | + int64_t getRank() const { return vectorType.getRank(); } |
| 341 | + |
| 342 | + /// Returns true if the vector contains scalable dimensions. |
| 343 | + bool isScalable() const { return vectorType.isScalable(); } |
| 344 | + bool allDimsScalable() const { return vectorType.allDimsScalable(); } |
| 345 | + |
| 346 | + /// Returns the element type of this vector type. |
| 347 | + Type getElementType() const { return vectorType.getElementType(); } |
| 348 | + |
| 349 | + /// Clones this vector type with a new element type. |
| 350 | + ScalableVectorType clone(Type elementType) { |
| 351 | + return vectorType.clone(elementType); |
| 352 | + } |
| 353 | + |
| 354 | + operator VectorType() const { return vectorType; } |
| 355 | + |
| 356 | + explicit operator bool() const { return bool(vectorType); } |
| 357 | + |
| 358 | +private: |
| 359 | + VectorType vectorType; |
| 360 | +}; |
| 361 | + |
| 362 | +} // namespace mlir |
| 363 | + |
| 364 | +#endif |
0 commit comments