@@ -888,19 +888,6 @@ LogicalResult SparseTensorEncodingAttr::verify(
888
888
return success ();
889
889
}
890
890
891
- Type SparseTensorEncodingAttr::getMismatchedValueType (Type elementType,
892
- Attribute val) const {
893
- Type type;
894
- auto fVal = llvm::dyn_cast<FloatAttr>(val);
895
- auto intVal = llvm::dyn_cast<IntegerAttr>(val);
896
- if (fVal && fVal .getType () != elementType) {
897
- type = fVal .getType ();
898
- } else if (intVal && intVal.getType () != elementType) {
899
- type = intVal.getType ();
900
- }
901
- return type;
902
- }
903
-
904
891
LogicalResult SparseTensorEncodingAttr::verifyEncoding (
905
892
ArrayRef<Size> dimShape, Type elementType,
906
893
function_ref<InFlightDiagnostic()> emitError) const {
@@ -920,20 +907,29 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
920
907
return emitError ()
921
908
<< " dimension-rank mismatch between encoding and tensor shape: "
922
909
<< getDimRank () << " != " << dimRank;
923
- Type type;
924
910
if (getExplicitVal ()) {
925
- if ((type = getMismatchedValueType (elementType, getExplicitVal ()))) {
926
- return emitError () << " explicit value type mismatch between encoding and "
927
- << " tensor element type: " << type
928
- << " != " << elementType;
911
+ if (auto typedAttr = llvm::dyn_cast<TypedAttr>(getExplicitVal ())) {
912
+ Type attrType = typedAttr.getType ();
913
+ if (attrType != elementType) {
914
+ return emitError ()
915
+ << " explicit value type mismatch between encoding and "
916
+ << " tensor element type: " << attrType << " != " << elementType;
917
+ }
918
+ } else {
919
+ return emitError () << " expected typed explicit value" ;
929
920
}
930
921
}
931
922
if (getImplicitVal ()) {
932
923
auto impVal = getImplicitVal ();
933
- if ((type = getMismatchedValueType (elementType, impVal))) {
934
- return emitError () << " implicit value type mismatch between encoding and "
935
- << " tensor element type: " << type
936
- << " != " << elementType;
924
+ if (auto typedAttr = llvm::dyn_cast<TypedAttr>(getImplicitVal ())) {
925
+ Type attrType = typedAttr.getType ();
926
+ if (attrType != elementType) {
927
+ return emitError ()
928
+ << " implicit value type mismatch between encoding and "
929
+ << " tensor element type: " << attrType << " != " << elementType;
930
+ }
931
+ } else {
932
+ return emitError () << " expected typed implicit value" ;
937
933
}
938
934
// Currently, we only support zero as the implicit value.
939
935
auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
0 commit comments