@@ -893,19 +893,6 @@ LogicalResult SparseTensorEncodingAttr::verify(
893
893
return success ();
894
894
}
895
895
896
- Type SparseTensorEncodingAttr::getMismatchedValueType (Type elementType,
897
- Attribute val) const {
898
- Type type;
899
- auto fVal = llvm::dyn_cast<FloatAttr>(val);
900
- auto intVal = llvm::dyn_cast<IntegerAttr>(val);
901
- if (fVal && fVal .getType () != elementType) {
902
- type = fVal .getType ();
903
- } else if (intVal && intVal.getType () != elementType) {
904
- type = intVal.getType ();
905
- }
906
- return type;
907
- }
908
-
909
896
LogicalResult SparseTensorEncodingAttr::verifyEncoding (
910
897
ArrayRef<Size> dimShape, Type elementType,
911
898
function_ref<InFlightDiagnostic()> emitError) const {
@@ -925,20 +912,29 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
925
912
return emitError ()
926
913
<< " dimension-rank mismatch between encoding and tensor shape: "
927
914
<< getDimRank () << " != " << dimRank;
928
- Type type;
929
915
if (getExplicitVal ()) {
930
- if ((type = getMismatchedValueType (elementType, getExplicitVal ()))) {
931
- return emitError () << " explicit value type mismatch between encoding and "
932
- << " tensor element type: " << type
933
- << " != " << elementType;
916
+ if (auto typedAttr = llvm::dyn_cast<TypedAttr>(getExplicitVal ())) {
917
+ Type attrType = typedAttr.getType ();
918
+ if (attrType != elementType) {
919
+ return emitError ()
920
+ << " explicit value type mismatch between encoding and "
921
+ << " tensor element type: " << attrType << " != " << elementType;
922
+ }
923
+ } else {
924
+ return emitError () << " expected typed explicit value" ;
934
925
}
935
926
}
936
927
if (getImplicitVal ()) {
937
928
auto impVal = getImplicitVal ();
938
- if ((type = getMismatchedValueType (elementType, impVal))) {
939
- return emitError () << " implicit value type mismatch between encoding and "
940
- << " tensor element type: " << type
941
- << " != " << elementType;
929
+ if (auto typedAttr = llvm::dyn_cast<TypedAttr>(getImplicitVal ())) {
930
+ Type attrType = typedAttr.getType ();
931
+ if (attrType != elementType) {
932
+ return emitError ()
933
+ << " implicit value type mismatch between encoding and "
934
+ << " tensor element type: " << attrType << " != " << elementType;
935
+ }
936
+ } else {
937
+ return emitError () << " expected typed implicit value" ;
942
938
}
943
939
// Currently, we only support zero as the implicit value.
944
940
auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);
0 commit comments