|
12 | 12 | #include "mlir/IR/BuiltinAttributeInterfaces.h"
|
13 | 13 | #include "mlir/IR/BuiltinTypeInterfaces.h"
|
14 | 14 | #include "mlir/Support/ADTExtras.h"
|
| 15 | +#include "llvm/ADT/STLExtras.h" |
15 | 16 |
|
16 | 17 | namespace llvm {
|
17 | 18 | class BitVector;
|
@@ -181,6 +182,239 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
|
181 | 182 | operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
|
182 | 183 | };
|
183 | 184 |
|
| 185 | +//===----------------------------------------------------------------------===// |
| 186 | +// VectorDim |
| 187 | +//===----------------------------------------------------------------------===// |
| 188 | + |
| 189 | +/// This class represents a dimension of a vector type. Unlike other ShapedTypes |
| 190 | +/// vector dimensions can have scalable quantities, which means the dimension |
| 191 | +/// has a known minimum size, which is scaled by a constant that is only |
| 192 | +/// known at runtime. |
| 193 | +class VectorDim { |
| 194 | +public: |
| 195 | + explicit constexpr VectorDim(int64_t quantity, bool scalable) |
| 196 | + : quantity(quantity), scalable(scalable){}; |
| 197 | + |
| 198 | + /// Constructs a new fixed dimension. |
| 199 | + constexpr static VectorDim getFixed(int64_t quantity) { |
| 200 | + return VectorDim(quantity, false); |
| 201 | + } |
| 202 | + |
| 203 | + /// Constructs a new scalable dimension. |
| 204 | + constexpr static VectorDim getScalable(int64_t quantity) { |
| 205 | + return VectorDim(quantity, true); |
| 206 | + } |
| 207 | + |
| 208 | + /// Returns true if this dimension is scalable; |
| 209 | + constexpr bool isScalable() const { return scalable; } |
| 210 | + |
| 211 | + /// Returns true if this dimension is fixed. |
| 212 | + constexpr bool isFixed() const { return !isScalable(); } |
| 213 | + |
| 214 | + /// Returns the minimum number of elements this dimension can contain. |
| 215 | + constexpr int64_t getMinSize() const { return quantity; } |
| 216 | + |
| 217 | + /// If this dimension is fixed returns the number of elements, otherwise |
| 218 | + /// aborts. |
| 219 | + constexpr int64_t getFixedSize() const { |
| 220 | + assert(isFixed()); |
| 221 | + return quantity; |
| 222 | + } |
| 223 | + |
| 224 | + constexpr bool operator==(VectorDim const &dim) const { |
| 225 | + return quantity == dim.quantity && scalable == dim.scalable; |
| 226 | + } |
| 227 | + |
| 228 | + constexpr bool operator!=(VectorDim const &dim) const { |
| 229 | + return !(*this == dim); |
| 230 | + } |
| 231 | + |
| 232 | + /// Print the dim. |
| 233 | + void print(raw_ostream &os) { |
| 234 | + if (isScalable()) |
| 235 | + os << '['; |
| 236 | + os << getMinSize(); |
| 237 | + if (isScalable()) |
| 238 | + os << ']'; |
| 239 | + } |
| 240 | + |
| 241 | + /// Helper class for indexing into a list of sizes (and possibly empty) list |
| 242 | + /// of scalable dimensions, extracting VectorDim elements. |
| 243 | + struct Indexer { |
| 244 | + explicit Indexer(ArrayRef<int64_t> sizes, ArrayRef<bool> scalableDims) |
| 245 | + : sizes(sizes), scalableDims(scalableDims) { |
| 246 | + assert( |
| 247 | + scalableDims.empty() || |
| 248 | + sizes.size() == scalableDims.size() && |
| 249 | + "expected `scalableDims` to be empty or match `sizes` in length"); |
| 250 | + } |
| 251 | + |
| 252 | + VectorDim operator[](size_t idx) const { |
| 253 | + int64_t size = sizes[idx]; |
| 254 | + bool scalable = scalableDims.empty() ? false : scalableDims[idx]; |
| 255 | + return VectorDim(size, scalable); |
| 256 | + } |
| 257 | + |
| 258 | + ArrayRef<int64_t> sizes; |
| 259 | + ArrayRef<bool> scalableDims; |
| 260 | + }; |
| 261 | + |
| 262 | +private: |
| 263 | + int64_t quantity; |
| 264 | + bool scalable; |
| 265 | +}; |
| 266 | + |
| 267 | +inline raw_ostream &operator<<(raw_ostream &os, VectorDim dim) { |
| 268 | + dim.print(os); |
| 269 | + return os; |
| 270 | +} |
| 271 | + |
| 272 | +//===----------------------------------------------------------------------===// |
| 273 | +// VectorDims |
| 274 | +//===----------------------------------------------------------------------===// |
| 275 | + |
| 276 | +/// Represents a non-owning list of vector dimensions. The underlying dimension |
| 277 | +/// sizes and scalability flags are stored a two seperate lists to match the |
| 278 | +/// storage of a VectorType. |
| 279 | +class VectorDims : public VectorDim::Indexer { |
| 280 | +public: |
| 281 | + using VectorDim::Indexer::Indexer; |
| 282 | + |
| 283 | + class Iterator : public llvm::iterator_facade_base< |
| 284 | + Iterator, std::random_access_iterator_tag, VectorDim, |
| 285 | + std::ptrdiff_t, VectorDim, VectorDim> { |
| 286 | + public: |
| 287 | + Iterator(VectorDim::Indexer indexer, size_t index) |
| 288 | + : indexer(indexer), index(index){}; |
| 289 | + |
| 290 | + // Iterator boilerplate. |
| 291 | + ptrdiff_t operator-(const Iterator &rhs) const { return index - rhs.index; } |
| 292 | + bool operator==(const Iterator &rhs) const { return index == rhs.index; } |
| 293 | + bool operator<(const Iterator &rhs) const { return index < rhs.index; } |
| 294 | + Iterator &operator+=(ptrdiff_t offset) { |
| 295 | + index += offset; |
| 296 | + return *this; |
| 297 | + } |
| 298 | + Iterator &operator-=(ptrdiff_t offset) { |
| 299 | + index -= offset; |
| 300 | + return *this; |
| 301 | + } |
| 302 | + VectorDim operator*() const { return indexer[index]; } |
| 303 | + |
| 304 | + VectorDim::Indexer getIndexer() const { return indexer; } |
| 305 | + ptrdiff_t getIndex() const { return index; } |
| 306 | + |
| 307 | + private: |
| 308 | + VectorDim::Indexer indexer; |
| 309 | + ptrdiff_t index; |
| 310 | + }; |
| 311 | + |
| 312 | + // Generic definitions. |
| 313 | + using value_type = VectorDim; |
| 314 | + using iterator = Iterator; |
| 315 | + using const_iterator = Iterator; |
| 316 | + using reverse_iterator = std::reverse_iterator<iterator>; |
| 317 | + using const_reverse_iterator = std::reverse_iterator<const_iterator>; |
| 318 | + using size_type = size_t; |
| 319 | + using difference_type = ptrdiff_t; |
| 320 | + |
| 321 | + /// Construct from iterator pair. |
| 322 | + VectorDims(Iterator begin, Iterator end) |
| 323 | + : VectorDims(VectorDims(begin.getIndexer()) |
| 324 | + .slice(begin.getIndex(), end - begin)) {} |
| 325 | + |
| 326 | + VectorDims(VectorDim::Indexer indexer) : VectorDim::Indexer(indexer){}; |
| 327 | + |
| 328 | + Iterator begin() const { return Iterator(*this, 0); } |
| 329 | + Iterator end() const { return Iterator(*this, size()); } |
| 330 | + |
| 331 | + /// Check if the dims are empty. |
| 332 | + bool empty() const { return sizes.empty(); } |
| 333 | + |
| 334 | + /// Get the number of dims. |
| 335 | + size_t size() const { return sizes.size(); } |
| 336 | + |
| 337 | + /// Return the first dim. |
| 338 | + VectorDim front() const { return (*this)[0]; } |
| 339 | + |
| 340 | + /// Return the last dim. |
| 341 | + VectorDim back() const { return (*this)[size() - 1]; } |
| 342 | + |
| 343 | + /// Chop of thie first \p n dims, and keep the remaining \p m |
| 344 | + /// dims. |
| 345 | + VectorDims slice(size_t n, size_t m) const { |
| 346 | + ArrayRef<int64_t> newSizes = sizes.slice(n, m); |
| 347 | + ArrayRef<bool> newScalableDims = |
| 348 | + scalableDims.empty() ? ArrayRef<bool>{} : scalableDims.slice(n, m); |
| 349 | + return VectorDims(newSizes, newScalableDims); |
| 350 | + } |
| 351 | + |
| 352 | + /// Drop the first \p n dims. |
| 353 | + VectorDims dropFront(size_t n = 1) const { return slice(n, size() - n); } |
| 354 | + |
| 355 | + /// Drop the last \p n dims. |
| 356 | + VectorDims dropBack(size_t n = 1) const { return slice(0, size() - n); } |
| 357 | + |
| 358 | + /// Return a copy of *this with only the first \p n elements. |
| 359 | + VectorDims takeFront(size_t n = 1) const { |
| 360 | + if (n >= size()) |
| 361 | + return *this; |
| 362 | + return dropBack(size() - n); |
| 363 | + } |
| 364 | + |
| 365 | + /// Return a copy of *this with only the last \p n elements. |
| 366 | + VectorDims takeBack(size_t n = 1) const { |
| 367 | + if (n >= size()) |
| 368 | + return *this; |
| 369 | + return dropFront(size() - n); |
| 370 | + } |
| 371 | + |
| 372 | + /// Return copy of *this with the first n dims matching the predicate removed. |
| 373 | + template <class PredicateT> |
| 374 | + VectorDims dropWhile(PredicateT predicate) const { |
| 375 | + return VectorDims(llvm::find_if_not(*this, predicate), end()); |
| 376 | + } |
| 377 | + |
| 378 | + /// Returns true if one or more of the dims are scalable. |
| 379 | + bool hasScalableDims() const { |
| 380 | + return llvm::is_contained(getScalableDims(), true); |
| 381 | + } |
| 382 | + |
| 383 | + /// Check for dim equality. |
| 384 | + bool equals(VectorDims rhs) const { |
| 385 | + if (size() != rhs.size()) |
| 386 | + return false; |
| 387 | + return std::equal(begin(), end(), rhs.begin()); |
| 388 | + } |
| 389 | + |
| 390 | + /// Check for dim equality. |
| 391 | + bool equals(ArrayRef<VectorDim> rhs) const { |
| 392 | + if (size() != rhs.size()) |
| 393 | + return false; |
| 394 | + return std::equal(begin(), end(), rhs.begin()); |
| 395 | + } |
| 396 | + |
| 397 | + /// Return the underlying sizes. |
| 398 | + ArrayRef<int64_t> getSizes() const { return sizes; } |
| 399 | + |
| 400 | + /// Return the underlying scalable dims. |
| 401 | + ArrayRef<bool> getScalableDims() const { return scalableDims; } |
| 402 | +}; |
| 403 | + |
| 404 | +inline bool operator==(VectorDims lhs, VectorDims rhs) { |
| 405 | + return lhs.equals(rhs); |
| 406 | +} |
| 407 | + |
| 408 | +inline bool operator!=(VectorDims lhs, VectorDims rhs) { return !(lhs == rhs); } |
| 409 | + |
| 410 | +inline bool operator==(VectorDims lhs, ArrayRef<VectorDim> rhs) { |
| 411 | + return lhs.equals(rhs); |
| 412 | +} |
| 413 | + |
| 414 | +inline bool operator!=(VectorDims lhs, ArrayRef<VectorDim> rhs) { |
| 415 | + return !(lhs == rhs); |
| 416 | +} |
| 417 | + |
184 | 418 | } // namespace mlir
|
185 | 419 |
|
186 | 420 | //===----------------------------------------------------------------------===//
|
|
0 commit comments