Skip to content

Commit 0aacc21

Browse files
authored
[mlir][sparse] introduce sparse_tensor.reorder_coo operation (#68827)
1 parent cff5007 commit 0aacc21

File tree

6 files changed

+121
-4
lines changed

6 files changed

+121
-4
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
134134
level-coordinates. The dimension-expressions collectively define the inverse map,
135135
which only needs to be provided for elaborate cases where it cannot be inferred
136136
automatically.
137-
137+
138138
Each dimension could also have an optional `SparseTensorDimSliceAttr`.
139139
Within the sparse storage format, we refer to indices that are stored explicitly
140140
as **coordinates** and offsets into the storage format as **positions**.
@@ -237,10 +237,10 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
237237
}>
238238
... tensor<20x30xf32, #BSR_explicit> ...
239239

240-
// ELL format.
240+
// ELL format.
241241
// In the simple format for matrix, one array stores values and another
242242
// array stores column indices. The arrays have the same number of rows
243-
// as the original matrix, but only have as many columns as
243+
// as the original matrix, but only have as many columns as
244244
// the maximum number of nonzeros on a row of the original matrix.
245245
// There are many variants for ELL such as jagged diagonal scheme.
246246
// To implement ELL, map provides a notion of "counting a
@@ -376,6 +376,9 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
376376
/// the null encoding (since dense-tensors are always all-dense).
377377
bool isAllDense() const;
378378

379+
/// Returns true if it is a sparse tensor encoding in COO format.
380+
bool isCOO() const;
381+
379382
/// Returns true if every level is ordered. Also returns true for
380383
/// the null encoding (since dense-tensors are always all-ordered).
381384
bool isAllOrdered() const;
@@ -468,6 +471,10 @@ def SparseTensorStorageSpecifierKindAttr
468471
def IsSparseTensorPred
469472
: CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self)">;
470473

474+
def IsCOOPred
475+
: CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self) && "
476+
" ::mlir::sparse_tensor::getSparseTensorEncoding($_self).isCOO()">;
477+
471478
def IsSparseTensorSlicePred
472479
: CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self) && "
473480
" ::mlir::sparse_tensor::getSparseTensorEncoding($_self).isSlice()">;
@@ -478,10 +485,14 @@ def IsSparseTensorSlicePred
478485
class SparseTensorOf<list<Type> allowedTypes>
479486
: TensorOf<allowedTypes, [IsSparseTensorPred], "sparse tensor">;
480487

488+
class COOSparseTensorOf<list<Type> allowedTypes>
489+
: TensorOf<allowedTypes, [IsCOOPred], "COO sparse tensor">;
490+
481491
class SparseTensorSliceOf<list<Type> allowedTypes>
482492
: TensorOf<allowedTypes, [IsSparseTensorSlicePred], "sparse tensor slice">;
483493

484494
def AnySparseTensor : SparseTensorOf<[AnyType]>;
495+
def AnyCOOSparseTensor : COOSparseTensorOf<[AnyType]>;
485496
def AnySparseTensorSlice : SparseTensorSliceOf<[AnyType]>;
486497

487498
class RankedSparseTensorOf<list<Type> allowedTypes>

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -770,7 +770,7 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
770770
}
771771

772772
//===----------------------------------------------------------------------===//
773-
// Sparse Tensor Sorting Operations.
773+
// Sparse Tensor Sorting/Ordering Operations.
774774
//===----------------------------------------------------------------------===//
775775

776776
def SparseTensor_SortOp : SparseTensor_Op<"sort">,
@@ -809,6 +809,36 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort">,
809809
let hasVerifier = 1;
810810
}
811811

812+
def SparseTensor_ReorderCOOOp : SparseTensor_Op<"reorder_coo", [Pure]>,
813+
Arguments<(ins AnyCOOSparseTensor: $input_coo,
814+
SparseTensorSortKindAttr:$algorithm)>,
815+
Results<(outs AnyCOOSparseTensor: $result_coo)> {
816+
let summary = "Reorder the input COO such that it has the the same order as "
817+
"the output COO";
818+
let description = [{
819+
sparse_tensor.reorder_coo reorder input COO to the same order as specified by
820+
the output format. E.g., reorder an unordered COO into an ordered one.
821+
822+
The input and result COO tensor must have the same element type, position type and
823+
coordinate type. At the moment, the operation also only supports ordering
824+
input and result COO with the same dim2lvl map.
825+
826+
Example:
827+
828+
```mlir
829+
%res = sparse_tensor.reorder_coo quick_sort %coo : tensor<?x?xf64 : #Unordered_COO> to
830+
tensor<?x?xf64 : #Ordered_COO>
831+
832+
```
833+
}];
834+
835+
let assemblyFormat = "$algorithm $input_coo attr-dict"
836+
"`:` type($input_coo) `to` type($result_coo)";
837+
838+
let hasFolder = 1;
839+
let hasVerifier = 1;
840+
}
841+
812842
//===----------------------------------------------------------------------===//
813843
// Sparse Tensor Syntax Operations.
814844
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,10 @@ bool SparseTensorEncodingAttr::isAllDense() const {
336336
return !getImpl() || llvm::all_of(getLvlTypes(), isDenseDLT);
337337
}
338338

339+
bool SparseTensorEncodingAttr::isCOO() const {
340+
return getImpl() && isCOOType(*this, 0, true);
341+
}
342+
339343
bool SparseTensorEncodingAttr::isAllOrdered() const {
340344
return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedDLT);
341345
}
@@ -1417,6 +1421,29 @@ LogicalResult ForeachOp::verify() {
14171421
return success();
14181422
}
14191423

