Skip to content

[mlir][LLVMIR] Check number of elements in mlir.constant verifier #102906

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1623,10 +1623,15 @@ def LLVM_ConstantOp
vectors. It has a mandatory `value` attribute, which may be an integer,
floating point attribute; dense or sparse attribute containing integers or
floats. The type of the attribute is one of the corresponding MLIR builtin
types. It may be omitted for `i64` and `f64` types that are implied. The
operation produces a new SSA value of the specified LLVM IR dialect type.
The type of that value _must_ correspond to the attribute type converted to
LLVM IR.
types. It may be omitted for `i64` and `f64` types that are implied.

The operation produces a new SSA value of the specified LLVM IR dialect
type. Certain builtin types such as integer, float and vector types are
also allowed. The result type _must_ correspond to the attribute type
converted to LLVM IR. In particular, the number of elements of a container
type must match the number of elements in the attribute. If the type is or
contains a scalable vector type, the attribute must be a splat elements
attribute.

Examples:

Expand Down
72 changes: 62 additions & 10 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2666,6 +2666,39 @@ OpFoldResult LLVM::ZeroOp::fold(FoldAdaptor) {
// ConstantOp.
//===----------------------------------------------------------------------===//

/// Compute the total number of elements in the given type, also taking into
/// account nested types. Supported types are `VectorType`, `LLVMArrayType` and
/// `LLVMFixedVectorType`. Everything else is treated as a scalar.
static int64_t getNumElements(Type t) {
if (auto vecType = dyn_cast<VectorType>(t))
return vecType.getNumElements() * getNumElements(vecType.getElementType());
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
return arrayType.getNumElements() *
getNumElements(arrayType.getElementType());
if (auto vecType = dyn_cast<LLVMFixedVectorType>(t))
return vecType.getNumElements() * getNumElements(vecType.getElementType());
assert(!isa<LLVM::LLVMScalableVectorType>(t) &&
"number of elements of a scalable vector type is unknown");
return 1;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if I should take into account scalable vector types here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getVectorNumElements maybe helpful (

llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not take into account nested vectors unfortunately.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right since the verifier needs to look not only at vector types it probably makes sense to have the specialized helpers you added.

}

/// Check if the given type is a scalable vector type or a vector/array type
/// that contains a nested scalable vector type.
static bool hasScalableVectorType(Type t) {
if (isa<LLVM::LLVMScalableVectorType>(t))
return true;
if (auto vecType = dyn_cast<VectorType>(t)) {
if (vecType.isScalable())
return true;
return hasScalableVectorType(vecType.getElementType());
}
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
return hasScalableVectorType(arrayType.getElementType());
if (auto vecType = dyn_cast<LLVMFixedVectorType>(t))
return hasScalableVectorType(vecType.getElementType());
return false;
}

LogicalResult LLVM::ConstantOp::verify() {
if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue())) {
auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType());
Expand Down Expand Up @@ -2708,14 +2741,12 @@ LogicalResult LLVM::ConstantOp::verify() {
if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
return emitOpError() << "does not support target extension type.";
}
if (!llvm::isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>(getValue()))
return emitOpError()
<< "only supports integer, float, string or elements attributes";

// Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr.
if (auto intAttr = dyn_cast<IntegerAttr>(getValue())) {
if (!llvm::isa<IntegerType>(getType()))
return emitOpError() << "expected integer type";
}
if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
} else if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
unsigned floatWidth = APFloat::getSizeInBits(sem);
if (auto floatTy = dyn_cast<FloatType>(getType())) {
Expand All @@ -2728,13 +2759,34 @@ LogicalResult LLVM::ConstantOp::verify() {
if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) {
return emitOpError() << "expected integer type of width " << floatWidth;
}
}
if (auto splatAttr = dyn_cast<SplatElementsAttr>(getValue())) {
if (!isa<VectorType>(getType()) && !isa<LLVM::LLVMArrayType>(getType()) &&
!isa<LLVM::LLVMFixedVectorType>(getType()) &&
!isa<LLVM::LLVMScalableVectorType>(getType()))
} else if (isa<ElementsAttr, ArrayAttr>(getValue())) {
if (hasScalableVectorType(getType())) {
// The exact number of elements of a scalable vector is unknown, so we
// allow only splat attributes.
auto splatElementsAttr = dyn_cast<SplatElementsAttr>(getValue());
if (!splatElementsAttr)
return emitOpError()
<< "scalable vector type requires a splat attribute";
return success();
}
if (!isa<VectorType, LLVM::LLVMArrayType, LLVM::LLVMFixedVectorType>(
getType()))
return emitOpError() << "expected vector or array type";
// The number of elements of the attribute and the type must match.
int64_t attrNumElements;
if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue()))
attrNumElements = elementsAttr.getNumElements();
else
attrNumElements = cast<ArrayAttr>(getValue()).size();
if (getNumElements(getType()) != attrNumElements)
return emitOpError()
<< "type and attribute have a different number of elements: "
<< getNumElements(getType()) << " vs. " << attrNumElements;
} else {
return emitOpError()
<< "only supports integer, float, string or elements attributes";
}

