Skip to content

Commit 6bc7c9d

Browse files
authored
[mlir][sparse] infer returned type for sparse_tensor.to_[buffer] ops (#83343)
The sparse structure buffers might not always be memrefs with rank == 1 with the presence of batch levels.
1 parent 43b7dfc commit 6bc7c9d

File tree

5 files changed

+132
-163
lines changed

5 files changed

+132
-163
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,10 @@ def SparseTensor_ReinterpretMapOp : SparseTensor_Op<"reinterpret_map", [NoMemory
257257
let hasVerifier = 1;
258258
}
259259

260-
def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions", [Pure]>,
260+
def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions",
261+
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
261262
Arguments<(ins AnySparseTensor:$tensor, LevelAttr:$level)>,
262-
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
263+
Results<(outs AnyNon0RankedMemRef:$result)> {
263264
let summary = "Extracts the `level`-th positions array of the `tensor`";
264265
let description = [{
265266
Returns the positions array of the tensor's storage at the given
@@ -283,9 +284,10 @@ def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions", [Pure]>,
283284
let hasVerifier = 1;
284285
}
285286

286-
def SparseTensor_ToCoordinatesOp : SparseTensor_Op<"coordinates", [Pure]>,
287+
def SparseTensor_ToCoordinatesOp : SparseTensor_Op<"coordinates",
288+
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
287289
Arguments<(ins AnySparseTensor:$tensor, LevelAttr:$level)>,
288-
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
290+
Results<(outs AnyNon0RankedMemRef:$result)> {
289291
let summary = "Extracts the `level`-th coordinates array of the `tensor`";
290292
let description = [{
291293
Returns the coordinates array of the tensor's storage at the given
@@ -309,9 +311,10 @@ def SparseTensor_ToCoordinatesOp : SparseTensor_Op<"coordinates", [Pure]>,
309311
let hasVerifier = 1;
310312
}
311313

312-
def SparseTensor_ToCoordinatesBufferOp : SparseTensor_Op<"coordinates_buffer", [Pure]>,
314+
def SparseTensor_ToCoordinatesBufferOp : SparseTensor_Op<"coordinates_buffer",
315+
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
313316
Arguments<(ins AnySparseTensor:$tensor)>,
314-
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
317+
Results<(outs AnyNon0RankedMemRef:$result)> {
315318
let summary = "Extracts the linear coordinates array from a tensor";
316319
let description = [{
317320
Returns the linear coordinates array for a sparse tensor with
@@ -340,9 +343,10 @@ def SparseTensor_ToCoordinatesBufferOp : SparseTensor_Op<"coordinates_buffer", [
340343
let hasVerifier = 1;
341344
}
342345

343-
def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [Pure]>,
346+
def SparseTensor_ToValuesOp : SparseTensor_Op<"values",
347+
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
344348
Arguments<(ins AnySparseTensor:$tensor)>,
345-
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
349+
Results<(outs AnyNon0RankedMemRef:$result)> {
346350
let summary = "Extracts numerical values array from a tensor";
347351
let description = [{
348352
Returns the values array of the sparse storage format for the given

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1445,6 +1445,38 @@ OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
14451445
return {};
14461446
}
14471447

1448+
template <typename ToBufferOp>
1449+
static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
1450+
OpaqueProperties prop,
1451+
RegionRange region,
1452+
SmallVectorImpl<mlir::Type> &ret) {
1453+
typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
1454+
SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
1455+
Type elemTp = nullptr;
1456+
bool withStride = false;
1457+
if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
1458+
elemTp = stt.getPosType();
1459+
} else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
1460+
std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
1461+
elemTp = stt.getCrdType();
1462+
if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
1463+
withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
1464+
} else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
1465+
elemTp = stt.getElementType();
1466+
}
1467+
1468+
assert(elemTp && "unhandled operation.");
1469+
SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
1470+
bufShape.push_back(ShapedType::kDynamic);
1471+
1472+
auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
1473+
stt.getContext(), ShapedType::kDynamic,
1474+
{ShapedType::kDynamic})
1475+
: StridedLayoutAttr();
1476+
ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
1477+
return success();
1478+
}
1479+
14481480
LogicalResult ToPositionsOp::verify() {
14491481
auto stt = getSparseTensorType(getTensor());
14501482
if (failed(lvlIsInBounds(getLevel(), getTensor())))
@@ -1454,6 +1486,14 @@ LogicalResult ToPositionsOp::verify() {
14541486
return success();
14551487
}
14561488

1489+
LogicalResult
1490+
ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1491+
ValueRange ops, DictionaryAttr attr,
1492+
OpaqueProperties prop, RegionRange region,
1493+
SmallVectorImpl<mlir::Type> &ret) {
1494+
return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
1495+
}
1496+
14571497
LogicalResult ToCoordinatesOp::verify() {
14581498
auto stt = getSparseTensorType(getTensor());
14591499
if (failed(lvlIsInBounds(getLevel(), getTensor())))
@@ -1463,13 +1503,29 @@ LogicalResult ToCoordinatesOp::verify() {
14631503
return success();
14641504
}
14651505

1506+
LogicalResult
1507+
ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
1508+
ValueRange ops, DictionaryAttr attr,
1509+
OpaqueProperties prop, RegionRange region,
1510+
SmallVectorImpl<mlir::Type> &ret) {
1511+
return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
1512+
}
1513+
14661514
LogicalResult ToCoordinatesBufferOp::verify() {
14671515
auto stt = getSparseTensorType(getTensor());
14681516
if (stt.getAoSCOOStart() >= stt.getLvlRank())
14691517
return emitError("expected sparse tensor with a COO region");
14701518
return success();
14711519
}
14721520

1521+
LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
1522+
MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
1523+
DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
1524+
SmallVectorImpl<mlir::Type> &ret) {
1525+
return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
1526+
ret);
1527+
}
1528+
14731529
LogicalResult ToValuesOp::verify() {
14741530
auto stt = getSparseTensorType(getTensor());
14751531
auto mtp = getMemRefType(getResult());
@@ -1478,6 +1534,15 @@ LogicalResult ToValuesOp::verify() {
14781534
return success();
14791535
}
14801536

1537+
LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
1538+
std::optional<Location> loc,
1539+
ValueRange ops, DictionaryAttr attr,
1540+
OpaqueProperties prop,
1541+
RegionRange region,
1542+
SmallVectorImpl<mlir::Type> &ret) {
1543+
return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
1544+
}
1545+
14811546
LogicalResult ToSliceOffsetOp::verify() {
14821547
auto rank = getRankedTensorType(getSlice()).getRank();
14831548
if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,17 +1058,9 @@ class SparseToCoordinatesConverter
10581058
// Replace the requested coordinates access with corresponding field.
10591059
// The cast_op is inserted by type converter to intermix 1:N type
10601060
// conversion.
1061-
Location loc = op.getLoc();
10621061
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
1063-
Value field = desc.getCrdMemRefOrView(rewriter, loc, op.getLevel());
1064-
1065-
// Insert a cast to bridge the actual type to the user expected type. If the
1066-
// actual type and the user expected type aren't compatible, the compiler or
1067-
// the runtime will issue an error.
1068-
Type resType = op.getResult().getType();
1069-
if (resType != field.getType())
1070-
field = rewriter.create<memref::CastOp>(loc, resType, field);
1071-
rewriter.replaceOp(op, field);
1062+
rewriter.replaceOp(
1063+
op, desc.getCrdMemRefOrView(rewriter, op.getLoc(), op.getLevel()));
10721064

10731065
return success();
10741066
}

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -618,10 +618,10 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
618618
rewriter.create<vector::PrintOp>(loc, nse);
619619
// Use the "codegen" foreach loop construct to iterate over
620620
// all typical sparse tensor components for printing.
621-
foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc,
622-
&tensor](Type tp, FieldIndex,
623-
SparseTensorFieldKind kind,
624-
Level l, LevelType) {
621+
foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc, &tensor,
622+
&stt](Type, FieldIndex,
623+
SparseTensorFieldKind kind,
624+
Level l, LevelType) {
625625
switch (kind) {
626626
case SparseTensorFieldKind::StorageSpec: {
627627
break;
@@ -632,8 +632,8 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
632632
rewriter.create<vector::PrintOp>(
633633
loc, lvl, vector::PrintPunctuation::NoPunctuation);
634634
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
635-
auto pos = rewriter.create<ToPositionsOp>(loc, tp, tensor, l);
636-
printContents(rewriter, loc, tp, pos);
635+
auto pos = rewriter.create<ToPositionsOp>(loc, tensor, l);
636+
printContents(rewriter, loc, pos);
637637
break;
638638
}
639639
case SparseTensorFieldKind::CrdMemRef: {
@@ -642,15 +642,20 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
642642
rewriter.create<vector::PrintOp>(
643643
loc, lvl, vector::PrintPunctuation::NoPunctuation);
644644
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
645-
auto crd = rewriter.create<ToCoordinatesOp>(loc, tp, tensor, l);
646-
printContents(rewriter, loc, tp, crd);
645+
Value crd = nullptr;
646+
// TODO: eliminates ToCoordinateBufferOp!
647+
if (stt.getAoSCOOStart() == l)
648+
crd = rewriter.create<ToCoordinatesBufferOp>(loc, tensor);
649+
else
650+
crd = rewriter.create<ToCoordinatesOp>(loc, tensor, l);
651+
printContents(rewriter, loc, crd);
647652
break;
648653
}
649654
case SparseTensorFieldKind::ValMemRef: {
650655
rewriter.create<vector::PrintOp>(loc,
651656
rewriter.getStringAttr("values : "));
652-
auto val = rewriter.create<ToValuesOp>(loc, tp, tensor);
653-
printContents(rewriter, loc, tp, val);
657+
auto val = rewriter.create<ToValuesOp>(loc, tensor);
658+
printContents(rewriter, loc, val);
654659
break;
655660
}
656661
}
@@ -670,7 +675,7 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
670675
//
671676
// Generates code to print:
672677
// ( a0, a1, ... )
673-
static void printContents(PatternRewriter &rewriter, Location loc, Type tp,
678+
static void printContents(PatternRewriter &rewriter, Location loc,
674679
Value vec) {
675680
// Open bracket.
676681
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);

0 commit comments

Comments
 (0)