Skip to content

Commit 98f8b1a

Browse files
authored
[mlir][sparse] remove COO test from trait and encoding (llvm#73733)
This is a minor step towards moving ALL COO related tests into the SparseTensorType class rather than having it all over the place (with risk of becoming inconsistent). Next revision will move ALL COO related methods into this class.
1 parent 83305fa commit 98f8b1a

File tree

3 files changed

+13
-29
lines changed

3 files changed

+13
-29
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -382,9 +382,6 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
382382
/// the null encoding (since dense-tensors are always all-dense).
383383
bool isAllDense() const;
384384

385-
/// Returns true if it is a sparse tensor encoding in COO format.
386-
bool isCOO() const;
387-
388385
/// Returns true if every level is ordered. Also returns true for
389386
/// the null encoding (since dense-tensors are always all-ordered).
390387
bool isAllOrdered() const;
@@ -467,33 +464,21 @@ def SparseTensorStorageSpecifierKindAttr
467464
def IsSparseTensorPred
468465
: CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self)">;
469466

470-
def IsCOOPred
471-
: CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self) && "
472-
" ::mlir::sparse_tensor::getSparseTensorEncoding($_self).isCOO()">;
473-
474467
def IsSparseTensorSlicePred
475468
: CPred<"!!::mlir::sparse_tensor::getSparseTensorEncoding($_self) && "
476469
" ::mlir::sparse_tensor::getSparseTensorEncoding($_self).isSlice()">;
477470

478471
class SparseTensorOf<list<Type> allowedTypes>
479472
: TensorOf<allowedTypes, [IsSparseTensorPred], "sparse tensor">;
480473

481-
class COOSparseTensorOf<list<Type> allowedTypes>
482-
: TensorOf<allowedTypes, [IsCOOPred], "COO sparse tensor">;
483-
484474
class SparseTensorSliceOf<list<Type> allowedTypes>
485475
: TensorOf<allowedTypes, [IsSparseTensorSlicePred], "sparse tensor slice">;
486476

487-
class RankedSparseTensorOf<list<Type> allowedTypes>
488-
: RankedTensorOf<allowedTypes, [IsSparseTensorPred], "ranked sparse tensor">;
489-
490477
class ScalarLikeOf<list<Type> allowedTypes>
491478
: AnyTypeOf<[0DTensorOf<allowedTypes>, AnyTypeOf<allowedTypes>], "scalar like">;
492479

493480
def AnySparseTensor : SparseTensorOf<[AnyType]>;
494-
def AnyCOOSparseTensor : COOSparseTensorOf<[AnyType]>;
495481
def AnySparseTensorSlice : SparseTensorSliceOf<[AnyType]>;
496-
def AnyRankedSparseTensor : RankedSparseTensorOf<[AnyType]>;
497482
def AnyIndexingScalarLike : ScalarLikeOf<[AnySignlessIntegerOrIndex]>;
498483

499484
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -921,10 +921,9 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort">,
921921
let summary = "Sorts the arrays in xs and ys lexicographically on the "
922922
"integral values found in the xs list";
923923
let description = [{
924-
Sparse_tensor.sort sort the `xs` values along with some `ys` values
925-
that are put in a single linear buffer `xy`.
926-
The affine map attribute `perm_map` specifies the permutation to be applied on
927-
the `xs` before comparison, the rank of the permutation map
924+
Sorts the `xs` values along with some `ys` values that are put in a single linear
925+
buffer `xy`. The affine map attribute `perm_map` specifies the permutation to be
926+
applied on the `xs` before comparison, the rank of the permutation map
928927
also specifies the number of `xs` values in `xy`.
929928
The optional index attribute `ny` provides the number of `ys` values in `xy`.
930929
When `ny` is not explicitly specified, its value is 0.
@@ -950,14 +949,14 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort">,
950949
}
951950

952951
def SparseTensor_ReorderCOOOp : SparseTensor_Op<"reorder_coo", [Pure]>,
953-
Arguments<(ins AnyCOOSparseTensor: $input_coo,
952+
Arguments<(ins AnySparseTensor: $input_coo,
954953
SparseTensorSortKindAttr:$algorithm)>,
955-
Results<(outs AnyCOOSparseTensor: $result_coo)> {
954+
Results<(outs AnySparseTensor: $result_coo)> {
956955
let summary = "Reorder the input COO such that it has the the same order as "
957956
"the output COO";
958957
let description = [{
959-
sparse_tensor.reorder_coo reorder input COO to the same order as specified by
960-
the output format. E.g., reorder an unordered COO into an ordered one.
958+
Reorders the input COO to the same order as specified by the output format.
959+
E.g., reorder an unordered COO into an ordered one.
961960

962961
The input and result COO tensor must have the same element type, position type and
963962
coordinate type. At the moment, the operation also only supports ordering

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -316,10 +316,6 @@ bool SparseTensorEncodingAttr::isAllDense() const {
316316
return !getImpl() || llvm::all_of(getLvlTypes(), isDenseLT);
317317
}
318318

319-
bool SparseTensorEncodingAttr::isCOO() const {
320-
return getImpl() && isCOOType(*this, 0, true);
321-
}
322-
323319
bool SparseTensorEncodingAttr::isAllOrdered() const {
324320
return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
325321
}
@@ -1664,14 +1660,18 @@ LogicalResult ReorderCOOOp::verify() {
16641660
SparseTensorType srcStt = getSparseTensorType(getInputCoo());
16651661
SparseTensorType dstStt = getSparseTensorType(getResultCoo());
16661662

1663+
if (!isCOOType(srcStt.getEncoding(), 0, /*isUnique=*/true) ||
1664+
!isCOOType(dstStt.getEncoding(), 0, /*isUnique=*/true))
1665+
emitError("Unexpected non-COO sparse tensors");
1666+
16671667
if (!srcStt.hasSameDimToLvl(dstStt))
16681668
emitError("Unmatched dim2lvl map between input and result COO");
16691669

16701670
if (srcStt.getPosType() != dstStt.getPosType() ||
16711671
srcStt.getCrdType() != dstStt.getCrdType() ||
1672-
srcStt.getElementType() != dstStt.getElementType()) {
1672+
srcStt.getElementType() != dstStt.getElementType())
16731673
emitError("Unmatched storage format between input and result COO");
1674-
}
1674+
16751675
return success();
16761676
}
16771677

0 commit comments

Comments
 (0)