Skip to content

Commit df411fb

Browse files
authored
[MLIR][DataLayout] Add support for scalable vectors (#89349)
This commit extends the data layout to support scalable vectors. For scalable vectors, the `TypeSize`'s scalable field is set accordingly, and the alignment information remains the same as for normal vectors. This behavior is in sync with what LLVM's data layout queries are producing. Before this change, scalable vectors incorrectly returned the same size as "normal" vectors.
1 parent 4d7f3d9 commit df411fb

File tree

3 files changed

+36
-8
lines changed

3 files changed

+36
-8
lines changed

mlir/lib/Interfaces/DataLayoutInterfaces.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,12 @@ mlir::detail::getDefaultTypeSizeInBits(Type type, const DataLayout &dataLayout,
7575
// there is no bit-packing at the moment element sizes are taken in bytes and
7676
// multiplied with 8 bits.
7777
// TODO: make this extensible.
78-
if (auto vecType = dyn_cast<VectorType>(type))
79-
return vecType.getNumElements() / vecType.getShape().back() *
80-
llvm::PowerOf2Ceil(vecType.getShape().back()) *
81-
dataLayout.getTypeSize(vecType.getElementType()) * 8;
78+
if (auto vecType = dyn_cast<VectorType>(type)) {
79+
uint64_t baseSize = vecType.getNumElements() / vecType.getShape().back() *
80+
llvm::PowerOf2Ceil(vecType.getShape().back()) *
81+
dataLayout.getTypeSize(vecType.getElementType()) * 8;
82+
return llvm::TypeSize::get(baseSize, vecType.isScalable());
83+
}
8284

8385
if (auto typeInterface = dyn_cast<DataLayoutTypeInterface>(type))
8486
return typeInterface.getTypeSizeInBits(dataLayout, params);
@@ -138,9 +140,10 @@ getFloatTypeABIAlignment(FloatType fltType, const DataLayout &dataLayout,
138140
uint64_t mlir::detail::getDefaultABIAlignment(
139141
Type type, const DataLayout &dataLayout,
140142
ArrayRef<DataLayoutEntryInterface> params) {
141-
// Natural alignment is the closest power-of-two number above.
143+
// Natural alignment is the closest power-of-two number above. For scalable
144+
// vectors, aligning them to the same as the base vector is sufficient.
142145
if (isa<VectorType>(type))
143-
return llvm::PowerOf2Ceil(dataLayout.getTypeSize(type));
146+
return llvm::PowerOf2Ceil(dataLayout.getTypeSize(type).getKnownMinValue());
144147

145148
if (auto fltType = dyn_cast<FloatType>(type))
146149
return getFloatTypeABIAlignment(fltType, dataLayout, params);

mlir/test/Interfaces/DataLayoutInterfaces/query.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,18 @@ func.func @no_layout_builtin() {
3232
// CHECK: preferred = 8
3333
// CHECK: size = 8
3434
"test.data_layout_query"() : () -> index
35+
// CHECK: alignment = 16
36+
// CHECK: bitsize = 128
37+
// CHECK: index = 0
38+
// CHECK: preferred = 16
39+
// CHECK: size = 16
40+
"test.data_layout_query"() : () -> vector<4xi32>
41+
// CHECK: alignment = 16
42+
// CHECK: bitsize = {minimal_size = 128 : index, scalable}
43+
// CHECK: index = 0
44+
// CHECK: preferred = 16
45+
// CHECK: size = {minimal_size = 16 : index, scalable}
46+
"test.data_layout_query"() : () -> vector<[4]xi32>
3547
return
3648

3749
}

mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,22 @@ struct TestDataLayoutQuery
4646
Attribute programMemorySpace = layout.getProgramMemorySpace();
4747
Attribute globalMemorySpace = layout.getGlobalMemorySpace();
4848
uint64_t stackAlignment = layout.getStackAlignment();
49+
50+
auto convertTypeSizeToAttr = [&](llvm::TypeSize typeSize) -> Attribute {
51+
if (!typeSize.isScalable())
52+
return builder.getIndexAttr(typeSize);
53+
54+
return builder.getDictionaryAttr({
55+
builder.getNamedAttr("scalable", builder.getUnitAttr()),
56+
builder.getNamedAttr(
57+
"minimal_size",
58+
builder.getIndexAttr(typeSize.getKnownMinValue())),
59+
});
60+
};
61+
4962
op->setAttrs(
50-
{builder.getNamedAttr("size", builder.getIndexAttr(size)),
51-
builder.getNamedAttr("bitsize", builder.getIndexAttr(bitsize)),
63+
{builder.getNamedAttr("size", convertTypeSizeToAttr(size)),
64+
builder.getNamedAttr("bitsize", convertTypeSizeToAttr(bitsize)),
5265
builder.getNamedAttr("alignment", builder.getIndexAttr(alignment)),
5366
builder.getNamedAttr("preferred", builder.getIndexAttr(preferred)),
5467
builder.getNamedAttr("index", builder.getIndexAttr(index)),

0 commit comments

Comments
 (0)