return success();
}

Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Dialect/LLVMIR/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,22 @@ llvm.func @struct_wrong_element_types() -> !llvm.struct<(!llvm.array<2 x f64>, !

// -----

llvm.func @const_wrong_number_of_elements() -> vector<5xf64> {
// expected-error @+1{{type and attribute have a different number of elements: 5 vs. 2}}
%0 = llvm.mlir.constant(dense<[1.0, 1.0]> : tensor<2xf64>) : vector<5xf64>
llvm.return %0 : vector<5xf64>
}

// -----

llvm.func @scalable_vec_requires_splat() -> vector<[4]xf64> {
// expected-error @+1{{scalable vector type requires a splat attribute}}
%0 = llvm.mlir.constant(dense<[1.0, 1.0, 2.0, 2.0]> : tensor<4xf64>) : vector<[4]xf64>
llvm.return %0 : vector<[4]xf64>
}

// -----

func.func @insertvalue_non_llvm_type(%a : i32, %b : i32) {
// expected-error@+2 {{expected LLVM IR Dialect type}}
llvm.insertvalue %a, %b[0] : tensor<*xi32>
Expand Down
8 changes: 7 additions & 1 deletion mlir/test/Target/LLVMIR/llvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1295,11 +1295,17 @@ llvm.func @complexintconstant() -> !llvm.struct<(i32, i32)> {
}

llvm.func @complexintconstantsplat() -> !llvm.array<2 x !llvm.struct<(i32, i32)>> {
%1 = llvm.mlir.constant(dense<(0, 1)> : tensor<complex<i32>>) : !llvm.array<2 x !llvm.struct<(i32, i32)>>
%1 = llvm.mlir.constant(dense<(0, 1)> : tensor<2xcomplex<i32>>) : !llvm.array<2 x !llvm.struct<(i32, i32)>>
// CHECK: ret [2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 0, i32 1 }]
llvm.return %1 : !llvm.array<2 x !llvm.struct<(i32, i32)>>
}

llvm.func @complexintconstantsingle() -> !llvm.array<1 x !llvm.struct<(i32, i32)>> {
%1 = llvm.mlir.constant(dense<(0, 1)> : tensor<complex<i32>>) : !llvm.array<1 x !llvm.struct<(i32, i32)>>
// CHECK: ret [1 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }]
llvm.return %1 : !llvm.array<1 x !llvm.struct<(i32, i32)>>
}

llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>> {
%1 = llvm.mlir.constant(dense<[[(0, 1), (2, 3)], [(4, 5), (6, 7)]]> : tensor<2x2xcomplex<i32>>) : !llvm.array<2 x!llvm.array<2 x !llvm.struct<(i32, i32)>>>
// CHECK{LITERAL}: ret [2 x [2 x { i32, i32 }]] [[2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 2, i32 3 }], [2 x { i32, i32 }] [{ i32, i32 } { i32 4, i32 5 }, { i32, i32 } { i32 6, i32 7 }]]
Expand Down
Loading