Skip to content

Commit ca8250c

Browse files
committed
[mlir] Add first-class support for scalability in VectorType dims
Currently, the shape of a VectorType is stored in two separate lists. The 'shape' which comes from ShapedType, which does not have a way to represent scalability, and the 'scalableDims', an additional list of bools attached to VectorType. This can be somewhat cumbersome to work with, and easy to ignore the scalability of a dim, producing incorrect results. For example, to correctly trim leading unit dims of a VectorType, currently, you need to do something like: ```c++ while (!newShape.empty() && newShape.front() == 1 && !newScalableDims.front()) { newShape = newShape.drop_front(1); newScalableDims = newScalableDims.drop_front(1); } ``` Which would be wrong if you (more naturally) wrote it as: ```c++ auto newShape = vectorType.getShape().drop_while([](int64_t dim) { return dim == 1; }); ``` As this would trim scalable one dims (`[1]`), which are not unit dims like their fixed counterpart. This patch does not change the storage of the VectorType, but instead adds new scalability-safe accessors and iterators. Two new methods are added to VectorType: ``` /// Returns the value of the specified dimension (including scalability) VectorDim VectorType::getDim(unsigned idx) /// Returns the dimensions of this vector type (including scalability) VectorDims VectorType::getDims() ``` These are backed by two new classes: `VectorDim` and `VectorDims`. `VectorDim` represents a single dimension of a VectorType. It can be a fixed or scalable quantity. It cannot be implicitly converted to/from an integer, so you must specify the kind of quantity you expect in comparisons. `VectorDims` represents a non-owning list of vector dimensions, backed by separate size and scalability lists (matching the storage of VectorType). This class has an iterator, and a few common helper methods (similar to that of ArrayRef). There are also new builders to construct VectorTypes from both the `VectorDims` class and an `ArrayRef<VectorDim>`. With these changes the previous example becomes: ```c++ auto newDims = vectorType.getDims().dropWhile([](VectorDim dim) { return dim == VectorDim::getFixed(1); }); ``` Which (to me) is easier to read, and safer as it is not possible to forget check the scalability of the dim. Just comparing with `1`, would fail to build.
1 parent 6ea3344 commit ca8250c

File tree

3 files changed

+358
-0
lines changed

3 files changed

+358
-0
lines changed

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/IR/BuiltinAttributeInterfaces.h"
1313
#include "mlir/IR/BuiltinTypeInterfaces.h"
1414
#include "mlir/Support/ADTExtras.h"
15+
#include "llvm/ADT/STLExtras.h"
1516

1617
namespace llvm {
1718
class BitVector;
@@ -181,6 +182,239 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
181182
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
182183
};
183184

