Skip to content

Commit 2e2011d

Browse files
authored
[mlir][sparse] avoid excessive macro magic (#70276)
The shorthands are not even always shorter and the code is less clear than when simply written out.
1 parent 76dea22 commit 2e2011d

File tree

1 file changed

+64
-73
lines changed

1 file changed

+64
-73
lines changed

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 64 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,6 @@
3434
using namespace mlir;
3535
using namespace mlir::sparse_tensor;
3636

37-
#define RETURN_FAILURE_IF_FAILED(X) \
38-
if (failed(X)) { \
39-
return failure(); \
40-
}
41-
4237
//===----------------------------------------------------------------------===//
4338
// Local convenience methods.
4439
//===----------------------------------------------------------------------===//
@@ -68,10 +63,6 @@ void StorageLayout::foreachField(
6863
llvm::function_ref<bool(FieldIndex, SparseTensorFieldKind, Level,
6964
DimLevelType)>
7065
callback) const {
71-
#define RETURN_ON_FALSE(fidx, kind, lvl, dlt) \
72-
if (!(callback(fidx, kind, lvl, dlt))) \
73-
return;
74-
7566
const auto lvlTypes = enc.getLvlTypes();
7667
const Level lvlRank = enc.getLvlRank();
7768
const Level cooStart = getCOOStart(enc);
@@ -81,21 +72,22 @@ void StorageLayout::foreachField(
8172
for (Level l = 0; l < end; l++) {
8273
const auto dlt = lvlTypes[l];
8374
if (isDLTWithPos(dlt)) {
84-
RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, dlt);
75+
if (!(callback(fieldIdx++, SparseTensorFieldKind::PosMemRef, l, dlt)))
76+
return;
8577
}
8678
if (isDLTWithCrd(dlt)) {
87-
RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt);
79+
if (!(callback(fieldIdx++, SparseTensorFieldKind::CrdMemRef, l, dlt)))
80+
return;
8881
}
8982
}
9083
// The values array.
91-
RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
92-
DimLevelType::Undef);
93-
84+
if (!(callback(fieldIdx++, SparseTensorFieldKind::ValMemRef, kInvalidLevel,
85+
DimLevelType::Undef)))
86+
return;
9487
// Put metadata at the end.
95-
RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
96-
DimLevelType::Undef);
97-
98-
#undef RETURN_ON_FALSE
88+
if (!(callback(fieldIdx++, SparseTensorFieldKind::StorageSpec, kInvalidLevel,
89+
DimLevelType::Undef)))
90+
return;
9991
}
10092

10193
void sparse_tensor::foreachFieldAndTypeInSparseTensor(
@@ -435,18 +427,11 @@ SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
435427
}
436428

437429
Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
438-
#define RETURN_ON_FAIL(stmt) \
439-
if (failed(stmt)) { \
440-
return {}; \
441-
}
442-
#define ERROR_IF(COND, MSG) \
443-
if (COND) { \
444-
parser.emitError(parser.getNameLoc(), MSG); \
445-
return {}; \
446-
}
447-
448-
RETURN_ON_FAIL(parser.parseLess())
449-
RETURN_ON_FAIL(parser.parseLBrace())
430+
// Open "<{" part.
431+
if (failed(parser.parseLess()))
432+
return {};
433+
if (failed(parser.parseLBrace()))
434+
return {};
450435

