Skip to content

Commit 761c9dd

Browse files
authored
[mlir][sparse] implementating stageSparseOpPass as an interface (llvm#69022)
1 parent a22a1fe commit 761c9dd

File tree

11 files changed

+299
-197
lines changed

11 files changed

+299
-197
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,9 @@ set(LLVM_TARGET_DEFINITIONS SparseTensorTypes.td)
1212
mlir_tablegen(SparseTensorTypes.h.inc -gen-typedef-decls)
1313
mlir_tablegen(SparseTensorTypes.cpp.inc -gen-typedef-defs)
1414
add_public_tablegen_target(MLIRSparseTensorTypesIncGen)
15+
16+
set(LLVM_TARGET_DEFINITIONS SparseTensorInterfaces.td)
17+
mlir_tablegen(SparseTensorInterfaces.h.inc -gen-op-interface-decls)
18+
mlir_tablegen(SparseTensorInterfaces.cpp.inc -gen-op-interface-defs)
19+
add_public_tablegen_target(MLIRSparseTensorInterfacesIncGen)
20+
add_dependencies(mlir-headers MLIRSparseTensorInterfacesIncGen)

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "mlir/Bytecode/BytecodeOpInterface.h"
1313
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
14+
#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h"
1415
#include "mlir/IR/BuiltinTypes.h"
1516
#include "mlir/IR/Dialect.h"
1617
#include "mlir/IR/OpDefinition.h"
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//===- SparseTensorInterfaces.h - sparse tensor operations
2+
//interfaces-------===//
3+
//
4+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#ifndef MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_
11+
#define MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_
12+
13+
#include "mlir/IR/OpDefinition.h"
14+
15+
namespace mlir {
16+
class PatternRewriter;
17+
18+
namespace sparse_tensor {
19+
class StageWithSortSparseOp;
20+
21+
namespace detail {
22+
LogicalResult stageWithSortImpl(sparse_tensor::StageWithSortSparseOp op,
23+
PatternRewriter &rewriter);
24+
} // namespace detail
25+
} // namespace sparse_tensor
26+
} // namespace mlir
27+
28+
/// Include the generated interface declarations.
29+
#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h.inc"
30+
31+
#endif // MLIR_DIALECT_SPARSETENSOR_IR_SPARSETENSORINTERFACES_H_
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//===- SparseTensorInterfaces.td --------------------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef SPARSETENSOR_IR_SPARSETENSORINTERFACES
10+
#define SPARSETENSOR_IR_SPARSETENSORINTERFACES
11+
12+
include "mlir/IR/OpBase.td"
13+
14+
def StageWithSortSparseOpInterface : OpInterface<"StageWithSortSparseOp"> {
15+
let description = [{
16+
A stage-with-sort sparse tensor operation is an operation that produces
17+
unordered intermediate output. An extra sort is required to obtain the final
18+
ordered result.
19+
20+
E.g., convert csr -> csc need to be implemented as
21+
convert csr -> unordered coo -> sort by column -> csc; and
22+
concatenate csr, csc -> csr can be staged into
23+
concatenate csr, csr -> unordered coo -> sort by row -> csr.
24+
}];
25+
let cppNamespace = "::mlir::sparse_tensor";
26+
let methods = [
27+
InterfaceMethod<
28+
/*desc=*/"Return true if the operation needs an extra sort to produce the final result.",
29+
/*retTy=*/"bool",
30+
/*methodName=*/"needsExtraSort",
31+
/*args=*/(ins),
32+
/*methodBody=*/"">,
33+
InterfaceMethod<
34+
/*desc=*/"Stage the operation, return the final result value after staging.",
35+
/*retTy=*/"::mlir::LogicalResult",
36+
/*methodName=*/"stageWithSort",
37+
/*args=*/(ins "::mlir::PatternRewriter &":$rewriter),
38+
/*methodBody=*/[{
39+
return detail::stageWithSortImpl($_op, rewriter);
40+
}]>,
41+
];
42+
}
43+
44+
45+
#endif // SPARSETENSOR_IR_SPARSETENSORINTERFACES

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td"
1313
include "mlir/Dialect/SparseTensor/IR/SparseTensorBase.td"
1414
include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
15+
include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
1516
include "mlir/Interfaces/InferTypeOpInterface.td"
1617
include "mlir/Interfaces/SideEffectInterfaces.td"
1718

@@ -153,7 +154,7 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
153154
}
154155

155156
def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
156-
[Pure]>,
157+
[Pure, StageWithSortSparseOpInterface]>,
157158
Arguments<(ins AnyTensor:$source)>,
158159
Results<(outs AnyTensor:$dest)> {
159160
string summary = "Converts between different tensor types";
@@ -197,9 +198,9 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
197198
}];
198199

199200
let extraClassDeclaration = [{
200-
// Whether the convert can be done by a single step (either a sort or a foreach),
201-
// or it would require a tmp buffer (sort, then foreach).
202-
bool directConvertable();
201+
// Whether the convert can be done by a single step or it would require
202+
// an extra sort. Inherited from StageWithSortSparseOpInterface.
203+
bool needsExtraSort();
203204
}];
204205

205206
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
@@ -334,7 +335,8 @@ def SparseTensor_NumberOfEntriesOp : SparseTensor_Op<"number_of_entries", [Pure]
334335
let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
335336
}
336337