185+
//===----------------------------------------------------------------------===//
186+
// VectorDim
187+
//===----------------------------------------------------------------------===//
188+
189+
/// This class represents a dimension of a vector type. Unlike other ShapedTypes
190+
/// vector dimensions can have scalable quantities, which means the dimension
191+
/// has a known minimum size, which is scaled by a constant that is only
192+
/// known at runtime.
193+
class VectorDim {
194+
public:
195+
explicit constexpr VectorDim(int64_t quantity, bool scalable)
196+
: quantity(quantity), scalable(scalable){};
197+
198+
/// Constructs a new fixed dimension.
199+
constexpr static VectorDim getFixed(int64_t quantity) {
200+
return VectorDim(quantity, false);
201+
}
202+
203+
/// Constructs a new scalable dimension.
204+
constexpr static VectorDim getScalable(int64_t quantity) {
205+
return VectorDim(quantity, true);
206+
}
207+
208+
/// Returns true if this dimension is scalable;
209+
constexpr bool isScalable() const { return scalable; }
210+
211+
/// Returns true if this dimension is fixed.
212+
constexpr bool isFixed() const { return !isScalable(); }
213+
214+
/// Returns the minimum number of elements this dimension can contain.
215+
constexpr int64_t getMinSize() const { return quantity; }
216+
217+
/// If this dimension is fixed returns the number of elements, otherwise
218+
/// aborts.
219+
constexpr int64_t getFixedSize() const {
220+
assert(isFixed());
221+
return quantity;
222+
}
223+
224+
constexpr bool operator==(VectorDim const &dim) const {
225+
return quantity == dim.quantity && scalable == dim.scalable;
226+
}
227+
228+
constexpr bool operator!=(VectorDim const &dim) const {
229+
return !(*this == dim);
230+
}
231+
232+
/// Print the dim.
233+
void print(raw_ostream &os) {
234+
if (isScalable())
235+
os << '[';
236+
os << getMinSize();
237+
if (isScalable())
238+
os << ']';
239+
}
240+
241+
/// Helper class for indexing into a list of sizes (and possibly empty) list
242+
/// of scalable dimensions, extracting VectorDim elements.
243+
struct Indexer {
244+
explicit Indexer(ArrayRef<int64_t> sizes, ArrayRef<bool> scalableDims)
245+
: sizes(sizes), scalableDims(scalableDims) {
246+
assert(
247+
scalableDims.empty() ||
248+
sizes.size() == scalableDims.size() &&
249+
"expected `scalableDims` to be empty or match `sizes` in length");
250+
}
251+
252+
VectorDim operator[](size_t idx) const {
253+
int64_t size = sizes[idx];
254+
bool scalable = scalableDims.empty() ? false : scalableDims[idx];
255+
return VectorDim(size, scalable);
256+
}
257+
258+
ArrayRef<int64_t> sizes;
259+
ArrayRef<bool> scalableDims;
260+
};
261+
262+
private:
263+
int64_t quantity;
264+
bool scalable;
265+
};
266+
267+
inline raw_ostream &operator<<(raw_ostream &os, VectorDim dim) {
268+
dim.print(os);
269+
return os;
270+
}
271+
272+
//===----------------------------------------------------------------------===//
273+
// VectorDims
274+
//===----------------------------------------------------------------------===//
275+
276+
/// Represents a non-owning list of vector dimensions. The underlying dimension
277+
/// sizes and scalability flags are stored a two seperate lists to match the
278+
/// storage of a VectorType.
279+
class VectorDims : public VectorDim::Indexer {
280+
public:
281+
using VectorDim::Indexer::Indexer;
282+
283+
class Iterator : public llvm::iterator_facade_base<
284+
Iterator, std::random_access_iterator_tag, VectorDim,
285+
std::ptrdiff_t, VectorDim, VectorDim> {
286+
public:
287+
Iterator(VectorDim::Indexer indexer, size_t index)
288+
: indexer(indexer), index(index){};
289+
290+
// Iterator boilerplate.
291+
ptrdiff_t operator-(const Iterator &rhs) const { return index - rhs.index; }
292+
bool operator==(const Iterator &rhs) const { return index == rhs.index; }
293+
bool operator<(const Iterator &rhs) const { return index < rhs.index; }
294+
Iterator &operator+=(ptrdiff_t offset) {
295+
index += offset;
296+
return *this;
297+
}
298+
Iterator &operator-=(ptrdiff_t offset) {
299+
index -= offset;
300+
return *this;
301+
}
302+
VectorDim operator*() const { return indexer[index]; }
303+
304+
VectorDim::Indexer getIndexer() const { return indexer; }
305+
ptrdiff_t getIndex() const { return index; }
306+
307+
private:
308+
VectorDim::Indexer indexer;
309+
ptrdiff_t index;
310+
};
311+
312+
// Generic definitions.
313+
using value_type = VectorDim;
314+
using iterator = Iterator;
315+
using const_iterator = Iterator;
316+
using reverse_iterator = std::reverse_iterator<iterator>;
317+
using const_reverse_iterator = std::reverse_iterator<const_iterator>;
318+
using size_type = size_t;
319+
using difference_type = ptrdiff_t;
320+
321+
/// Construct from iterator pair.
322+
VectorDims(Iterator begin, Iterator end)
323+
: VectorDims(VectorDims(begin.getIndexer())
324+
.slice(begin.getIndex(), end - begin)) {}
325+
326+
VectorDims(VectorDim::Indexer indexer) : VectorDim::Indexer(indexer){};
327+
328+
Iterator begin() const { return Iterator(*this, 0); }
329+
Iterator end() const { return Iterator(*this, size()); }
330+
331+
/// Check if the dims are empty.
332+
bool empty() const { return sizes.empty(); }
333+
334+
/// Get the number of dims.
335+
size_t size() const { return sizes.size(); }
336+
337+
/// Return the first dim.
338+
VectorDim front() const { return (*this)[0]; }
339+
340+
/// Return the last dim.
341+
VectorDim back() const { return (*this)[size() - 1]; }
342+
343+
/// Chop of thie first \p n dims, and keep the remaining \p m
344+
/// dims.
345+
VectorDims slice(size_t n, size_t m) const {
346+
ArrayRef<int64_t> newSizes = sizes.slice(n, m);
347+
ArrayRef<bool> newScalableDims =
348+
scalableDims.empty() ? ArrayRef<bool>{} : scalableDims.slice(n, m);
349+
return VectorDims(newSizes, newScalableDims);
350+
}
351+
352+
/// Drop the first \p n dims.
353+
VectorDims dropFront(size_t n = 1) const { return slice(n, size() - n); }
354+
355+
/// Drop the last \p n dims.
356+
VectorDims dropBack(size_t n = 1) const { return slice(0, size() - n); }
357+
358+
/// Return a copy of *this with only the first \p n elements.
359+
VectorDims takeFront(size_t n = 1) const {
360+
if (n >= size())
361+
return *this;
362+
return dropBack(size() - n);
363+
}
364+
365+
/// Return a copy of *this with only the last \p n elements.
366+
VectorDims takeBack(size_t n = 1) const {
367+
if (n >= size())
368+
return *this;
369+
return dropFront(size() - n);
370+
}
371+
372+
/// Return copy of *this with the first n dims matching the predicate removed.
373+
template <class PredicateT>
374+
VectorDims dropWhile(PredicateT predicate) const {
375+
return VectorDims(llvm::find_if_not(*this, predicate), end());
376+
}
377+
378+
/// Returns true if one or more of the dims are scalable.
379+
bool hasScalableDims() const {
380+
return llvm::is_contained(getScalableDims(), true);
381+
}
382+
383+
/// Check for dim equality.
384+
bool equals(VectorDims rhs) const {
385+
if (size() != rhs.size())
386+
return false;
387+
return std::equal(begin(), end(), rhs.begin());
388+
}
389+
390+
/// Check for dim equality.
391+
bool equals(ArrayRef<VectorDim> rhs) const {
392+
if (size() != rhs.size())
393+
return false;
394+
return std::equal(begin(), end(), rhs.begin());
395+
}
396+
397+
/// Return the underlying sizes.
398+
ArrayRef<int64_t> getSizes() const { return sizes; }
399+
400+
/// Return the underlying scalable dims.
401+
ArrayRef<bool> getScalableDims() const { return scalableDims; }
402+
};
403+
404+
inline bool operator==(VectorDims lhs, VectorDims rhs) {
405+
return lhs.equals(rhs);
406+
}
407+
408+
inline bool operator!=(VectorDims lhs, VectorDims rhs) { return !(lhs == rhs); }
409+
410+
inline bool operator==(VectorDims lhs, ArrayRef<VectorDim> rhs) {
411+
return lhs.equals(rhs);
412+
}
413+
414+
inline bool operator!=(VectorDims lhs, ArrayRef<VectorDim> rhs) {
415+
return !(lhs == rhs);
416+
}
417+
184418
} // namespace mlir
185419