451436
// Process the data from the parsed dictionary value into struct-like data.
452437
SmallVector<DimLevelType> lvlTypes;
@@ -466,13 +451,15 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
466451
}
467452
unsigned keyWordIndex = it - keys.begin();
468453
// Consume the `=` after keys
469-
RETURN_ON_FAIL(parser.parseEqual())
454+
if (failed(parser.parseEqual()))
455+
return {};
470456
// Dispatch on keyword.
471457
switch (keyWordIndex) {
472458
case 0: { // map
473459
ir_detail::DimLvlMapParser cParser(parser);
474460
auto res = cParser.parseDimLvlMap();
475-
RETURN_ON_FAIL(res);
461+
if (failed(res))
462+
return {};
476463
const auto &dlm = *res;
477464

478465
const Level lvlRank = dlm.getLvlRank();
@@ -504,17 +491,27 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
504491
}
505492
case 1: { // posWidth
506493
Attribute attr;
507-
RETURN_ON_FAIL(parser.parseAttribute(attr))
494+
if (failed(parser.parseAttribute(attr)))
495+
return {};
508496
auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
509-
ERROR_IF(!intAttr, "expected an integral position bitwidth")
497+
if (!intAttr) {
498+
parser.emitError(parser.getNameLoc(),
499+
"expected an integral position bitwidth");
500+
return {};
501+
}
510502
posWidth = intAttr.getInt();
511503
break;
512504
}
513505
case 2: { // crdWidth
514506
Attribute attr;
515-
RETURN_ON_FAIL(parser.parseAttribute(attr))
507+
if (failed(parser.parseAttribute(attr)))
508+
return {};
516509
auto intAttr = llvm::dyn_cast<IntegerAttr>(attr);
517-
ERROR_IF(!intAttr, "expected an integral index bitwidth")
510+
if (!intAttr) {
511+
parser.emitError(parser.getNameLoc(),
512+
"expected an integral index bitwidth");
513+
return {};
514+
}
518515
crdWidth = intAttr.getInt();
519516
break;
520517
}
@@ -524,10 +521,11 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
524521
break;
525522
}
526523

527-
RETURN_ON_FAIL(parser.parseRBrace())
528-
RETURN_ON_FAIL(parser.parseGreater())
529-
#undef ERROR_IF
530-
#undef RETURN_ON_FAIL
524+
// Close "}>" part.
525+
if (failed(parser.parseRBrace()))
526+
return {};
527+
if (failed(parser.parseGreater()))
528+
return {};
531529

532530
// Construct struct-like storage for attribute.
533531
if (!lvlToDim || lvlToDim.isEmpty()) {
@@ -668,9 +666,9 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
668666
function_ref<InFlightDiagnostic()> emitError) const {
669667
// Check structural integrity. In particular, this ensures that the
670668
// level-rank is coherent across all the fields.
671-
RETURN_FAILURE_IF_FAILED(verify(emitError, getLvlTypes(), getDimToLvl(),
672-
getLvlToDim(), getPosWidth(), getCrdWidth(),
673-
getDimSlices()))
669+
if (failed(verify(emitError, getLvlTypes(), getDimToLvl(), getLvlToDim(),
670+
getPosWidth(), getCrdWidth(), getDimSlices())))
671+
return failure();
674672
// Check integrity with tensor type specifics. In particular, we
675673
// need only check that the dimension-rank of the tensor agrees with
676674
// the dimension-rank of the encoding.
@@ -926,10 +924,6 @@ Level mlir::sparse_tensor::toStoredDim(RankedTensorType type, Dimension d) {
926924
return toStoredDim(getSparseTensorEncoding(type), d);
927925
}
928926

929-
//===----------------------------------------------------------------------===//
930-
// SparseTensorDialect Types.
931-
//===----------------------------------------------------------------------===//
932-
933927
/// We normalized sparse tensor encoding attribute by always using
934928
/// ordered/unique DLT such that "compressed_nu_no" and "compressed_nu" (as well
935929
/// as other variants) lead to the same storage specifier type, and stripping
@@ -1340,9 +1334,8 @@ LogicalResult ToSliceStrideOp::verify() {
13401334
}
13411335

13421336
LogicalResult GetStorageSpecifierOp::verify() {
1343-
RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter(
1344-
getSpecifierKind(), getLevel(), getSpecifier(), getOperation()))
1345-
return success();
1337+
return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1338+
getSpecifier(), getOperation());
13461339
}
13471340

13481341
template <typename SpecifierOp>
@@ -1360,9 +1353,8 @@ OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) {
13601353
}
13611354

13621355
LogicalResult SetStorageSpecifierOp::verify() {
1363-
RETURN_FAILURE_IF_FAILED(verifySparsifierGetterSetter(
1364-
getSpecifierKind(), getLevel(), getSpecifier(), getOperation()))
1365-
return success();
1356+
return verifySparsifierGetterSetter(getSpecifierKind(), getLevel(),
1357+
getSpecifier(), getOperation());
13661358
}
13671359

