34
34
using namespace mlir ;
35
35
using namespace mlir ::sparse_tensor;
36
36
37
- #define RETURN_FAILURE_IF_FAILED (X ) \
38
- if (failed(X)) { \
39
- return failure (); \
40
- }
41
-
42
37
// ===----------------------------------------------------------------------===//
43
38
// Local convenience methods.
44
39
// ===----------------------------------------------------------------------===//
@@ -68,10 +63,6 @@ void StorageLayout::foreachField(
68
63
llvm::function_ref<bool (FieldIndex, SparseTensorFieldKind, Level,
69
64
DimLevelType)>
70
65
callback) const {
71
- #define RETURN_ON_FALSE (fidx, kind, lvl, dlt ) \
72
- if (!(callback (fidx, kind, lvl, dlt))) \
73
- return ;
74
-
75
66
const auto lvlTypes = enc.getLvlTypes ();
76
67
const Level lvlRank = enc.getLvlRank ();
77
68
const Level cooStart = getCOOStart (enc);
@@ -81,21 +72,22 @@ void StorageLayout::foreachField(
81
72
for (Level l = 0 ; l < end; l++) {
82
73
const auto dlt = lvlTypes[l];
83
74
if (isDLTWithPos (dlt)) {
84
- RETURN_ON_FALSE (fieldIdx++, SparseTensorFieldKind::PosMemRef, l, dlt);
75
+ if (!(callback (fieldIdx++, SparseTensorFieldKind::PosMemRef, l, dlt)))
76
+ return ;
85
77
}
86
78
if (isDLTWithCrd (dlt)) {
87
- RETURN_ON_FALSE (fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt);
79
+ if (!(callback (fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt)))
80
+ return ;
88
81
}
89
82
}
90
83
// The values array.
91
- RETURN_ON_FALSE (fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel ,
92
- DimLevelType::Undef);
93
-
84
+ if (!( callback (fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel ,
85
+ DimLevelType::Undef)))
86
+ return ;
94
87
// Put metadata at the end.
95
- RETURN_ON_FALSE (fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel ,
96
- DimLevelType::Undef);
97
-
98
- #undef RETURN_ON_FALSE
88
+ if (!(callback (fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel ,
89
+ DimLevelType::Undef)))
90
+ return ;
99
91
}
100
92
101
93
void sparse_tensor::foreachFieldAndTypeInSparseTensor (
@@ -435,18 +427,11 @@ SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
435
427
}
436
428
437
429
Attribute SparseTensorEncodingAttr::parse (AsmParser &parser, Type type) {
438
- #define RETURN_ON_FAIL (stmt ) \
439
- if (failed (stmt)) { \
440
- return {}; \
441
- }
442
- #define ERROR_IF (COND, MSG ) \
443
- if (COND) { \
444
- parser.emitError (parser.getNameLoc (), MSG); \
445
- return {}; \
446
- }
447
-
448
- RETURN_ON_FAIL (parser.parseLess ())
449
- RETURN_ON_FAIL (parser.parseLBrace ())
430
+ // Open "<{" part.
431
+ if (failed (parser.parseLess ()))
432
+ return {};
433
+ if (failed (parser.parseLBrace ()))
434
+ return {};
450
435
451
436
// Process the data from the parsed dictionary value into struct-like data.
452
437
SmallVector<DimLevelType> lvlTypes;
@@ -466,13 +451,15 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
466
451
}
467
452
unsigned keyWordIndex = it - keys.begin ();
468
453
// Consume the `=` after keys
469
- RETURN_ON_FAIL (parser.parseEqual ())
454
+ if (failed (parser.parseEqual ()))
455
+ return {};
470
456
// Dispatch on keyword.
471
457
switch (keyWordIndex) {
472
458
case 0 : { // map
473
459
ir_detail::DimLvlMapParser cParser (parser);
474
460
auto res = cParser.parseDimLvlMap ();
475
- RETURN_ON_FAIL (res);
461
+ if (failed (res))
462
+ return {};
476
463
const auto &dlm = *res;
477
464
478
465
const Level lvlRank = dlm.getLvlRank ();
@@ -504,17 +491,27 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
504
491
}
505
492
case 1 : { // posWidth
506
493
Attribute attr;
507
- RETURN_ON_FAIL (parser.parseAttribute (attr))
494
+ if (failed (parser.parseAttribute (attr)))
495
+ return {};
508
496
auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
509
- ERROR_IF (!intAttr, " expected an integral position bitwidth" )
497
+ if (!intAttr) {
498
+ parser.emitError (parser.getNameLoc (),
499
+ " expected an integral position bitwidth" );
500
+ return {};
501
+ }
510
502
posWidth = intAttr.getInt ();
511
503
break ;
512
504
}
513
505
case 2 : { // crdWidth
514
506
Attribute attr;
515
- RETURN_ON_FAIL (parser.parseAttribute (attr))
507
+ if (failed (parser.parseAttribute (attr)))
508
+ return {};
516
509
auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
517
- ERROR_IF (!intAttr, " expected an integral index bitwidth" )
510
+ if (!intAttr) {
511
+ parser.emitError (parser.getNameLoc (),
512
+ " expected an integral index bitwidth" );
513
+ return {};
514
+ }
518
515
crdWidth = intAttr.getInt ();
519
516
break ;
520
517
}
@@ -524,10 +521,11 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
524
521
break ;
525
522
}
526
523
527
- RETURN_ON_FAIL (parser.parseRBrace ())
528
- RETURN_ON_FAIL (parser.parseGreater ())
529
- #undef ERROR_IF
530
- #undef RETURN_ON_FAIL
524
+ // Close "}>" part.
525
+ if (failed (parser.parseRBrace ()))
526
+ return {};
527
+ if (failed (parser.parseGreater ()))
528
+ return {};
531
529
532
530
// Construct struct-like storage for attribute.
533
531
if (!lvlToDim || lvlToDim.isEmpty ()) {
@@ -668,9 +666,9 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
668
666
function_ref<InFlightDiagnostic()> emitError) const {
669
667
// Check structural integrity. In particular, this ensures that the
670
668
// level-rank is coherent across all the fields.
671
- RETURN_FAILURE_IF_FAILED ( verify (emitError, getLvlTypes (), getDimToLvl (),
672
- getLvlToDim (), getPosWidth (), getCrdWidth (),
673
- getDimSlices ()))
669
+ if ( failed ( verify (emitError, getLvlTypes (), getDimToLvl (), getLvlToDim (),
670
+ getPosWidth (), getCrdWidth (), getDimSlices ())))
671
+ return failure ();
674
672
// Check integrity with tensor type specifics. In particular, we
675
673
// need only check that the dimension-rank of the tensor agrees with
676
674
// the dimension-rank of the encoding.
@@ -926,10 +924,6 @@ Level mlir::sparse_tensor::toStoredDim(RankedTensorType type, Dimension d) {
926
924
return toStoredDim (getSparseTensorEncoding (type), d);
927
925
}
928
926
929
- // ===----------------------------------------------------------------------===//
930
- // SparseTensorDialect Types.
931
- // ===----------------------------------------------------------------------===//
932
-
933
927
// / We normalized sparse tensor encoding attribute by always using
934
928
// / ordered/unique DLT such that "compressed_nu_no" and "compressed_nu" (as well
935
929
// / as other variants) lead to the same storage specifier type, and stripping
@@ -1340,9 +1334,8 @@ LogicalResult ToSliceStrideOp::verify() {
1340
1334
}
1341
1335
1342
1336
LogicalResult GetStorageSpecifierOp::verify () {
1343
- RETURN_FAILURE_IF_FAILED (verifySparsifierGetterSetter (
1344
- getSpecifierKind (), getLevel (), getSpecifier (), getOperation ()))
1345
- return success ();
1337
+ return verifySparsifierGetterSetter (getSpecifierKind (), getLevel (),
1338
+ getSpecifier (), getOperation ());
1346
1339
}
1347
1340
1348
1341
template <typename SpecifierOp>
@@ -1360,9 +1353,8 @@ OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
1360
1353
}
1361
1354
1362
1355
LogicalResult SetStorageSpecifierOp::verify () {
1363
- RETURN_FAILURE_IF_FAILED (verifySparsifierGetterSetter (
1364
- getSpecifierKind (), getLevel (), getSpecifier (), getOperation ()))
1365
- return success ();
1356
+ return verifySparsifierGetterSetter (getSpecifierKind (), getLevel (),
1357
+ getSpecifier (), getOperation ());
1366
1358
}
1367
1359
1368
1360
template <class T >
@@ -1404,20 +1396,23 @@ LogicalResult BinaryOp::verify() {
1404
1396
// Check correct number of block arguments and return type for each
1405
1397
// non-empty region.
1406
1398
if (!overlap.empty ()) {
1407
- RETURN_FAILURE_IF_FAILED (verifyNumBlockArgs (
1408
- this , overlap, " overlap" , TypeRange{leftType, rightType}, outputType))
1399
+ if (failed (verifyNumBlockArgs (this , overlap, " overlap" ,
1400
+ TypeRange{leftType, rightType}, outputType)))
1401
+ return failure ();
1409
1402
}
1410
1403
if (!left.empty ()) {
1411
- RETURN_FAILURE_IF_FAILED (
1412
- verifyNumBlockArgs (this , left, " left" , TypeRange{leftType}, outputType))
1404
+ if (failed (verifyNumBlockArgs (this , left, " left" , TypeRange{leftType},
1405
+ outputType)))
1406
+ return failure ();
1413
1407
} else if (getLeftIdentity ()) {
1414
1408
if (leftType != outputType)
1415
1409
return emitError (" left=identity requires first argument to have the same "
1416
1410
" type as the output" );
1417
1411
}
1418
1412
if (!right.empty ()) {
1419
- RETURN_FAILURE_IF_FAILED (verifyNumBlockArgs (
1420
- this , right, " right" , TypeRange{rightType}, outputType))
1413
+ if (failed (verifyNumBlockArgs (this , right, " right" , TypeRange{rightType},
1414
+ outputType)))
1415
+ return failure ();
1421
1416
} else if (getRightIdentity ()) {
1422
1417
if (rightType != outputType)
1423
1418
return emitError (" right=identity requires second argument to have the "
@@ -1434,13 +1429,15 @@ LogicalResult UnaryOp::verify() {
1434
1429
// non-empty region.
1435
1430
Region &present = getPresentRegion ();
1436
1431
if (!present.empty ()) {
1437
- RETURN_FAILURE_IF_FAILED (verifyNumBlockArgs (
1438
- this , present, " present" , TypeRange{inputType}, outputType))
1432
+ if (failed (verifyNumBlockArgs (this , present, " present" ,
1433
+ TypeRange{inputType}, outputType)))
1434
+ return failure ();
1439
1435
}
1440
1436
Region &absent = getAbsentRegion ();
1441
1437
if (!absent.empty ()) {
1442
- RETURN_FAILURE_IF_FAILED (
1443
- verifyNumBlockArgs (this , absent, " absent" , TypeRange{}, outputType))
1438
+ if (failed (verifyNumBlockArgs (this , absent, " absent" , TypeRange{},
1439
+ outputType)))
1440
+ return failure ();
1444
1441
// Absent branch can only yield invariant values.
1445
1442
Block *absentBlock = &absent.front ();
1446
1443
Block *parent = getOperation ()->getBlock ();
@@ -1655,22 +1652,18 @@ LogicalResult ReorderCOOOp::verify() {
1655
1652
1656
1653
LogicalResult ReduceOp::verify () {
1657
1654
Type inputType = getX ().getType ();
1658
- // Check correct number of block arguments and return type.
1659
1655
Region &formula = getRegion ();
1660
- RETURN_FAILURE_IF_FAILED (verifyNumBlockArgs (
1661
- this , formula, " reduce" , TypeRange{inputType, inputType}, inputType))
1662
- return success ();
1656
+ return verifyNumBlockArgs (this , formula, " reduce" ,
1657
+ TypeRange{inputType, inputType}, inputType);
1663
1658
}
1664
1659
1665
1660
LogicalResult SelectOp::verify () {
1666
1661
Builder b (getContext ());
1667
1662
Type inputType = getX ().getType ();
1668
1663
Type boolType = b.getI1Type ();
1669
- // Check correct number of block arguments and return type.
1670
1664
Region &formula = getRegion ();
1671
- RETURN_FAILURE_IF_FAILED (verifyNumBlockArgs (this , formula, " select" ,
1672
- TypeRange{inputType}, boolType))
1673
- return success ();
1665
+ return verifyNumBlockArgs (this , formula, " select" , TypeRange{inputType},
1666
+ boolType);
1674
1667
}
1675
1668
1676
1669
LogicalResult SortOp::verify () {
@@ -1725,8 +1718,6 @@ LogicalResult YieldOp::verify() {
1725
1718
" reduce, select or foreach" );
1726
1719
}
1727
1720
1728
- #undef RETURN_FAILURE_IF_FAILED
1729
-
1730
1721
// / Materialize a single constant operation from a given attribute value with
1731
1722
// / the desired resultant type.
1732
1723
Operation *SparseTensorDialect::materializeConstant (OpBuilder &builder,
0 commit comments