Skip to content

[mlir][sparse] infer returned type for sparse_tensor.to_[buffer] ops #83343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,10 @@ def SparseTensor_ReinterpretMapOp : SparseTensor_Op<"reinterpret_map", [NoMemory
let hasVerifier = 1;
}

def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions", [Pure]>,
def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions",
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
Arguments<(ins AnySparseTensor:$tensor, LevelAttr:$level)>,
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
Results<(outs AnyNon0RankedMemRef:$result)> {
let summary = "Extracts the `level`-th positions array of the `tensor`";
let description = [{
Returns the positions array of the tensor's storage at the given
Expand All @@ -283,9 +284,10 @@ def SparseTensor_ToPositionsOp : SparseTensor_Op<"positions", [Pure]>,
let hasVerifier = 1;
}

def SparseTensor_ToCoordinatesOp : SparseTensor_Op<"coordinates", [Pure]>,
def SparseTensor_ToCoordinatesOp : SparseTensor_Op<"coordinates",
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
Arguments<(ins AnySparseTensor:$tensor, LevelAttr:$level)>,
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
Results<(outs AnyNon0RankedMemRef:$result)> {
let summary = "Extracts the `level`-th coordinates array of the `tensor`";
let description = [{
Returns the coordinates array of the tensor's storage at the given
Expand All @@ -309,9 +311,10 @@ def SparseTensor_ToCoordinatesOp : SparseTensor_Op<"coordinates", [Pure]>,
let hasVerifier = 1;
}

def SparseTensor_ToCoordinatesBufferOp : SparseTensor_Op<"coordinates_buffer", [Pure]>,
def SparseTensor_ToCoordinatesBufferOp : SparseTensor_Op<"coordinates_buffer",
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
Arguments<(ins AnySparseTensor:$tensor)>,
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
Results<(outs AnyNon0RankedMemRef:$result)> {
let summary = "Extracts the linear coordinates array from a tensor";
let description = [{
Returns the linear coordinates array for a sparse tensor with
Expand Down Expand Up @@ -340,9 +343,10 @@ def SparseTensor_ToCoordinatesBufferOp : SparseTensor_Op<"coordinates_buffer", [
let hasVerifier = 1;
}

def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [Pure]>,
def SparseTensor_ToValuesOp : SparseTensor_Op<"values",
[Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]>,
Arguments<(ins AnySparseTensor:$tensor)>,
Results<(outs AnyStridedMemRefOfRank<1>:$result)> {
Results<(outs AnyNon0RankedMemRef:$result)> {
let summary = "Extracts numerical values array from a tensor";
let description = [{
Returns the values array of the sparse storage format for the given
Expand Down
65 changes: 65 additions & 0 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1445,6 +1445,38 @@ OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
return {};
}

template <typename ToBufferOp>
static LogicalResult inferSparseBufferType(ValueRange ops, DictionaryAttr attr,
OpaqueProperties prop,
RegionRange region,
SmallVectorImpl<mlir::Type> &ret) {
typename ToBufferOp::Adaptor adaptor(ops, attr, prop, region);
SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
Type elemTp = nullptr;
bool withStride = false;
if constexpr (std::is_same_v<ToBufferOp, ToPositionsOp>) {
elemTp = stt.getPosType();
} else if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp> ||
std::is_same_v<ToBufferOp, ToCoordinatesBufferOp>) {
elemTp = stt.getCrdType();
if constexpr (std::is_same_v<ToBufferOp, ToCoordinatesOp>)
withStride = stt.getAoSCOOStart() <= adaptor.getLevel();
} else if constexpr (std::is_same_v<ToBufferOp, ToValuesOp>) {
elemTp = stt.getElementType();
}

assert(elemTp && "unhandled operation.");
SmallVector<int64_t> bufShape = stt.getBatchLvlShape();
bufShape.push_back(ShapedType::kDynamic);

auto layout = withStride ? StridedLayoutAttr::StridedLayoutAttr::get(
stt.getContext(), ShapedType::kDynamic,
{ShapedType::kDynamic})
: StridedLayoutAttr();
ret.emplace_back(MemRefType::get(bufShape, elemTp, layout));
return success();
}

LogicalResult ToPositionsOp::verify() {
auto stt = getSparseTensorType(getTensor());
if (failed(lvlIsInBounds(getLevel(), getTensor())))
Expand All @@ -1454,6 +1486,14 @@ LogicalResult ToPositionsOp::verify() {
return success();
}

LogicalResult
ToPositionsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
ValueRange ops, DictionaryAttr attr,
OpaqueProperties prop, RegionRange region,
SmallVectorImpl<mlir::Type> &ret) {
return inferSparseBufferType<ToPositionsOp>(ops, attr, prop, region, ret);
}

LogicalResult ToCoordinatesOp::verify() {
auto stt = getSparseTensorType(getTensor());
if (failed(lvlIsInBounds(getLevel(), getTensor())))
Expand All @@ -1463,13 +1503,29 @@ LogicalResult ToCoordinatesOp::verify() {
return success();
}

LogicalResult
ToCoordinatesOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
ValueRange ops, DictionaryAttr attr,
OpaqueProperties prop, RegionRange region,
SmallVectorImpl<mlir::Type> &ret) {
return inferSparseBufferType<ToCoordinatesOp>(ops, attr, prop, region, ret);
}

LogicalResult ToCoordinatesBufferOp::verify() {
auto stt = getSparseTensorType(getTensor());
if (stt.getAoSCOOStart() >= stt.getLvlRank())
return emitError("expected sparse tensor with a COO region");
return success();
}

LogicalResult ToCoordinatesBufferOp::inferReturnTypes(
MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
SmallVectorImpl<mlir::Type> &ret) {
return inferSparseBufferType<ToCoordinatesBufferOp>(ops, attr, prop, region,
ret);
}

LogicalResult ToValuesOp::verify() {
auto stt = getSparseTensorType(getTensor());
auto mtp = getMemRefType(getResult());
Expand All @@ -1478,6 +1534,15 @@ LogicalResult ToValuesOp::verify() {
return success();
}

LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
std::optional<Location> loc,
ValueRange ops, DictionaryAttr attr,
OpaqueProperties prop,
RegionRange region,
SmallVectorImpl<mlir::Type> &ret) {
return inferSparseBufferType<ToValuesOp>(ops, attr, prop, region, ret);
}

LogicalResult ToSliceOffsetOp::verify() {
auto rank = getRankedTensorType(getSlice()).getRank();
if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1058,17 +1058,9 @@ class SparseToCoordinatesConverter
// Replace the requested coordinates access with corresponding field.
// The cast_op is inserted by type converter to intermix 1:N type
// conversion.
Location loc = op.getLoc();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
Value field = desc.getCrdMemRefOrView(rewriter, loc, op.getLevel());

// Insert a cast to bridge the actual type to the user expected type. If the
// actual type and the user expected type aren't compatible, the compiler or
// the runtime will issue an error.
Type resType = op.getResult().getType();
if (resType != field.getType())
field = rewriter.create<memref::CastOp>(loc, resType, field);
rewriter.replaceOp(op, field);
rewriter.replaceOp(
op, desc.getCrdMemRefOrView(rewriter, op.getLoc(), op.getLevel()));

return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -618,10 +618,10 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
rewriter.create<vector::PrintOp>(loc, nse);
// Use the "codegen" foreach loop construct to iterate over
// all typical sparse tensor components for printing.
foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc,
&tensor](Type tp, FieldIndex,
SparseTensorFieldKind kind,
Level l, LevelType) {
foreachFieldAndTypeInSparseTensor(stt, [&rewriter, &loc, &tensor,
&stt](Type, FieldIndex,
SparseTensorFieldKind kind,
Level l, LevelType) {
switch (kind) {
case SparseTensorFieldKind::StorageSpec: {
break;
Expand All @@ -632,8 +632,8 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
rewriter.create<vector::PrintOp>(
loc, lvl, vector::PrintPunctuation::NoPunctuation);
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
auto pos = rewriter.create<ToPositionsOp>(loc, tp, tensor, l);
printContents(rewriter, loc, tp, pos);
auto pos = rewriter.create<ToPositionsOp>(loc, tensor, l);
printContents(rewriter, loc, pos);
break;
}
case SparseTensorFieldKind::CrdMemRef: {
Expand All @@ -642,15 +642,20 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
rewriter.create<vector::PrintOp>(
loc, lvl, vector::PrintPunctuation::NoPunctuation);
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
auto crd = rewriter.create<ToCoordinatesOp>(loc, tp, tensor, l);
printContents(rewriter, loc, tp, crd);
Value crd = nullptr;
// TODO: eliminates ToCoordinateBufferOp!
if (stt.getAoSCOOStart() == l)
crd = rewriter.create<ToCoordinatesBufferOp>(loc, tensor);
else
crd = rewriter.create<ToCoordinatesOp>(loc, tensor, l);
printContents(rewriter, loc, crd);
break;
}
case SparseTensorFieldKind::ValMemRef: {
rewriter.create<vector::PrintOp>(loc,
rewriter.getStringAttr("values : "));
auto val = rewriter.create<ToValuesOp>(loc, tp, tensor);
printContents(rewriter, loc, tp, val);
auto val = rewriter.create<ToValuesOp>(loc, tensor);
printContents(rewriter, loc, val);
break;
}
}
Expand All @@ -670,7 +675,7 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
//
// Generates code to print:
// ( a0, a1, ... )
static void printContents(PatternRewriter &rewriter, Location loc, Type tp,
static void printContents(PatternRewriter &rewriter, Location loc,
Value vec) {
// Open bracket.
rewriter.create<vector::PrintOp>(loc, vector::PrintPunctuation::Open);
Expand Down
Loading