Skip to content

Commit 9442427

Browse files
use TypedAttr
1 parent 7497b48 commit 9442427

File tree

2 files changed

+18
-27
lines changed

2 files changed

+18
-27
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -512,11 +512,6 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
512512
void printSymbols(AffineMap &map, AsmPrinter &printer) const;
513513
void printDimensions(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr> dimSlices) const;
514514
void printLevels(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::LevelType> lvlTypes) const;
515-
516-
//
517-
// Explicit/implicit value methods.
518-
//
519-
Type getMismatchedValueType(Type elementType, Attribute val) const;
520515
}];
521516

522517
let genVerifyDecl = 1;

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -893,19 +893,6 @@ LogicalResult SparseTensorEncodingAttr::verify(
893893
return success();
894894
}
895895

896-
Type SparseTensorEncodingAttr::getMismatchedValueType(Type elementType,
897-
Attribute val) const {
898-
Type type;
899-
auto fVal = llvm::dyn_cast<FloatAttr>(val);
900-
auto intVal = llvm::dyn_cast<IntegerAttr>(val);
901-
if (fVal && fVal.getType() != elementType) {
902-
type = fVal.getType();
903-
} else if (intVal && intVal.getType() != elementType) {
904-
type = intVal.getType();
905-
}
906-
return type;
907-
}
908-
909896
LogicalResult SparseTensorEncodingAttr::verifyEncoding(
910897
ArrayRef<Size> dimShape, Type elementType,
911898
function_ref<InFlightDiagnostic()> emitError) const {
@@ -925,20 +912,29 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
925912
return emitError()
926913
<< "dimension-rank mismatch between encoding and tensor shape: "
927914
<< getDimRank() << " != " << dimRank;
928-
Type type;
929915
if (getExplicitVal()) {
930-
if ((type = getMismatchedValueType(elementType, getExplicitVal()))) {
931-
return emitError() << "explicit value type mismatch between encoding and "
932-
<< "tensor element type: " << type
933-
<< " != " << elementType;
916+
if (auto typedAttr = llvm::dyn_cast<TypedAttr>(getExplicitVal())) {
917+
Type attrType = typedAttr.getType();
918+
if (attrType != elementType) {
919+
return emitError()
920+
<< "explicit value type mismatch between encoding and "
921+
<< "tensor element type: " << attrType << " != " << elementType;
922+
}
923+
} else {
924+
return emitError() << "expected typed explicit value";
934925
}
935926
}
936927
if (getImplicitVal()) {
937928
auto impVal = getImplicitVal();
938-
if ((type = getMismatchedValueType(elementType, impVal))) {
939-
return emitError() << "implicit value type mismatch between encoding and "
940-
<< "tensor element type: " << type
941-
<< " != " << elementType;
929+
if (auto typedAttr = llvm::dyn_cast<TypedAttr>(getImplicitVal())) {
930+
Type attrType = typedAttr.getType();
931+
if (attrType != elementType) {
932+
return emitError()
933+
<< "implicit value type mismatch between encoding and "
934+
<< "tensor element type: " << attrType << " != " << elementType;
935+
}
936+
} else {
937+
return emitError() << "expected typed implicit value";
942938
}
943939
// Currently, we only support zero as the implicit value.
944940
auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);

0 commit comments

Comments
 (0)