13681360
template <class T>
@@ -1404,20 +1396,23 @@ LogicalResult BinaryOp::verify() {
14041396
// Check correct number of block arguments and return type for each
14051397
// non-empty region.
14061398
if (!overlap.empty()) {
1407-
RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
1408-
this, overlap, "overlap", TypeRange{leftType, rightType}, outputType))
1399+
if (failed(verifyNumBlockArgs(this, overlap, "overlap",
1400+
TypeRange{leftType, rightType}, outputType)))
1401+
return failure();
14091402
}
14101403
if (!left.empty()) {
1411-
RETURN_FAILURE_IF_FAILED(
1412-
verifyNumBlockArgs(this, left, "left", TypeRange{leftType}, outputType))
1404+
if (failed(verifyNumBlockArgs(this, left, "left", TypeRange{leftType},
1405+
outputType)))
1406+
return failure();
14131407
} else if (getLeftIdentity()) {
14141408
if (leftType != outputType)
14151409
return emitError("left=identity requires first argument to have the same "
14161410
"type as the output");
14171411
}
14181412
if (!right.empty()) {
1419-
RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
1420-
this, right, "right", TypeRange{rightType}, outputType))
1413+
if (failed(verifyNumBlockArgs(this, right, "right", TypeRange{rightType},
1414+
outputType)))
1415+
return failure();
14211416
} else if (getRightIdentity()) {
14221417
if (rightType != outputType)
14231418
return emitError("right=identity requires second argument to have the "
@@ -1434,13 +1429,15 @@ LogicalResult UnaryOp::verify() {
14341429
// non-empty region.
14351430
Region &present = getPresentRegion();
14361431
if (!present.empty()) {
1437-
RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
1438-
this, present, "present", TypeRange{inputType}, outputType))
1432+
if (failed(verifyNumBlockArgs(this, present, "present",
1433+
TypeRange{inputType}, outputType)))
1434+
return failure();
14391435
}
14401436
Region &absent = getAbsentRegion();
14411437
if (!absent.empty()) {
1442-
RETURN_FAILURE_IF_FAILED(
1443-
verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType))
1438+
if (failed(verifyNumBlockArgs(this, absent, "absent", TypeRange{},
1439+
outputType)))
1440+
return failure();
14441441
// Absent branch can only yield invariant values.
14451442
Block *absentBlock = &absent.front();
14461443
Block *parent = getOperation()->getBlock();
@@ -1655,22 +1652,18 @@ LogicalResult ReorderCOOOp::verify() {
16551652

16561653
LogicalResult ReduceOp::verify() {
16571654
Type inputType = getX().getType();
1658-
// Check correct number of block arguments and return type.
16591655
Region &formula = getRegion();
1660-
RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(
1661-
this, formula, "reduce", TypeRange{inputType, inputType}, inputType))
1662-
return success();
1656+
return verifyNumBlockArgs(this, formula, "reduce",
1657+
TypeRange{inputType, inputType}, inputType);
16631658
}
16641659

16651660
LogicalResult SelectOp::verify() {
16661661
Builder b(getContext());
16671662
Type inputType = getX().getType();
16681663
Type boolType = b.getI1Type();
1669-
// Check correct number of block arguments and return type.
16701664
Region &formula = getRegion();
1671-
RETURN_FAILURE_IF_FAILED(verifyNumBlockArgs(this, formula, "select",
1672-
TypeRange{inputType}, boolType))
1673-
return success();
1665+
return verifyNumBlockArgs(this, formula, "select", TypeRange{inputType},
1666+
boolType);
16741667
}
16751668

16761669
LogicalResult SortOp::verify() {
@@ -1725,8 +1718,6 @@ LogicalResult YieldOp::verify() {
17251718
"reduce, select or foreach");
17261719
}
17271720

1728-
#undef RETURN_FAILURE_IF_FAILED
1729-
17301721
/// Materialize a single constant operation from a given attribute value with
17311722
/// the desired resultant type.
17321723
Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,

0 commit comments

Comments
 (0)