Skip to content

Commit 80816e7

Browse files
authored
[mlir][LLVM] handle ArrayAttr for constant array of structs (#139724)
While LLVM IR dialect has a way to represent arbitrary LLVM constant array of structs via an insert chain, it is in practice very expensive for the compilation time as soon as the array is bigger than a couple hundred elements. This is because generating and later folding such insert chain is really not cheap. This patch allows representing array of struct constants via ArrayAttr in the LLVM dialect.
1 parent 034eaed commit 80816e7

File tree

6 files changed

+179
-19
lines changed

6 files changed

+179
-19
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2073,9 +2073,9 @@ def LLVM_ConstantOp
20732073
Unlike LLVM IR, MLIR does not have first-class constant values. Therefore,
20742074
all constants must be created as SSA values before being used in other
20752075
operations. `llvm.mlir.constant` creates such values for scalars, vectors,
2076-
strings, and structs. It has a mandatory `value` attribute whose type
2077-
depends on the type of the constant value. The type of the constant value
2078-
must correspond to the attribute type converted to LLVM IR type.
2076+
strings, structs, and array of structs. It has a mandatory `value` attribute
2077+
whose type depends on the type of the constant value. The type of the constant
2078+
value must correspond to the attribute type converted to LLVM IR type.
20792079

20802080
When creating constant scalars, the `value` attribute must be either an
20812081
integer attribute or a floating point attribute. The type of the attribute
@@ -2097,6 +2097,11 @@ def LLVM_ConstantOp
20972097
must correspond to the type of the corresponding attribute element converted
20982098
to LLVM IR.
20992099

2100+
When creating an array of structs, the `value` attribute must be an array
2101+
attribute, itself containing zero, or undef, or array attributes for each
2102+
potential nested array type, and the elements of the leaf array attributes
2103+
for must match the struct element types or be zero or undef attributes.
2104+
21002105
Examples:
21012106

21022107
```mlir

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 83 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3142,6 +3142,74 @@ static bool hasScalableVectorType(Type t) {
31423142
return false;
31433143
}
31443144

3145+
/// Verifies the constant array represented by `arrayAttr` matches the provided
3146+
/// `arrayType`.
3147+
static LogicalResult verifyStructArrayConstant(LLVM::ConstantOp op,
3148+
LLVM::LLVMArrayType arrayType,
3149+
ArrayAttr arrayAttr, int dim) {
3150+
if (arrayType.getNumElements() != arrayAttr.size())
3151+
return op.emitOpError()
3152+
<< "array attribute size does not match array type size in "
3153+
"dimension "
3154+
<< dim << ": " << arrayAttr.size() << " vs. "
3155+
<< arrayType.getNumElements();
3156+
3157+
llvm::DenseSet<Attribute> elementsVerified;
3158+
3159+
// Recursively verify sub-dimensions for multidimensional arrays.
3160+
if (auto subArrayType =
3161+
dyn_cast<LLVM::LLVMArrayType>(arrayType.getElementType())) {
3162+
for (auto [idx, elementAttr] : llvm::enumerate(arrayAttr))
3163+
if (elementsVerified.insert(elementAttr).second) {
3164+
if (isa<LLVM::ZeroAttr, LLVM::UndefAttr>(elementAttr))
3165+
continue;
3166+
auto subArrayAttr = dyn_cast<ArrayAttr>(elementAttr);
3167+
if (!subArrayAttr)
3168+
return op.emitOpError()
3169+
<< "nested attribute for sub-array in dimension " << dim
3170+
<< " at index " << idx
3171+
<< " must be a zero, or undef, or array attribute";
3172+
if (failed(verifyStructArrayConstant(op, subArrayType, subArrayAttr,
3173+
dim + 1)))
3174+
return failure();
3175+
}
3176+
return success();
3177+
}
3178+
3179+
// Forbid usages of ArrayAttr for simple array types that should use
3180+
// DenseElementsAttr instead. Note that there would be a use case for such
3181+
// array types when one element value is obtained via a ptr-to-int conversion
3182+
// from a symbol and cannot be represented in a DenseElementsAttr, but no MLIR
3183+
// user needs this so far, and it seems better to avoid people misusing the
3184+
// ArrayAttr for simple types.
3185+
auto structType = dyn_cast<LLVM::LLVMStructType>(arrayType.getElementType());
3186+
if (!structType)
3187+
return op.emitOpError() << "for array with an array attribute must have a "
3188+
"struct element type";
3189+
3190+
// Shallow verification that leaf attributes are appropriate as struct initial
3191+
// value.
3192+
size_t numStructElements = structType.getBody().size();
3193+
for (auto [idx, elementAttr] : llvm::enumerate(arrayAttr)) {
3194+
if (elementsVerified.insert(elementAttr).second) {
3195+
if (isa<LLVM::ZeroAttr, LLVM::UndefAttr>(elementAttr))
3196+
continue;
3197+
auto subArrayAttr = dyn_cast<ArrayAttr>(elementAttr);
3198+
if (!subArrayAttr)
3199+
return op.emitOpError()
3200+
<< "nested attribute for struct element at index " << idx
3201+
<< " must be a zero, or undef, or array attribute";
3202+
if (subArrayAttr.size() != numStructElements)
3203+
return op.emitOpError()
3204+
<< "nested array attribute size for struct element at index "
3205+
<< idx << " must match struct size: " << subArrayAttr.size()
3206+
<< " vs. " << numStructElements;
3207+
}
3208+
}
3209+
3210+
return success();
3211+
}
3212+
31453213
LogicalResult LLVM::ConstantOp::verify() {
31463214
if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue())) {
31473215
auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType());
@@ -3208,7 +3276,7 @@ LogicalResult LLVM::ConstantOp::verify() {
32083276
if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) {
32093277
return emitOpError() << "expected integer type of width " << floatWidth;
32103278
}
3211-
} else if (isa<ElementsAttr, ArrayAttr>(getValue())) {
3279+
} else if (isa<ElementsAttr>(getValue())) {
32123280
if (hasScalableVectorType(getType())) {
32133281
// The exact number of elements of a scalable vector is unknown, so we
32143282
// allow only splat attributes.
@@ -3221,15 +3289,20 @@ LogicalResult LLVM::ConstantOp::verify() {
32213289
if (!isa<VectorType, LLVM::LLVMArrayType>(getType()))
32223290
return emitOpError() << "expected vector or array type";
32233291
// The number of elements of the attribute and the type must match.
3224-
int64_t attrNumElements;
3225-
if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue()))
3226-
attrNumElements = elementsAttr.getNumElements();
3227-
else
3228-
attrNumElements = cast<ArrayAttr>(getValue()).size();
3229-
if (getNumElements(getType()) != attrNumElements)
3230-
return emitOpError()
3231-
<< "type and attribute have a different number of elements: "
3232-
<< getNumElements(getType()) << " vs. " << attrNumElements;
3292+
if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue())) {
3293+
int64_t attrNumElements = elementsAttr.getNumElements();
3294+
if (getNumElements(getType()) != attrNumElements)
3295+
return emitOpError()
3296+
<< "type and attribute have a different number of elements: "
3297+
<< getNumElements(getType()) << " vs. " << attrNumElements;
3298+
}
3299+
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) {
3300+
auto arrayType = dyn_cast<LLVM::LLVMArrayType>(getType());
3301+
if (!arrayType)
3302+
return emitOpError() << "expected array type";
3303+
// When the attribute is an ArrayAttr, check that its nesting matches the
3304+
// corresponding ArrayType or VectorType nesting.
3305+
return verifyStructArrayConstant(*this, arrayType, arrayAttr, /*dim=*/0);
32333306
} else {
32343307
return emitOpError()
32353308
<< "only supports integer, float, string or elements attributes";

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,8 +553,10 @@ static llvm::Constant *convertDenseResourceElementsAttr(
553553
llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
554554
llvm::Type *llvmType, Attribute attr, Location loc,
555555
const ModuleTranslation &moduleTranslation) {
556-
if (!attr)
556+
if (!attr || isa<UndefAttr>(attr))
557557
return llvm::UndefValue::get(llvmType);
558+
if (isa<ZeroAttr>(attr))
559+
return llvm::Constant::getNullValue(llvmType);
558560
if (auto *structType = dyn_cast<::llvm::StructType>(llvmType)) {
559561
auto arrayAttr = dyn_cast<ArrayAttr>(attr);
560562
if (!arrayAttr) {
@@ -713,6 +715,33 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
713715
ArrayRef<char>{stringAttr.getValue().data(),
714716
stringAttr.getValue().size()});
715717
}
718+
719+
// Handle arrays of structs that cannot be represented as DenseElementsAttr
720+
// in MLIR.
721+
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
722+
if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) {
723+
llvm::Type *elementType = arrayTy->getElementType();
724+
Attribute previousElementAttr;
725+
llvm::Constant *elementCst = nullptr;
726+
SmallVector<llvm::Constant *> constants;
727+
constants.reserve(arrayTy->getNumElements());
728+
for (Attribute elementAttr : arrayAttr) {
729+
// Arrays with a single value or with repeating values are quite common.
730+
// Short-circuit the translation when the element value is the same as
731+
// the previous one.
732+
if (!previousElementAttr || previousElementAttr != elementAttr) {
733+
previousElementAttr = elementAttr;
734+
elementCst =
735+
getLLVMConstant(elementType, elementAttr, loc, moduleTranslation);
736+
if (!elementCst)
737+
return nullptr;
738+
}
739+
constants.push_back(elementCst);
740+
}
741+
return llvm::ConstantArray::get(arrayTy, constants);
742+
}
743+
}
744+
716745
emitError(loc, "unsupported constant value");
717746
return nullptr;
718747
}

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1850,3 +1850,35 @@ llvm.func @gep_inbounds_flag_usage(%ptr: !llvm.ptr, %idx: i64) {
18501850
llvm.getelementptr inbounds_flag %ptr[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
18511851
llvm.return
18521852
}
1853+
1854+
// -----
1855+
1856+
llvm.mlir.global @bad_struct_array_init_size() : !llvm.array<2x!llvm.struct<(i32, f32)>> {
1857+
// expected-error@below {{'llvm.mlir.constant' op array attribute size does not match array type size in dimension 0: 1 vs. 2}}
1858+
%0 = llvm.mlir.constant([[42 : i32, 1.000000e+00 : f32]]) : !llvm.array<2x!llvm.struct<(i32, f32)>>
1859+
llvm.return %0 : !llvm.array<2x!llvm.struct<(i32, f32)>>
1860+
}
1861+
1862+
// -----
1863+
1864+
llvm.mlir.global @bad_struct_array_init_nesting() : !llvm.array<1x!llvm.array<1x!llvm.array<1x!llvm.struct<(i32)>>>> {
1865+
// expected-error@below {{'llvm.mlir.constant' op nested attribute for sub-array in dimension 1 at index 0 must be a zero, or undef, or array attribute}}
1866+
%0 = llvm.mlir.constant([[1 : i32]]) : !llvm.array<1x!llvm.array<1x!llvm.array<1x!llvm.struct<(i32)>>>>
1867+
llvm.return %0 : !llvm.array<1x!llvm.array<1x!llvm.array<1x!llvm.struct<(i32)>>>>
1868+
}
1869+
1870+
// -----
1871+
1872+
llvm.mlir.global @bad_struct_array_init_elements() : !llvm.array<1x!llvm.struct<(i32, f32)>> {
1873+
// expected-error@below {{'llvm.mlir.constant' op nested array attribute size for struct element at index 0 must match struct size: 1 vs. 2}}
1874+
%0 = llvm.mlir.constant([[1 : i32]]) : !llvm.array<1x!llvm.struct<(i32, f32)>>
1875+
llvm.return %0 : !llvm.array<1x!llvm.struct<(i32, f32)>>
1876+
}
1877+
1878+
// ----
1879+
1880+
llvm.mlir.global internal constant @bad_array_attr_simple_type() : !llvm.array<2 x f64> {
1881+
// expected-error@below {{'llvm.mlir.constant' op for array with an array attribute must have a struct element type}}
1882+
%0 = llvm.mlir.constant([2.5, 7.4]) : !llvm.array<2 x f64>
1883+
llvm.return %0 : !llvm.array<2 x f64>
1884+
}

mlir/test/Target/LLVMIR/llvmir-invalid.mlir

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,6 @@ llvm.func @incompatible_integer_type_for_float_attr() -> i32 {
7979

8080
// -----
8181

82-
// expected-error @below{{unsupported constant value}}
83-
llvm.mlir.global internal constant @test([2.5, 7.4]) : !llvm.array<2 x f64>
84-
85-
// -----
86-
8782
// expected-error @below{{LLVM attribute 'readonly' does not expect a value}}
8883
llvm.func @passthrough_unexpected_value() attributes {passthrough = [["readonly", "42"]]}
8984

mlir/test/Target/LLVMIR/llvmir.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3022,3 +3022,29 @@ llvm.func internal @i(%arg0: i32) attributes {dso_local} {
30223022
llvm.call @testfn3(%arg0) : (i32 {llvm.alignstack = 8 : i64}) -> ()
30233023
llvm.return
30243024
}
3025+
3026+
// -----
3027+
3028+
// CHECK: @test_array_attr_2 = global [2 x { i32, float }] [{ i32, float } { i32 42, float 1.000000e+00 }, { i32, float } { i32 42, float 1.000000e+00 }]
3029+
llvm.mlir.global @test_array_attr_2() : !llvm.array<2 x !llvm.struct<(i32, f32)>> {
3030+
%0 = llvm.mlir.constant([[42 : i32, 1.000000e+00 : f32],[42 : i32, 1.000000e+00 : f32]]) : !llvm.array<2 x !llvm.struct<(i32, f32)>>
3031+
llvm.return %0 : !llvm.array<2 x !llvm.struct<(i32, f32)>>
3032+
}
3033+
3034+
// CHECK: @test_array_attr_3 = global [2 x [3 x { i32, float }]{{.*}}[3 x { i32, float }] [{ i32, float } { i32 1, float 1.000000e+00 }, { i32, float } { i32 2, float 1.000000e+00 }, { i32, float } { i32 3, float 1.000000e+00 }], [3 x { i32, float }] [{ i32, float } { i32 4, float 1.000000e+00 }, { i32, float } { i32 5, float 1.000000e+00 }, { i32, float } { i32 6, float 1.000000e+00 }
3035+
llvm.mlir.global @test_array_attr_3() : !llvm.array<2 x !llvm.array<3 x !llvm.struct<(i32, f32)>>> {
3036+
%0 = llvm.mlir.constant([[[1 : i32, 1.000000e+00 : f32], [2 : i32, 1.000000e+00 : f32], [3 : i32, 1.000000e+00 : f32]], [[4 : i32, 1.000000e+00 : f32], [5 : i32, 1.000000e+00 : f32], [6 : i32, 1.000000e+00 : f32]]]) : !llvm.array<2 x !llvm.array<3 x !llvm.struct<(i32, f32)>>>
3037+
llvm.return %0 : !llvm.array<2 x !llvm.array<3 x !llvm.struct<(i32, f32)>>>
3038+
}
3039+
3040+
// CHECK: @test_array_attr_struct_with_ptr = internal constant [2 x { ptr }] [{ ptr } zeroinitializer, { ptr } undef]
3041+
llvm.mlir.global internal constant @test_array_attr_struct_with_ptr() : !llvm.array<2 x struct<(ptr)>> {
3042+
%0 = llvm.mlir.constant([[#llvm.zero], [#llvm.undef]]) : !llvm.array<2 x struct<(ptr)>>
3043+
llvm.return %0 : !llvm.array<2 x struct<(ptr)>>
3044+
}
3045+
3046+
// CHECK: @test_array_attr_struct_with_struct = internal constant [3 x { i32, float }] [{ i32, float } zeroinitializer, { i32, float } { i32 2, float 1.000000e+00 }, { i32, float } undef]
3047+
llvm.mlir.global internal constant @test_array_attr_struct_with_struct() : !llvm.array<3 x struct<(i32, f32)>> {
3048+
%0 = llvm.mlir.constant([#llvm.zero, [2 : i32, 1.0 : f32], #llvm.undef]) : !llvm.array<3 x struct<(i32, f32)>>
3049+
llvm.return %0 : !llvm.array<3 x struct<(i32, f32)>>
3050+
}

0 commit comments

Comments
 (0)