Skip to content

[mlir][llvm] Fix verifier for const int and dense #74340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 5, 2023
Merged

[mlir][llvm] Fix verifier for const int and dense #74340

merged 5 commits into from
Dec 5, 2023

Conversation

rikhuijzer
Copy link
Member

Continuation of #74247 to fix #56962. Fixes verifier for (Integer Attr):

llvm.mlir.constant(1 : index) : f32

and (Dense Attr):

llvm.mlir.constant(dense<100.0> : vector<1xf64>) : f32

Integer Attr

The addition that this PR makes to LLVM::ConstantOp::verify is meant to be exactly verifying the code in mlir::LLVM::detail::getLLVMConstant:

if (auto intAttr = dyn_cast<IntegerAttr>(attr))
return llvm::ConstantInt::get(
llvmType,
intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth()));

One failure mode is when the type (llvm.mlir.constant(<value>) : <type>) is not an Integer, because then the cast in getIntegerBitWidth will crash:

unsigned Type::getIntegerBitWidth() const {
return cast<IntegerType>(this)->getBitWidth();
}

So that's now caught in the verifier.

Apart from that, I don't see anything we could check for. sextOrTrunc means "Sign extend or truncate to width" and that one is quite permissive. For example, the following doesn't have to be caught in the verifier as it doesn't crash during mlir-translate -mlir-to-llvmir:

llvm.func @main() -> f32 {
  %cst = llvm.mlir.constant(100 : i64) : f32
  llvm.return %cst : f32
}

Dense Attr

Crash if not either a MLIR Vector type or one of these:

if (auto splatAttr = dyn_cast<SplatElementsAttr>(attr)) {
llvm::Type *elementType;
uint64_t numElements;
bool isScalable = false;
if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) {
elementType = arrayTy->getElementType();
numElements = arrayTy->getNumElements();
} else if (auto *fVectorTy = dyn_cast<llvm::FixedVectorType>(llvmType)) {
elementType = fVectorTy->getElementType();
numElements = fVectorTy->getNumElements();
} else if (auto *sVectorTy = dyn_cast<llvm::ScalableVectorType>(llvmType)) {
elementType = sVectorTy->getElementType();
numElements = sVectorTy->getMinNumElements();
isScalable = true;
} else {
llvm_unreachable("unrecognized constant vector type");
}

@llvmbot
Copy link
Member

llvmbot commented Dec 4, 2023

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Rik Huijzer (rikhuijzer)

Changes

Continuation of #74247 to fix #56962. Fixes verifier for (Integer Attr):

llvm.mlir.constant(1 : index) : f32

and (Dense Attr):

llvm.mlir.constant(dense&lt;100.0&gt; : vector&lt;1xf64&gt;) : f32

Integer Attr

The addition that this PR makes to LLVM::ConstantOp::verify is meant to be exactly verifying the code in mlir::LLVM::detail::getLLVMConstant:

if (auto intAttr = dyn_cast<IntegerAttr>(attr))
return llvm::ConstantInt::get(
llvmType,
intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth()));

One failure mode is when the type (llvm.mlir.constant(&lt;value&gt;) : &lt;type&gt;) is not an Integer, because then the cast in getIntegerBitWidth will crash:

unsigned Type::getIntegerBitWidth() const {
return cast<IntegerType>(this)->getBitWidth();
}

So that's now caught in the verifier.

Apart from that, I don't see anything we could check for. sextOrTrunc means "Sign extend or truncate to width" and that one is quite permissive. For example, the following doesn't have to be caught in the verifier as it doesn't crash during mlir-translate -mlir-to-llvmir:

llvm.func @<!-- -->main() -&gt; f32 {
  %cst = llvm.mlir.constant(100 : i64) : f32
  llvm.return %cst : f32
}

Dense Attr

Crash if not either a MLIR Vector type or one of these:

if (auto splatAttr = dyn_cast<SplatElementsAttr>(attr)) {
llvm::Type *elementType;
uint64_t numElements;
bool isScalable = false;
if (auto *arrayTy = dyn_cast<llvm::ArrayType>(llvmType)) {
elementType = arrayTy->getElementType();
numElements = arrayTy->getNumElements();
} else if (auto *fVectorTy = dyn_cast<llvm::FixedVectorType>(llvmType)) {
elementType = fVectorTy->getElementType();
numElements = fVectorTy->getNumElements();
} else if (auto *sVectorTy = dyn_cast<llvm::ScalableVectorType>(llvmType)) {
elementType = sVectorTy->getElementType();
numElements = sVectorTy->getMinNumElements();
isScalable = true;
} else {
llvm_unreachable("unrecognized constant vector type");
}


Full diff: https://github.com/llvm/llvm-project/pull/74340.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+9)
  • (modified) mlir/test/Target/LLVMIR/llvmir-invalid.mlir (+16)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 71860363b8ea5..705788d237d7f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -2539,6 +2539,10 @@ LogicalResult LLVM::ConstantOp::verify() {
   if (!llvm::isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>(getValue()))
     return emitOpError()
            << "only supports integer, float, string or elements attributes";
+  if (auto intAttr = dyn_cast<IntegerAttr>(getValue())) {
+    if (!llvm::isa<IntegerType>(getType()))
+      return emitOpError() << "expected integer type";
+  }
   if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
     const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
     unsigned floatWidth = APFloat::getSizeInBits(sem);
@@ -2553,6 +2557,11 @@ LogicalResult LLVM::ConstantOp::verify() {
       return emitOpError() << "expected integer type of width " << floatWidth;
     }
   }
+  if (auto splatAttr = dyn_cast<SplatElementsAttr>(getValue())) {
+    if (!getType().isa<VectorType>() && !getType().isa<LLVM::LLVMArrayType>() &&
+        !getType().isa<LLVM::LLVMFixedVectorType>() && !getType().isa<LLVM::LLVMScalableVectorType>())
+      return emitOpError() << "expected vector or array type";
+  }
   return success();
 }
 
diff --git a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
index 0def5895fb330..117a6b8269089 100644
--- a/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/llvmir-invalid.mlir
@@ -7,6 +7,14 @@ func.func @foo() {
 
 // -----
 
+llvm.func @vector_with_non_vector_type() -> f32 {
+  // expected-error @below{{expected vector or array type}}
+  %cst = llvm.mlir.constant(dense<100.0> : vector<1xf64>) : f32
+  llvm.return %cst : f32
+}
+
+// -----
+
 llvm.func @no_non_complex_struct() -> !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>> {
   // expected-error @below{{expected struct type to be a complex number}}
   %0 = llvm.mlir.constant(dense<[[[1, 2], [3, 4]], [[42, 43], [44, 45]]]> : tensor<2x2x2xi32>) : !llvm.array<2 x array<2 x array<2 x struct<(i32)>>>>
@@ -31,6 +39,14 @@ llvm.func @struct_wrong_attribute_element_type() -> !llvm.struct<(f64, f64)> {
 
 // -----
 
+llvm.func @integer_with_float_type() -> f32 {
+  // expected-error @+1 {{expected integer type}}
+  %0 = llvm.mlir.constant(1 : index) : f32
+  llvm.return %0 : f32
+}
+
+// -----
+
 llvm.func @incompatible_float_attribute_type() -> f32 {
   // expected-error @below{{expected float type of width 64}}
   %cst = llvm.mlir.constant(1.0 : f64) : f32

Copy link

github-actions bot commented Dec 4, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@rikhuijzer rikhuijzer merged commit 13da9a5 into llvm:main Dec 5, 2023
@rikhuijzer rikhuijzer deleted the rh/mismatched-llvm-int-const branch December 5, 2023 11:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir:llvm] LLVM::ConstantOp allows mismatched constant value
3 participants