@@ -2666,6 +2666,20 @@ OpFoldResult LLVM::ZeroOp::fold(FoldAdaptor) {
2666
2666
// ConstantOp.
2667
2667
// ===----------------------------------------------------------------------===//
2668
2668
2669
+ // / Compute the total number of elements in the given type, also taking into
2670
+ // / account nested types. Supported types are `VectorType`, `LLVMArrayType` and
2671
+ // / `LLVMFixedVectorType`. Everything else is treated as a scalar.
2672
+ static int64_t getNumElements (Type t) {
2673
+ if (auto vecType = dyn_cast<VectorType>(t))
2674
+ return vecType.getNumElements () * getNumElements (vecType.getElementType ());
2675
+ if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
2676
+ return arrayType.getNumElements () *
2677
+ getNumElements (arrayType.getElementType ());
2678
+ if (auto vecType = dyn_cast<LLVMFixedVectorType>(t))
2679
+ return vecType.getNumElements () * getNumElements (vecType.getElementType ());
2680
+ return 1 ;
2681
+ }
2682
+
2669
2683
LogicalResult LLVM::ConstantOp::verify () {
2670
2684
if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue ())) {
2671
2685
auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType ());
@@ -2708,14 +2722,12 @@ LogicalResult LLVM::ConstantOp::verify() {
2708
2722
if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType ())) {
2709
2723
return emitOpError () << " does not support target extension type." ;
2710
2724
}
2711
- if (!llvm::isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>(getValue ()))
2712
- return emitOpError ()
2713
- << " only supports integer, float, string or elements attributes" ;
2725
+
2726
+ // Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr.
2714
2727
if (auto intAttr = dyn_cast<IntegerAttr>(getValue ())) {
2715
2728
if (!llvm::isa<IntegerType>(getType ()))
2716
2729
return emitOpError () << " expected integer type" ;
2717
- }
2718
- if (auto floatAttr = dyn_cast<FloatAttr>(getValue ())) {
2730
+ } else if (auto floatAttr = dyn_cast<FloatAttr>(getValue ())) {
2719
2731
const llvm::fltSemantics &sem = floatAttr.getValue ().getSemantics ();
2720
2732
unsigned floatWidth = APFloat::getSizeInBits (sem);
2721
2733
if (auto floatTy = dyn_cast<FloatType>(getType ())) {
@@ -2728,13 +2740,30 @@ LogicalResult LLVM::ConstantOp::verify() {
2728
2740
if (isa<IntegerType>(getType ()) && !getType ().isInteger (floatWidth)) {
2729
2741
return emitOpError () << " expected integer type of width " << floatWidth;
2730
2742
}
2731
- }
2732
- if (auto splatAttr = dyn_cast<SplatElementsAttr>(getValue ())) {
2733
- if (!isa<VectorType>(getType ()) && !isa<LLVM::LLVMArrayType>(getType ()) &&
2734
- !isa<LLVM::LLVMFixedVectorType>(getType ()) &&
2735
- !isa<LLVM::LLVMScalableVectorType>(getType ()))
2743
+ } else if (isa<ElementsAttr, ArrayAttr>(getValue ())) {
2744
+ if (isa<LLVM::LLVMScalableVectorType>(getType ())) {
2745
+ // The exact number of elements of a scalable vector is unknown, so there
2746
+ // is nothing more to verify.
2747
+ return success ();
2748
+ }
2749
+ if (!isa<VectorType, LLVM::LLVMArrayType, LLVM::LLVMFixedVectorType>(
2750
+ getType ()))
2736
2751
return emitOpError () << " expected vector or array type" ;
2752
+ // The number of elements of the attribute and the type must match.
2753
+ int64_t attrNumElements;
2754
+ if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue ()))
2755
+ attrNumElements = elementsAttr.getNumElements ();
2756
+ else
2757
+ attrNumElements = cast<ArrayAttr>(getValue ()).size ();
2758
+ if (getNumElements (getType ()) != attrNumElements)
2759
+ return emitOpError ()
2760
+ << " type and attribute have a different number of elements: "
2761
+ << getNumElements (getType ()) << " vs. " << attrNumElements;
2762
+ } else {
2763
+ return emitOpError ()
2764
+ << " only supports integer, float, string or elements attributes" ;
2737
2765
}
2766
+
2738
2767
return success ();
2739
2768
}
2740
2769
0 commit comments