1424+
OpFoldResult ReorderCOOOp::fold(FoldAdaptor adaptor) {
1425+
if (getSparseTensorEncoding(getInputCoo().getType()) ==
1426+
getSparseTensorEncoding(getResultCoo().getType()))
1427+
return getInputCoo();
1428+
1429+
return {};
1430+
}
1431+
1432+
LogicalResult ReorderCOOOp::verify() {
1433+
SparseTensorType srcStt = getSparseTensorType(getInputCoo());
1434+
SparseTensorType dstStt = getSparseTensorType(getResultCoo());
1435+
1436+
if (!srcStt.hasSameDimToLvl(dstStt))
1437+
emitError("Unmatched dim2lvl map between input and result COO");
1438+
1439+
if (srcStt.getPosType() != dstStt.getPosType() ||
1440+
srcStt.getCrdType() != dstStt.getCrdType() ||
1441+
srcStt.getElementType() != dstStt.getElementType()) {
1442+
emitError("Unmatched storage format between input and result COO");
1443+
}
1444+
return success();
1445+
}
1446+
14201447
LogicalResult ReduceOp::verify() {
14211448
Type inputType = getX().getType();
14221449
// Check correct number of block arguments and return type.

mlir/test/Dialect/SparseTensor/fold.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,16 @@ func.func @sparse_get_specifier_dce_fold(%arg0: !sparse_tensor.storage_specifier
6262
: !sparse_tensor.storage_specifier<#SparseVector>
6363
return %2 : index
6464
}
65+
66+
67+
68+
#COO = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)}>
69+
70+
// CHECK-LABEL: func @sparse_reorder_coo(
71+
// CHECK-SAME: %[[A:.*]]: tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
72+
// CHECK-NOT: %[[R:.*]] = sparse_tensor.reorder_coo
73+
// CHECK: return %[[A]]
74+
func.func @sparse_reorder_coo(%arg0 : tensor<?x?xf32, #COO>) -> tensor<?x?xf32, #COO> {
75+
%ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #COO> to tensor<?x?xf32, #COO>
76+
return %ret : tensor<?x?xf32, #COO>
77+
}

mlir/test/Dialect/SparseTensor/invalid.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,3 +839,25 @@ func.func @sparse_alloc_escapes(%arg0: index) -> tensor<10x?xf64, #CSR> {
839839
%0 = bufferization.alloc_tensor(%arg0) : tensor<10x?xf64, #CSR>
840840
return %0: tensor<10x?xf64, #CSR>
841841
}
842+
843+
// -----
844+
845+
#UnorderedCOO = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed(nonunique, nonordered), d1 : singleton(nonordered))}>
846+
#OrderedCOOPerm = #sparse_tensor.encoding<{map = (d0, d1) -> (d1 : compressed(nonunique), d0 : singleton)}>
847+
848+
func.func @sparse_permuted_reorder_coo(%arg0 : tensor<?x?xf32, #UnorderedCOO>) -> tensor<?x?xf32, #OrderedCOOPerm> {
849+
// expected-error@+1 {{Unmatched dim2lvl map between input and result COO}}
850+
%ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #UnorderedCOO> to tensor<?x?xf32, #OrderedCOOPerm>
851+
return %ret : tensor<?x?xf32, #OrderedCOOPerm>
852+
}
853+
854+
// -----
855+
856+
#UnorderedCOO = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed(nonunique, nonordered), d1 : singleton(nonordered))}>
857+
#OrderedCOO = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)}>
858+
859+
func.func @sparse_permuted_reorder_coo(%arg0 : tensor<?x?xf32, #UnorderedCOO>) -> tensor<?x?xf64, #OrderedCOO> {
860+
// expected-error@+1 {{Unmatched storage format between input and result COO}}
861+
%ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #UnorderedCOO> to tensor<?x?xf64, #OrderedCOO>
862+
return %ret : tensor<?x?xf64, #OrderedCOO>
863+
}

mlir/test/Dialect/SparseTensor/roundtrip.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,3 +633,17 @@ func.func @sparse_sort_coo_stable(%arg0: index, %arg1: memref<?xi64>, %arg2: mem
633633
sparse_tensor.sort insertion_sort_stable %arg0, %arg1 jointly %arg2 {perm_map = #ID_MAP, ny = 1 : index}: memref<?xi64> jointly memref<?xf32>
634634
return %arg1, %arg2 : memref<?xi64>, memref<?xf32>
635635
}
636+
637+
// -----
638+
639+
#UnorderedCOO = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed(nonunique, nonordered), d1 : singleton(nonordered))}>
640+
#OrderedCOO = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)}>
641+
642+
// CHECK-LABEL: func @sparse_reorder_coo(
643+
// CHECK-SAME: %[[A:.*]]: tensor<?x?xf32, #sparse_tensor.encoding<{{{.*}}}>>
644+
// CHECK: %[[R:.*]] = sparse_tensor.reorder_coo quick_sort %[[A]]
645+
// CHECK: return %[[R]]
646+
func.func @sparse_reorder_coo(%arg0 : tensor<?x?xf32, #UnorderedCOO>) -> tensor<?x?xf32, #OrderedCOO> {
647+
%ret = sparse_tensor.reorder_coo quick_sort %arg0 : tensor<?x?xf32, #UnorderedCOO> to tensor<?x?xf32, #OrderedCOO>
648+
return %ret : tensor<?x?xf32, #OrderedCOO>
649+
}

0 commit comments

Comments
 (0)