@@ -1356,50 +1356,54 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
1356
1356
// See buildLattices() for an explanation of rejecting certain
1357
1357
// division and shift operations.
1358
1358
if (def->getNumOperands () == 2 ) {
1359
- const auto [x, xDepSp] = buildTensorExp (op, def->getOperand (0 ));
1360
- const auto [y, yDepSp] = buildTensorExp (op, def->getOperand (1 ));
1361
- bool hasSpDep = xDepSp || yDepSp;
1359
+ const auto [x, xSpVals] = buildTensorExp (op, def->getOperand (0 ));
1360
+ const auto [y, ySpVals] = buildTensorExp (op, def->getOperand (1 ));
1361
+ // For a conjunctive operation, it yields a "sparse" result if any operand
1362
+ // is sparse. For a disjunctive operation, it yields a "sparse" result if
1363
+ // all operands are sparse.
1364
+ bool conjSpVals = xSpVals || ySpVals;
1365
+ bool disjSpVals = xSpVals && ySpVals;
1362
1366
if (x.has_value () && y.has_value ()) {
1363
1367
const ExprId e0 = *x;
1364
1368
const ExprId e1 = *y;
1365
1369
if (isa<arith::MulFOp>(def))
1366
- return {addExp (TensorExp::Kind::kMulF , e0 , e1 ), hasSpDep };
1370
+ return {addExp (TensorExp::Kind::kMulF , e0 , e1 ), conjSpVals };
1367
1371
if (isa<complex::MulOp>(def))
1368
- return {addExp (TensorExp::Kind::kMulC , e0 , e1 ), hasSpDep };
1372
+ return {addExp (TensorExp::Kind::kMulC , e0 , e1 ), conjSpVals };
1369
1373
if (isa<arith::MulIOp>(def))
1370
- return {addExp (TensorExp::Kind::kMulI , e0 , e1 ), hasSpDep };
1374
+ return {addExp (TensorExp::Kind::kMulI , e0 , e1 ), conjSpVals };
1371
1375
if (isa<arith::DivFOp>(def) && !maybeZero (e1 ))
1372
- return {addExp (TensorExp::Kind::kDivF , e0 , e1 ), hasSpDep };
1376
+ return {addExp (TensorExp::Kind::kDivF , e0 , e1 ), conjSpVals };
1373
1377
if (isa<complex::DivOp>(def) && !maybeZero (e1 ))
1374
- return {addExp (TensorExp::Kind::kDivC , e0 , e1 ), hasSpDep };
1378
+ return {addExp (TensorExp::Kind::kDivC , e0 , e1 ), conjSpVals };
1375
1379
if (isa<arith::DivSIOp>(def) && !maybeZero (e1 ))
1376
- return {addExp (TensorExp::Kind::kDivS , e0 , e1 ), hasSpDep };
1380
+ return {addExp (TensorExp::Kind::kDivS , e0 , e1 ), conjSpVals };
1377
1381
if (isa<arith::DivUIOp>(def) && !maybeZero (e1 ))
1378
- return {addExp (TensorExp::Kind::kDivU , e0 , e1 ), hasSpDep };
1382
+ return {addExp (TensorExp::Kind::kDivU , e0 , e1 ), conjSpVals };
1379
1383
if (isa<arith::AddFOp>(def))
1380
- return {addExp (TensorExp::Kind::kAddF , e0 , e1 ), hasSpDep };
1384
+ return {addExp (TensorExp::Kind::kAddF , e0 , e1 ), disjSpVals };
1381
1385
if (isa<complex::AddOp>(def))
1382
- return {addExp (TensorExp::Kind::kAddC , e0 , e1 ), hasSpDep };
1386
+ return {addExp (TensorExp::Kind::kAddC , e0 , e1 ), disjSpVals };
1383
1387
if (isa<arith::AddIOp>(def))
1384
- return {addExp (TensorExp::Kind::kAddI , e0 , e1 ), hasSpDep };
1388
+ return {addExp (TensorExp::Kind::kAddI , e0 , e1 ), disjSpVals };
1385
1389
if (isa<arith::SubFOp>(def))
1386
- return {addExp (TensorExp::Kind::kSubF , e0 , e1 ), hasSpDep };
1390
+ return {addExp (TensorExp::Kind::kSubF , e0 , e1 ), disjSpVals };
1387
1391
if (isa<complex::SubOp>(def))
1388
- return {addExp (TensorExp::Kind::kSubC , e0 , e1 ), hasSpDep };
1392
+ return {addExp (TensorExp::Kind::kSubC , e0 , e1 ), disjSpVals };
1389
1393
if (isa<arith::SubIOp>(def))
1390
- return {addExp (TensorExp::Kind::kSubI , e0 , e1 ), hasSpDep };
1394
+ return {addExp (TensorExp::Kind::kSubI , e0 , e1 ), disjSpVals };
1391
1395
if (isa<arith::AndIOp>(def))
1392
- return {addExp (TensorExp::Kind::kAndI , e0 , e1 ), hasSpDep };
1396
+ return {addExp (TensorExp::Kind::kAndI , e0 , e1 ), conjSpVals };
1393
1397
if (isa<arith::OrIOp>(def))
1394
- return {addExp (TensorExp::Kind::kOrI , e0 , e1 ), hasSpDep };
1398
+ return {addExp (TensorExp::Kind::kOrI , e0 , e1 ), disjSpVals };
1395
1399
if (isa<arith::XOrIOp>(def))
1396
- return {addExp (TensorExp::Kind::kXorI , e0 , e1 ), hasSpDep };
1400
+ return {addExp (TensorExp::Kind::kXorI , e0 , e1 ), disjSpVals };
1397
1401
if (isa<arith::ShRSIOp>(def) && isInvariant (e1 ))
1398
- return {addExp (TensorExp::Kind::kShrS , e0 , e1 ), hasSpDep };
1402
+ return {addExp (TensorExp::Kind::kShrS , e0 , e1 ), conjSpVals };
1399
1403
if (isa<arith::ShRUIOp>(def) && isInvariant (e1 ))
1400
- return {addExp (TensorExp::Kind::kShrU , e0 , e1 ), hasSpDep };
1404
+ return {addExp (TensorExp::Kind::kShrU , e0 , e1 ), conjSpVals };
1401
1405
if (isa<arith::ShLIOp>(def) && isInvariant (e1 ))
1402
- return {addExp (TensorExp::Kind::kShlI , e0 , e1 ), hasSpDep };
1406
+ return {addExp (TensorExp::Kind::kShlI , e0 , e1 ), conjSpVals };
1403
1407
if (auto ci = dyn_cast<arith::CmpIOp>(def)) {
1404
1408
if (ci.getPredicate () == arith::CmpIPredicate::eq &&
1405
1409
ci.getPredicate () == arith::CmpIPredicate::sle &&
@@ -1413,7 +1417,7 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
1413
1417
1414
1418
auto e = addExp (TensorExp::Kind::kCmpI , e0 , e1 , nullptr ,
1415
1419
ci.getPredicateAttr ());
1416
- return {e, hasSpDep };
1420
+ return {e, conjSpVals };
1417
1421
}
1418
1422
if (auto cf = dyn_cast<arith::CmpFOp>(def)) {
1419
1423
if (cf.getPredicate () == arith::CmpFPredicate::OEQ &&
@@ -1431,15 +1435,15 @@ Merger::buildTensorExp(linalg::GenericOp op, Value v) {
1431
1435
}
1432
1436
auto e = addExp (TensorExp::Kind::kCmpF , e0 , e1 , nullptr ,
1433
1437
cf.getPredicateAttr ());
1434
- return {e, hasSpDep };
1438
+ return {e, conjSpVals };
1435
1439
}
1436
1440
if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
1437
1441
if (isAdmissibleBranch (binop, binop.getOverlapRegion ()) &&
1438
1442
(binop.getLeftIdentity () ||
1439
1443
isAdmissibleBranch (binop, binop.getLeftRegion ())) &&
1440
1444
(binop.getRightIdentity () ||
1441
1445
isAdmissibleBranch (binop, binop.getRightRegion ())))
1442
- return {addExp (TensorExp::Kind::kBinary , e0 , e1 , def), hasSpDep };
1446
+ return {addExp (TensorExp::Kind::kBinary , e0 , e1 , def), conjSpVals };
1443
1447
}
1444
1448
}
1445
1449
}
0 commit comments