337-
def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [Pure]>,
338+
def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate",
339+
[Pure, StageWithSortSparseOpInterface]>,
338340
Arguments<(ins Variadic<AnyRankedTensor>:$inputs, DimensionAttr:$dimension)>,
339341
Results<(outs AnyRankedTensor:$result)> {
340342

@@ -357,6 +359,12 @@ def SparseTensor_ConcatenateOp : SparseTensor_Op<"concatenate", [Pure]>,
357359
```
358360
}];
359361

362+
let extraClassDeclaration = [{
363+
// Whether the concatenate can be done by a single step or it would require
364+
// an extra sort. Inherited from StageWithSortSparseOpInterface.
365+
bool needsExtraSort();
366+
}];
367+
360368
let assemblyFormat = "$inputs attr-dict `:` type($inputs) `to` type($result)";
361369
let hasVerifier = 1;
362370
}

mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ endif()
2929

3030
add_mlir_dialect_library(MLIRSparseTensorDialect
3131
SparseTensorDialect.cpp
32+
SparseTensorInterfaces.cpp
3233
Detail/Var.cpp
3334
Detail/DimLvlMap.cpp
3435
Detail/LvlTypeParser.cpp

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,18 +1065,18 @@ OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
10651065
return {};
10661066
}
10671067

1068-
bool ConvertOp::directConvertable() {
1068+
bool ConvertOp::needsExtraSort() {
10691069
SparseTensorType srcStt = getSparseTensorType(getSource());
10701070
SparseTensorType dstStt = getSparseTensorType(getDest());
10711071

1072-
// We can always directly convert to unordered sparse tensor or dense tensor
1073-
// since dense tensor support random access.
1072+
// We do not need an extra sort when returning unordered sparse tensors or
1073+
// dense tensor since dense tensor support random access.
10741074
if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1075-
return true;
1075+
return false;
10761076

10771077
if (srcStt.isAllOrdered() && dstStt.isAllOrdered() &&
10781078
srcStt.hasSameDimToLvl(dstStt)) {
1079-
return true;
1079+
return false;
10801080
}
10811081

10821082
// Source and dest tensors are ordered in different ways. We only do direct
@@ -1086,9 +1086,9 @@ bool ConvertOp::directConvertable() {
10861086
// performance.
10871087
if (auto constOp = getSource().getDefiningOp<arith::ConstantOp>())
10881088
if (isa<SparseElementsAttr>(constOp.getValue()))
1089-
return true;
1089+
return false;
10901090

1091-
return false;
1091+
return true;
10921092
}
10931093

10941094
LogicalResult ToPositionsOp::verify() {
@@ -1248,6 +1248,23 @@ LogicalResult UnaryOp::verify() {
12481248
return success();
12491249
}
12501250

1251+
bool ConcatenateOp::needsExtraSort() {
1252+
SparseTensorType dstStt = getSparseTensorType(*this);
1253+
if (dstStt.isAllDense() || !dstStt.isAllOrdered())
1254+
return false;
1255+
1256+
bool allSameOrdered = llvm::all_of(getInputs(), [dstStt](Value op) {
1257+
return getSparseTensorType(op).hasSameDimToLvl(dstStt);
1258+
});
1259+
// TODO: When conDim != 0, as long as conDim corresponding to the first level
1260+
// in all input/output buffers, and all input/output buffers have the same
1261+
// dimToLvl, the tmp COO buffer is still unnecessary (e.g, concatenate
1262+
// CSC matrices along column).
1263+
bool directLowerable =
1264+
allSameOrdered && getDimension() == 0 && dstStt.isIdentity();
1265+
return !directLowerable;
1266+
}
1267+
12511268
LogicalResult ConcatenateOp::verify() {
12521269
const auto dstTp = getSparseTensorType(*this);
12531270
const Dimension concatDim = getDimension();
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
//===- SparseTensorInterfaces.cpp - SparseTensor interfaces impl ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h"
10+
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
11+
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
12+
#include "mlir/IR/PatternMatch.h"
13+
14+
using namespace mlir;
15+
using namespace mlir::sparse_tensor;
16+
17+
#include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc"
18+
19+
LogicalResult
20+
sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
21+
PatternRewriter &rewriter) {
22+
if (!op.needsExtraSort())
23+
return failure();
24+
25+
Location loc = op.getLoc();
26+
Type finalTp = op->getOpResult(0).getType();
27+
SparseTensorType dstStt(finalTp.cast<RankedTensorType>());
28+
29+
Type srcCOOTp = getCOOFromTypeWithOrdering(
30+
dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/false);
31+
32+
// Clones the original operation but changing the output to an unordered COO.
33+
Operation *cloned = rewriter.clone(*op.getOperation());
34+
rewriter.updateRootInPlace(cloned, [cloned, srcCOOTp]() {
35+
cloned->getOpResult(0).setType(srcCOOTp);
36+
});
37+
Value srcCOO = cloned->getOpResult(0);
38+
39+
// -> sort
40+
Type dstCOOTp = getCOOFromTypeWithOrdering(
41+
dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/true);
42+
Value dstCOO = rewriter.create<ReorderCOOOp>(
43+
loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort);
44+
45+
// -> dest.
46+
if (dstCOO.getType() == finalTp) {
47+
rewriter.replaceOp(op, dstCOO);
48+
} else {
49+
// Need an extra conversion if the target type is not COO.
50+
rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO);
51+
}
52+
// TODO: deallocate extra COOs, we should probably delegate it to buffer
53+
// deallocation pass.
54+
return success();
55+
}

0 commit comments

Comments
 (0)