Skip to content

Commit bb03496

Browse files
use TypedAttr
1 parent f4f58e4 commit bb03496

File tree

2 files changed

+18
-27
lines changed

2 files changed

+18
-27
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -512,11 +512,6 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
512512
void printSymbols(AffineMap &map, AsmPrinter &printer) const;
513513
void printDimensions(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::SparseTensorDimSliceAttr> dimSlices) const;
514514
void printLevels(AffineMap &map, AsmPrinter &printer, ArrayRef<::mlir::sparse_tensor::LevelType> lvlTypes) const;
515-
516-
//
517-
// Explicit/implicit value methods.
518-
//
519-
Type getMismatchedValueType(Type elementType, Attribute val) const;
520515
}];
521516

522517
let genVerifyDecl = 1;

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -888,19 +888,6 @@ LogicalResult SparseTensorEncodingAttr::verify(
888888
return success();
889889
}
890890

891-
Type SparseTensorEncodingAttr::getMismatchedValueType(Type elementType,
892-
Attribute val) const {
893-
Type type;
894-
auto fVal = llvm::dyn_cast<FloatAttr>(val);
895-
auto intVal = llvm::dyn_cast<IntegerAttr>(val);
896-
if (fVal && fVal.getType() != elementType) {
897-
type = fVal.getType();
898-
} else if (intVal && intVal.getType() != elementType) {
899-
type = intVal.getType();
900-
}
901-
return type;
902-
}
903-
904891
LogicalResult SparseTensorEncodingAttr::verifyEncoding(
905892
ArrayRef<Size> dimShape, Type elementType,
906893
function_ref<InFlightDiagnostic()> emitError) const {
@@ -920,20 +907,29 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
920907
return emitError()
921908
<< "dimension-rank mismatch between encoding and tensor shape: "
922909
<< getDimRank() << " != " << dimRank;
923-
Type type;
924910
if (getExplicitVal()) {
925-
if ((type = getMismatchedValueType(elementType, getExplicitVal()))) {
926-
return emitError() << "explicit value type mismatch between encoding and "
927-
<< "tensor element type: " << type
928-
<< " != " << elementType;
911+
if (auto typedAttr = llvm::dyn_cast<TypedAttr>(getExplicitVal())) {
912+
Type attrType = typedAttr.getType();
913+
if (attrType != elementType) {
914+
return emitError()
915+
<< "explicit value type mismatch between encoding and "
916+
<< "tensor element type: " << attrType << " != " << elementType;
917+
}
918+
} else {
919+
return emitError() << "expected typed explicit value";
929920
}
930921
}
931922
if (getImplicitVal()) {
932923
auto impVal = getImplicitVal();
933-
if ((type = getMismatchedValueType(elementType, impVal))) {
934-
return emitError() << "implicit value type mismatch between encoding and "
935-
<< "tensor element type: " << type
936-
<< " != " << elementType;
924+
if (auto typedAttr = llvm::dyn_cast<TypedAttr>(getImplicitVal())) {
925+
Type attrType = typedAttr.getType();
926+
if (attrType != elementType) {
927+
return emitError()
928+
<< "implicit value type mismatch between encoding and "
929+
<< "tensor element type: " << attrType << " != " << elementType;
930+
}
931+
} else {
932+
return emitError() << "expected typed implicit value";
937933
}
938934
// Currently, we only support zero as the implicit value.
939935
auto impFVal = llvm::dyn_cast<FloatAttr>(impVal);

0 commit comments

Comments
 (0)