-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][LLVMIR] Check number of elements in mlir.constant
verifier
#102906
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][LLVMIR] Check number of elements in mlir.constant
verifier
#102906
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Matthias Springer (matthias-springer) ChangesCheck that the number of elements in the result type and the attribute of an Full diff: https://github.com/llvm/llvm-project/pull/102906.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 90610118a45cd2..82dadfc98de7b6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2666,6 +2666,20 @@ OpFoldResult LLVM::ZeroOp::fold(FoldAdaptor) {
// ConstantOp.
//===----------------------------------------------------------------------===//
+/// Compute the total number of elements in the given type, also taking into
+/// account nested types. Supported types are `VectorType`, `LLVMArrayType` and
+/// `LLVMFixedVectorType`. Everything else is treated as a scalar.
+static int64_t getNumElements(Type t) {
+ if (auto vecType = dyn_cast<VectorType>(t))
+ return vecType.getNumElements() * getNumElements(vecType.getElementType());
+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
+ return arrayType.getNumElements() *
+ getNumElements(arrayType.getElementType());
+ if (auto vecType = dyn_cast<LLVMFixedVectorType>(t))
+ return vecType.getNumElements() * getNumElements(vecType.getElementType());
+ return 1;
+}
+
LogicalResult LLVM::ConstantOp::verify() {
if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue())) {
auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType());
@@ -2708,14 +2722,12 @@ LogicalResult LLVM::ConstantOp::verify() {
if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
return emitOpError() << "does not support target extension type.";
}
- if (!llvm::isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>(getValue()))
- return emitOpError()
- << "only supports integer, float, string or elements attributes";
+
+ // Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr.
if (auto intAttr = dyn_cast<IntegerAttr>(getValue())) {
if (!llvm::isa<IntegerType>(getType()))
return emitOpError() << "expected integer type";
- }
- if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
+ } else if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
unsigned floatWidth = APFloat::getSizeInBits(sem);
if (auto floatTy = dyn_cast<FloatType>(getType())) {
@@ -2728,13 +2740,30 @@ LogicalResult LLVM::ConstantOp::verify() {
if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) {
return emitOpError() << "expected integer type of width " << floatWidth;
}
- }
- if (auto splatAttr = dyn_cast<SplatElementsAttr>(getValue())) {
- if (!isa<VectorType>(getType()) && !isa<LLVM::LLVMArrayType>(getType()) &&
- !isa<LLVM::LLVMFixedVectorType>(getType()) &&
- !isa<LLVM::LLVMScalableVectorType>(getType()))
+ } else if (isa<ElementsAttr, ArrayAttr>(getValue())) {
+ if (isa<LLVM::LLVMScalableVectorType>(getType())) {
+ // The exact number of elements of a scalable vector is unknown, so there
+ // is nothing more to verify.
+ return success();
+ }
+ if (!isa<VectorType, LLVM::LLVMArrayType, LLVM::LLVMFixedVectorType>(
+ getType()))
return emitOpError() << "expected vector or array type";
+ // The number of elements of the attribute and the type must match.
+ int64_t attrNumElements;
+ if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue()))
+ attrNumElements = elementsAttr.getNumElements();
+ else
+ attrNumElements = cast<ArrayAttr>(getValue()).size();
+ if (getNumElements(getType()) != attrNumElements)
+ return emitOpError()
+ << "type and attribute have a different number of elements: "
+ << getNumElements(getType()) << " vs. " << attrNumElements;
+ } else {
+ return emitOpError()
+ << "only supports integer, float, string or elements attributes";
}
+
return success();
}
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index fe288dab973f5a..7edf036201e1c0 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -414,6 +414,14 @@ llvm.func @struct_wrong_element_types() -> !llvm.struct<(!llvm.array<2 x f64>, !
// -----
+llvm.func @struct_wrong_element_types() -> vector<5xf64> {
+ // expected-error @+1{{type and attribute have a different number of elements: 5 vs. 2}}
+ %0 = llvm.mlir.constant(dense<[1.0, 1.0]> : tensor<2xf64>) : vector<5xf64>
+ llvm.return %0 : vector<5xf64>
+}
+
+// -----
+
func.func @insertvalue_non_llvm_type(%a : i32, %b : i32) {
// expected-error@+2 {{expected LLVM IR Dialect type}}
llvm.insertvalue %a, %b[0] : tensor<*xi32>
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index fbdf725f3ec17b..8453983aa07c33 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1295,11 +1295,17 @@ llvm.func @complexintconstant() -> !llvm.struct<(i32, i32)> {
}
llvm.func @complexintconstantsplat() -> !llvm.array<2 x !llvm.struct<(i32, i32)>> {
- %1 = llvm.mlir.constant(dense<(0, 1)> : tensor<complex<i32>>) : !llvm.array<2 x !llvm.struct<(i32, i32)>>
+ %1 = llvm.mlir.constant(dense<(0, 1)> : tensor<2xcomplex<i32>>) : !llvm.array<2 x !llvm.struct<(i32, i32)>>
// CHECK: ret [2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 0, i32 1 }]
llvm.return %1 : !llvm.array<2 x !llvm.struct<(i32, i32)>>
}
+llvm.func @complexintconstantsingle() -> !llvm.array<1 x !llvm.struct<(i32, i32)>> {
+ %1 = llvm.mlir.constant(dense<(0, 1)> : tensor<complex<i32>>) : !llvm.array<1 x !llvm.struct<(i32, i32)>>
+ // CHECK: ret [1 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }]
+ llvm.return %1 : !llvm.array<1 x !llvm.struct<(i32, i32)>>
+}
+
llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>> {
%1 = llvm.mlir.constant(dense<[[(0, 1), (2, 3)], [(4, 5), (6, 7)]]> : tensor<2x2xcomplex<i32>>) : !llvm.array<2 x!llvm.array<2 x !llvm.struct<(i32, i32)>>>
// CHECK{LITERAL}: ret [2 x [2 x { i32, i32 }]] [[2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 2, i32 3 }], [2 x { i32, i32 }] [{ i32, i32 } { i32 4, i32 5 }, { i32, i32 } { i32 6, i32 7 }]]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the fix.
!isa<LLVM::LLVMFixedVectorType>(getType()) && | ||
!isa<LLVM::LLVMScalableVectorType>(getType())) | ||
} else if (isa<ElementsAttr, ArrayAttr>(getValue())) { | ||
if (isa<LLVM::LLVMScalableVectorType>(getType())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it even allowed to specify constant scalable vectors in the first place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was also wondering about that. We have test cases such as:
mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir: // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir: // CHECK-NEXT: %[[RES0:.*]] = llvm.mlir.constant(dense<false> : vector<2x[16]xi1>) : !llvm.array<2 x vector<[16]xi1>>
This is what the op documentation says: The operation produces a new SSA value of the specified LLVM IR dialect type. The type of that value _must_ correspond to the attribute type converted to LLVM IR.
I'm not sure what "correspond" means in this context.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess that this is just UB when the vector size does not match. I usually check what LLVM's verifier catch for such cases, but scalable vector support might not be too great.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The arith constant seems to check if the attribute is a splat elements attribute. A size check is not needed in the arith case since result and value type are equivalent...
I would suggest to keep the PR as is for now except you are up to figure out the exact semantics here. I assume the right thing would be to check for a splat elements attribute and matching scalable sizes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@banach-space @dcaballe Do you know how we should verify the number of elements when we have vector types and/or attributes with scalable dimensions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've looked at this in a slightly different context recently. My recommendation is to only allow "splats" when dealing with scalable vectors. As in, this would be fine:
llvm.mlir.constant(dense<[0]> : vector<[4]xi32>)
but not this:
llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<[4]xi32>)
I don't envisage us requiring the 2nd option.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added another check to the verifier such that a scalable vector type requires a splat attribute. Also updated the operation documentation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
!isa<LLVM::LLVMFixedVectorType>(getType()) && | ||
!isa<LLVM::LLVMScalableVectorType>(getType())) | ||
} else if (isa<ElementsAttr, ArrayAttr>(getValue())) { | ||
if (isa<LLVM::LLVMScalableVectorType>(getType())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The arith constant seems to check if the attribute is a splat elements attribute. A size check is not needed in the arith case since result and value type are equivalent...
I would suggest to keep the PR as is for now except you are up to figure out the exact semantics here. I assume the right thing would be to check for a splat elements attribute and matching scalable sizes?
getNumElements(arrayType.getElementType()); | ||
if (auto vecType = dyn_cast<LLVMFixedVectorType>(t)) | ||
return vecType.getNumElements() * getNumElements(vecType.getElementType()); | ||
return 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if I should take into account scalable vector types here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getVectorNumElements
maybe helpful (
llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not take into account nested vectors unfortunately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right since the verifier needs to look not only at vector types it probably makes sense to have the specialized helpers you added.
Check that the number of elements in the result type and the attribute of an `llvm.mlir.constant` op matches. Also fix a broken test where that was not the case.
d8792fe
to
8678ad8
Compare
Check that the number of elements in the result type and the attribute of an
llvm.mlir.constant
op matches. Also fix a broken test where that was not the case.