186420
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,13 +1114,36 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
11141114
scalableDims = isScalableVec;
11151115
}
11161116
return $_get(elementType.getContext(), shape, elementType, scalableDims);
1117+
}]>,
1118+
TypeBuilderWithInferredContext<(ins "Type":$elementType, "ArrayRef<VectorDim>": $shape), [{
1119+
SmallVector<int64_t> sizes;
1120+
SmallVector<bool> scalableDims;
1121+
for (VectorDim dim : shape) {
1122+
sizes.push_back(dim.getMinSize());
1123+
scalableDims.push_back(dim.isScalable());
1124+
}
1125+
return get(sizes, elementType, scalableDims);
1126+
}]>,
1127+
TypeBuilderWithInferredContext<(ins "Type":$elementType, "VectorDims": $shape), [{
1128+
return get(shape.getSizes(), elementType, shape.getScalableDims());
11171129
}]>
11181130
];
11191131
let extraClassDeclaration = [{
11201132
/// This is a builder type that keeps local references to arguments.
11211133
/// Arguments that are passed into the builder must outlive the builder.
11221134
class Builder;
11231135

1136+
/// Returns the value of the specified dimension (including scalability).
1137+
VectorDim getDim(unsigned idx) const {
1138+
assert(idx < getRank() && "invalid dim index for vector type");
1139+
return getDims()[idx];
1140+
}
1141+
1142+
/// Returns the dimensions of this vector type (including scalability).
1143+
VectorDims getDims() const {
1144+
return VectorDims(getShape(), getScalableDims());
1145+
}
1146+
11241147
/// Returns true if the given type can be used as an element of a vector
11251148
/// type. In particular, vectors can consist of integer, index, or float
11261149
/// primitives.

mlir/unittests/IR/ShapedTypeTest.cpp

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,4 +226,105 @@ TEST(ShapedTypeTest, RankedTensorTypeBuilder) {
226226
}
227227
}
228228

229+
TEST(ShapedTypeTest, VectorDims) {
230+
MLIRContext context;
231+
Type f32 = FloatType::getF32(&context);
232+
233+
SmallVector<VectorDim> dims{VectorDim::getFixed(2), VectorDim::getScalable(4),
234+
VectorDim::getFixed(8), VectorDim::getScalable(9),
235+
VectorDim::getFixed(1)};
236+
VectorType vectorType = VectorType::get(f32, dims);
237+
238+
// Directly check values
239+
{
240+
auto dim0 = vectorType.getDim(0);
241+
ASSERT_EQ(dim0.getMinSize(), 2);
242+
ASSERT_TRUE(dim0.isFixed());
243+
244+
auto dim1 = vectorType.getDim(1);
245+
ASSERT_EQ(dim1.getMinSize(), 4);
246+
ASSERT_TRUE(dim1.isScalable());
247+
248+
auto dim2 = vectorType.getDim(2);
249+
ASSERT_EQ(dim2.getMinSize(), 8);
250+
ASSERT_TRUE(dim2.isFixed());
251+
252+
auto dim3 = vectorType.getDim(3);
253+
ASSERT_EQ(dim3.getMinSize(), 9);
254+
ASSERT_TRUE(dim3.isScalable());
255+
256+
auto dim4 = vectorType.getDim(4);
257+
ASSERT_EQ(dim4.getMinSize(), 1);
258+
ASSERT_TRUE(dim4.isFixed());
259+
}
260+
261+
// Test indexing via getDim(idx)
262+
{
263+
for (unsigned i = 0; i < dims.size(); i++)
264+
ASSERT_EQ(vectorType.getDim(i), dims[i]);
265+
}
266+
267+
// Test using VectorDims::Iterator in for-each loop
268+
{
269+
unsigned i = 0;
270+
for (VectorDim dim : vectorType.getDims())
271+
ASSERT_EQ(dim, dims[i++]);
272+
ASSERT_EQ(i, vectorType.getRank());
273+
}
274+
275+
// Test using VectorDims::Iterator in LLVM iterator helper
276+
{
277+
for (auto [dim, expectedDim] :
278+
llvm::zip_equal(vectorType.getDims(), dims)) {
279+
ASSERT_EQ(dim, expectedDim);
280+
}
281+
}
282+
283+
// Test dropFront()
284+
{
285+
auto vectorDims = vectorType.getDims();
286+
auto newDims = vectorDims.dropFront();
287+
288+
ASSERT_EQ(newDims.size(), vectorDims.size() - 1);
289+
for (unsigned i = 0; i < newDims.size(); i++)
290+
ASSERT_EQ(newDims[i], vectorDims[i + 1]);
291+
}
292+
293+
// Test dropBack()
294+
{
295+
auto vectorDims = vectorType.getDims();
296+
auto newDims = vectorDims.dropBack();
297+
298+
ASSERT_EQ(newDims.size(), vectorDims.size() - 1);
299+
for (unsigned i = 0; i < newDims.size(); i++)
300+
ASSERT_EQ(newDims[i], vectorDims[i]);
301+
}
302+
303+
// Test front()
304+
{ ASSERT_EQ(vectorType.getDims().front(), VectorDim::getFixed(2)); }
305+
306+
// Test back()
307+
{ ASSERT_EQ(vectorType.getDims().back(), VectorDim::getFixed(1)); }
308+
309+
// Test dropWhile.
310+
{
311+
SmallVector<VectorDim> dims{
312+
VectorDim::getFixed(1), VectorDim::getFixed(1), VectorDim::getFixed(1),
313+
VectorDim::getScalable(1), VectorDim::getScalable(4)};
314+
315+
VectorType vectorTypeWithLeadingUnitDims = VectorType::get(f32, dims);
316+
ASSERT_EQ(vectorTypeWithLeadingUnitDims.getDims().size(),
317+
unsigned(vectorTypeWithLeadingUnitDims.getRank()));
318+
319+
// Drop leading unit dims.
320+
auto withoutLeadingUnitDims =
321+
vectorTypeWithLeadingUnitDims.getDims().dropWhile(
322+
[](VectorDim dim) { return dim == VectorDim::getFixed(1); });
323+
324+
SmallVector<VectorDim> expectedDims{VectorDim::getScalable(1),
325+
VectorDim::getScalable(4)};
326+
ASSERT_EQ(withoutLeadingUnitDims, expectedDims);
327+
}
328+
}
329+
229330
} // namespace

0 commit comments

Comments
 (0)