@@ -1310,7 +1310,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1310
1310
// The coordinates should be in shape of <? x rank>
1311
1311
unsigned expCOORank = stt.getLvlRank () - cooStartLvl;
1312
1312
if (cooTp.getRank () != 2 || expCOORank != cooTp.getShape ().back ()) {
1313
- op->emitError (" input/output trailing COO level-ranks don't match" );
1313
+ return op->emitError (" input/output trailing COO level-ranks don't match" );
1314
1314
}
1315
1315
}
1316
1316
@@ -1350,7 +1350,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1350
1350
}
1351
1351
1352
1352
LogicalResult AssembleOp::verify () {
1353
- const auto valuesTp = getRankedTensorType ( getValues ());
1353
+ RankedTensorType valuesTp = getValues (). getType ( );
1354
1354
const auto lvlsTp = getLevels ().getTypes ();
1355
1355
const auto resTp = getSparseTensorType (getResult ());
1356
1356
return verifyPackUnPack (*this , true , resTp, valuesTp, lvlsTp);
@@ -1364,34 +1364,31 @@ LogicalResult DisassembleOp::verify() {
1364
1364
if (ot.getType () != rt.getType ())
1365
1365
return emitError (" output levels and return levels type mismatch" );
1366
1366
1367
- const auto valuesTp = getRankedTensorType ( getRetValues ());
1367
+ RankedTensorType valuesTp = getRetValues (). getType ( );
1368
1368
const auto lvlsTp = getRetLevels ().getTypes ();
1369
1369
const auto srcTp = getSparseTensorType (getTensor ());
1370
1370
return verifyPackUnPack (*this , false , srcTp, valuesTp, lvlsTp);
1371
1371
}
1372
1372
1373
1373
LogicalResult ConvertOp::verify () {
1374
- if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource ().getType ())) {
1375
- if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest ().getType ())) {
1376
- if (tp1.getRank () != tp2.getRank ())
1377
- return emitError (" unexpected conversion mismatch in rank" );
1378
- auto dstEnc =
1379
- llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding ());
1380
- if (dstEnc && dstEnc.isSlice ())
1381
- return emitError (" cannot convert to a sparse tensor slice" );
1382
-
1383
- auto shape1 = tp1.getShape ();
1384
- auto shape2 = tp2.getShape ();
1385
- // Accept size matches between the source and the destination type
1386
- // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1387
- // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1388
- for (Dimension d = 0 , dimRank = tp1.getRank (); d < dimRank; d++)
1389
- if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic )
1390
- return emitError (" unexpected conversion mismatch in dimension " ) << d;
1391
- return success ();
1392
- }
1393
- }
1394
- return emitError (" unexpected type in convert" );
1374
+ RankedTensorType tp1 = getSource ().getType ();
1375
+ RankedTensorType tp2 = getDest ().getType ();
1376
+ if (tp1.getRank () != tp2.getRank ())
1377
+ return emitError (" unexpected conversion mismatch in rank" );
1378
+ auto dstEnc =
1379
+ llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding ());
1380
+ if (dstEnc && dstEnc.isSlice ())
1381
+ return emitError (" cannot convert to a sparse tensor slice" );
1382
+
1383
+ auto shape1 = tp1.getShape ();
1384
+ auto shape2 = tp2.getShape ();
1385
+ // Accept size matches between the source and the destination type
1386
+ // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1387
+ // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1388
+ for (Dimension d = 0 , dimRank = tp1.getRank (); d < dimRank; d++)
1389
+ if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic )
1390
+ return emitError (" unexpected conversion mismatch in dimension " ) << d;
1391
+ return success ();
1395
1392
}
1396
1393
1397
1394
OpFoldResult ConvertOp::fold (FoldAdaptor adaptor) {
@@ -1495,7 +1492,8 @@ LogicalResult LvlOp::verify() {
1495
1492
if (std::optional<uint64_t > lvl = getConstantLvlIndex ()) {
1496
1493
auto stt = getSparseTensorType (getSource ());
1497
1494
if (static_cast <uint64_t >(lvl.value ()) >= stt.getLvlRank ())
1498
- emitError (" Level index exceeds the rank of the input sparse tensor" );
1495
+ return emitError (
1496
+ " Level index exceeds the rank of the input sparse tensor" );
1499
1497
}
1500
1498
return success ();
1501
1499
}
@@ -1697,14 +1695,14 @@ LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1697
1695
}
1698
1696
1699
1697
LogicalResult ToSliceOffsetOp::verify () {
1700
- auto rank = getRankedTensorType ( getSlice ()).getRank ();
1698
+ auto rank = getSlice (). getType ( ).getRank ();
1701
1699
if (rank <= getDim ().getSExtValue () || getDim ().getSExtValue () < 0 )
1702
1700
return emitError (" requested dimension out of bound" );
1703
1701
return success ();
1704
1702
}
1705
1703
1706
1704
LogicalResult ToSliceStrideOp::verify () {
1707
- auto rank = getRankedTensorType ( getSlice ()).getRank ();
1705
+ auto rank = getSlice (). getType ( ).getRank ();
1708
1706
if (rank <= getDim ().getSExtValue () || getDim ().getSExtValue () < 0 )
1709
1707
return emitError (" requested dimension out of bound" );
1710
1708
return success ();
@@ -1986,15 +1984,16 @@ LogicalResult ForeachOp::verify() {
1986
1984
const auto iTp = IndexType::get (getContext ());
1987
1985
for (Dimension d = 0 ; d < dimRank; d++)
1988
1986
if (args[d].getType () != iTp)
1989
- emitError (
1987
+ return emitError (
1990
1988
llvm::formatv (" Expecting Index type for argument at index {0}" , d));
1991
1989
1992
1990
const auto elemTp = t.getElementType ();
1993
1991
const auto valueTp = args[dimRank].getType ();
1994
1992
if (elemTp != valueTp)
1995
- emitError (llvm::formatv (" Unmatched element type between input tensor and "
1996
- " block argument, expected:{0}, got: {1}" ,
1997
- elemTp, valueTp));
1993
+ return emitError (
1994
+ llvm::formatv (" Unmatched element type between input tensor and "
1995
+ " block argument, expected:{0}, got: {1}" ,
1996
+ elemTp, valueTp));
1998
1997
return success ();
1999
1998
}
2000
1999
@@ -2011,15 +2010,15 @@ LogicalResult ReorderCOOOp::verify() {
2011
2010
SparseTensorType dstStt = getSparseTensorType (getResultCoo ());
2012
2011
2013
2012
if (!srcStt.isCOOType () || !dstStt.isCOOType ())
2014
- emitError (" Expected COO sparse tensors only" );
2013
+ return emitError (" Expected COO sparse tensors only" );
2015
2014
2016
2015
if (!srcStt.hasSameDimToLvl (dstStt))
2017
- emitError (" Unmatched dim2lvl map between input and result COO" );
2016
+ return emitError (" Unmatched dim2lvl map between input and result COO" );
2018
2017
2019
2018
if (srcStt.getPosType () != dstStt.getPosType () ||
2020
2019
srcStt.getCrdType () != dstStt.getCrdType () ||
2021
2020
srcStt.getElementType () != dstStt.getElementType ())
2022
- emitError (" Unmatched storage format between input and result COO" );
2021
+ return emitError (" Unmatched storage format between input and result COO" );
2023
2022
2024
2023
return success ();
2025
2024
}
@@ -2044,10 +2043,11 @@ LogicalResult SortOp::verify() {
2044
2043
AffineMap xPerm = getPermMap ();
2045
2044
uint64_t nx = xPerm.getNumDims ();
2046
2045
if (nx < 1 )
2047
- emitError (llvm::formatv (" Expected rank(perm_map) > 1, got {0}" , nx));
2046
+ return emitError (llvm::formatv (" Expected rank(perm_map) > 1, got {0}" , nx));
2048
2047
2049
2048
if (!xPerm.isPermutation ())
2050
- emitError (llvm::formatv (" Expected a permutation map, got {0}" , xPerm));
2049
+ return emitError (
2050
+ llvm::formatv (" Expected a permutation map, got {0}" , xPerm));
2051
2051
2052
2052
// We can't check the size of the buffers when n or buffer dimensions aren't
2053
2053
// compile-time constants.
@@ -2056,19 +2056,24 @@ LogicalResult SortOp::verify() {
2056
2056
return success ();
2057
2057
2058
2058
// Verify dimensions.
2059
- const auto checkDim = [&](Value v, Size minSize, const char *message) {
2059
+ const auto checkDim = [&](Value v, Size minSize,
2060
+ const char *message) -> LogicalResult {
2060
2061
const Size sh = getMemRefType (v).getShape ()[0 ];
2061
2062
if (!ShapedType::isDynamic (sh) && sh < minSize)
2062
- emitError (llvm::formatv (" {0} got {1} < {2}" , message, sh, minSize));
2063
+ return emitError (
2064
+ llvm::formatv (" {0} got {1} < {2}" , message, sh, minSize));
2065
+ return success ();
2063
2066
};
2064
2067
uint64_t n = cn.value ();
2065
2068
uint64_t ny = 0 ;
2066
2069
if (auto nyAttr = getNyAttr ())
2067
2070
ny = nyAttr.getInt ();
2068
- checkDim (getXy (), n * (nx + ny),
2069
- " Expected dimension(xy) >= n * (rank(perm_map) + ny)" );
2071
+ if (failed (checkDim (getXy (), n * (nx + ny),
2072
+ " Expected dimension(xy) >= n * (rank(perm_map) + ny)" )))
2073
+ return failure ();
2070
2074
for (Value opnd : getYs ())
2071
- checkDim (opnd, n, " Expected dimension(y) >= n" );
2075
+ if (failed (checkDim (opnd, n, " Expected dimension(y) >= n" )))
2076
+ return failure ();
2072
2077
2073
2078
return success ();
2074
2079
}
@@ -2101,8 +2106,8 @@ static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
2101
2106
}
2102
2107
2103
2108
if (lvlHi <= lvlLo)
2104
- parser.emitError (parser.getNameLoc (),
2105
- " expect larger level upper bound than lower bound" );
2109
+ return parser.emitError (parser.getNameLoc (),
2110
+ " expect larger level upper bound than lower bound" );
2106
2111
2107
2112
return success ();
2108
2113
}
0 commit comments