Skip to content

[mlir][LLVMIR] Check number of elements in mlir.constant verifier #102906

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

Check that the number of elements in the result type and the attribute of an llvm.mlir.constant op matches. Also fix a broken test where that was not the case.

@llvmbot
Copy link
Member

llvmbot commented Aug 12, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Matthias Springer (matthias-springer)

Changes

Check that the number of elements in the result type and the attribute of an llvm.mlir.constant op matches. Also fix a broken test where that was not the case.


Full diff: https://github.com/llvm/llvm-project/pull/102906.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+39-10)
  • (modified) mlir/test/Dialect/LLVMIR/invalid.mlir (+8)
  • (modified) mlir/test/Target/LLVMIR/llvmir.mlir (+7-1)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 90610118a45cd2..82dadfc98de7b6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2666,6 +2666,20 @@ OpFoldResult LLVM::ZeroOp::fold(FoldAdaptor) {
 // ConstantOp.
 //===----------------------------------------------------------------------===//
 
+/// Compute the total number of elements in the given type, also taking into
+/// account nested types. Supported types are `VectorType`, `LLVMArrayType` and
+/// `LLVMFixedVectorType`. Everything else is treated as a scalar.
+static int64_t getNumElements(Type t) {
+  if (auto vecType = dyn_cast<VectorType>(t))
+    return vecType.getNumElements() * getNumElements(vecType.getElementType());
+  if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
+    return arrayType.getNumElements() *
+           getNumElements(arrayType.getElementType());
+  if (auto vecType = dyn_cast<LLVMFixedVectorType>(t))
+    return vecType.getNumElements() * getNumElements(vecType.getElementType());
+  return 1;
+}
+
 LogicalResult LLVM::ConstantOp::verify() {
   if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue())) {
     auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType());
@@ -2708,14 +2722,12 @@ LogicalResult LLVM::ConstantOp::verify() {
   if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
     return emitOpError() << "does not support target extension type.";
   }
-  if (!llvm::isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>(getValue()))
-    return emitOpError()
-           << "only supports integer, float, string or elements attributes";
+
+  // Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr.
   if (auto intAttr = dyn_cast<IntegerAttr>(getValue())) {
     if (!llvm::isa<IntegerType>(getType()))
       return emitOpError() << "expected integer type";
-  }
-  if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
+  } else if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
     const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
     unsigned floatWidth = APFloat::getSizeInBits(sem);
     if (auto floatTy = dyn_cast<FloatType>(getType())) {
@@ -2728,13 +2740,30 @@ LogicalResult LLVM::ConstantOp::verify() {
     if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) {
       return emitOpError() << "expected integer type of width " << floatWidth;
     }
-  }
-  if (auto splatAttr = dyn_cast<SplatElementsAttr>(getValue())) {
-    if (!isa<VectorType>(getType()) && !isa<LLVM::LLVMArrayType>(getType()) &&
-        !isa<LLVM::LLVMFixedVectorType>(getType()) &&
-        !isa<LLVM::LLVMScalableVectorType>(getType()))
+  } else if (isa<ElementsAttr, ArrayAttr>(getValue())) {
+    if (isa<LLVM::LLVMScalableVectorType>(getType())) {
+      // The exact number of elements of a scalable vector is unknown, so there
+      // is nothing more to verify.
+      return success();
+    }
+    if (!isa<VectorType, LLVM::LLVMArrayType, LLVM::LLVMFixedVectorType>(
+            getType()))
       return emitOpError() << "expected vector or array type";
+    // The number of elements of the attribute and the type must match.
+    int64_t attrNumElements;
+    if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue()))
+      attrNumElements = elementsAttr.getNumElements();
+    else
+      attrNumElements = cast<ArrayAttr>(getValue()).size();
+    if (getNumElements(getType()) != attrNumElements)
+      return emitOpError()
+             << "type and attribute have a different number of elements: "
+             << getNumElements(getType()) << " vs. " << attrNumElements;
+  } else {
+    return emitOpError()
+           << "only supports integer, float, string or elements attributes";
   }
+
   return success();
 }
 
diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir
index fe288dab973f5a..7edf036201e1c0 100644
--- a/mlir/test/Dialect/LLVMIR/invalid.mlir
+++ b/mlir/test/Dialect/LLVMIR/invalid.mlir
@@ -414,6 +414,14 @@ llvm.func @struct_wrong_element_types() -> !llvm.struct<(!llvm.array<2 x f64>, !
 
 // -----
 
+llvm.func @struct_wrong_element_types() -> vector<5xf64> {
+  // expected-error @+1{{type and attribute have a different number of elements: 5 vs. 2}}
+  %0 = llvm.mlir.constant(dense<[1.0, 1.0]> : tensor<2xf64>) : vector<5xf64>
+  llvm.return %0 : vector<5xf64>
+}
+
+// -----
+
 func.func @insertvalue_non_llvm_type(%a : i32, %b : i32) {
   // expected-error@+2 {{expected LLVM IR Dialect type}}
   llvm.insertvalue %a, %b[0] : tensor<*xi32>
diff --git a/mlir/test/Target/LLVMIR/llvmir.mlir b/mlir/test/Target/LLVMIR/llvmir.mlir
index fbdf725f3ec17b..8453983aa07c33 100644
--- a/mlir/test/Target/LLVMIR/llvmir.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir.mlir
@@ -1295,11 +1295,17 @@ llvm.func @complexintconstant() -> !llvm.struct<(i32, i32)> {
 }
 
 llvm.func @complexintconstantsplat() -> !llvm.array<2 x !llvm.struct<(i32, i32)>> {
-  %1 = llvm.mlir.constant(dense<(0, 1)> : tensor<complex<i32>>) : !llvm.array<2 x !llvm.struct<(i32, i32)>>
+  %1 = llvm.mlir.constant(dense<(0, 1)> : tensor<2xcomplex<i32>>) : !llvm.array<2 x !llvm.struct<(i32, i32)>>
   // CHECK: ret [2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 0, i32 1 }]
   llvm.return %1 : !llvm.array<2 x !llvm.struct<(i32, i32)>>
 }
 
+llvm.func @complexintconstantsingle() -> !llvm.array<1 x !llvm.struct<(i32, i32)>> {
+  %1 = llvm.mlir.constant(dense<(0, 1)> : tensor<complex<i32>>) : !llvm.array<1 x !llvm.struct<(i32, i32)>>
+  // CHECK: ret [1 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }]
+  llvm.return %1 : !llvm.array<1 x !llvm.struct<(i32, i32)>>
+}
+
 llvm.func @complexintconstantarray() -> !llvm.array<2 x !llvm.array<2 x !llvm.struct<(i32, i32)>>> {
   %1 = llvm.mlir.constant(dense<[[(0, 1), (2, 3)], [(4, 5), (6, 7)]]> : tensor<2x2xcomplex<i32>>) : !llvm.array<2 x!llvm.array<2 x !llvm.struct<(i32, i32)>>>
   // CHECK{LITERAL}: ret [2 x [2 x { i32, i32 }]] [[2 x { i32, i32 }] [{ i32, i32 } { i32 0, i32 1 }, { i32, i32 } { i32 2, i32 3 }], [2 x { i32, i32 }] [{ i32, i32 } { i32 4, i32 5 }, { i32, i32 } { i32 6, i32 7 }]]

Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the fix.

!isa<LLVM::LLVMFixedVectorType>(getType()) &&
!isa<LLVM::LLVMScalableVectorType>(getType()))
} else if (isa<ElementsAttr, ArrayAttr>(getValue())) {
if (isa<LLVM::LLVMScalableVectorType>(getType())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it even allowed to specify constant scalable vectors in the first place?

Copy link
Member Author

@matthias-springer matthias-springer Aug 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also wondering about that. We have test cases such as:

mlir/test/Conversion/MathToLLVM/math-to-llvm.mlir:  // CHECK: %[[ONE:.*]] = llvm.mlir.constant(dense<1.000000e+00> : vector<[4]xf32>) : vector<[4]xf32>
mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir:  // CHECK-NEXT:    %[[RES0:.*]] = llvm.mlir.constant(dense<false> : vector<2x[16]xi1>) : !llvm.array<2 x vector<[16]xi1>>

This is what the op documentation says: The operation produces a new SSA value of the specified LLVM IR dialect type. The type of that value _must_ correspond to the attribute type converted to LLVM IR.

I'm not sure what "correspond" means in this context.

Copy link
Contributor

@Dinistro Dinistro Aug 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that this is just UB when the vector size does not match. I usually check what LLVM's verifier catch for such cases, but scalable vector support might not be too great.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The arith constant seems to check if the attribute is a splat elements attribute. A size check is not needed in the arith case since result and value type are equivalent...

I would suggest to keep the PR as is for now except you are up to figure out the exact semantics here. I assume the right thing would be to check for a splat elements attribute and matching scalable sizes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@banach-space @dcaballe Do you know how we should verify the number of elements when we have vector types and/or attributes with scalable dimensions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've looked at this in a slightly different context recently. My recommendation is to only allow "splats" when dealing with scalable vectors. As in, this would be fine:

llvm.mlir.constant(dense<[0]> : vector<[4]xi32>)

but not this:

llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<[4]xi32>)

I don't envisage us requiring the 2nd option.

Copy link
Member Author

@matthias-springer matthias-springer Aug 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added another check to the verifier such that a scalable vector type requires a splat attribute. Also updated the operation documentation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

!isa<LLVM::LLVMFixedVectorType>(getType()) &&
!isa<LLVM::LLVMScalableVectorType>(getType()))
} else if (isa<ElementsAttr, ArrayAttr>(getValue())) {
if (isa<LLVM::LLVMScalableVectorType>(getType())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The arith constant seems to check if the attribute is a splat elements attribute. A size check is not needed in the arith case since result and value type are equivalent...

I would suggest to keep the PR as is for now except you are up to figure out the exact semantics here. I assume the right thing would be to check for a splat elements attribute and matching scalable sizes?

getNumElements(arrayType.getElementType());
if (auto vecType = dyn_cast<LLVMFixedVectorType>(t))
return vecType.getNumElements() * getNumElements(vecType.getElementType());
return 1;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if I should take into account scalable vector types here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getVectorNumElements maybe helpful (

llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not take into account nested vectors unfortunately.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right since the verifier needs to look not only at vector types it probably makes sense to have the specialized helpers you added.

Check that the number of elements in the result type and the attribute of an `llvm.mlir.constant` op matches. Also fix a broken test where that was not the case.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/llvm_constant_verifier branch from d8792fe to 8678ad8 Compare August 13, 2024 08:13
@matthias-springer matthias-springer merged commit b8bf14e into main Aug 13, 2024
6 of 7 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/llvm_constant_verifier branch August 13, 2024 08:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants