Skip to content

Commit 77f8297

Browse files
[mlir][sparse] Improve sparse tensor type constraints (#112133)
Sparse tensors are always ranked tensors. Encodings cannot be attached to unranked tensors. Change the type constraint to `RankedTensorOf`, so that we generate `TypedValue<RankedTensorType>` instead of `TypedValue<TensorType>`. This removes the need for type casting in some cases. Also improve the verifiers (missing `return` statements) and switch a few other `AnyTensor` to `AnyRankedTensor`. This commit is in preparation of a dialect conversion commit that required fixes in the sparse dialect.
1 parent 2c5dd03 commit 77f8297

File tree

4 files changed

+64
-59
lines changed

4 files changed

+64
-59
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -586,10 +586,10 @@ def IsSparseTensorSlicePred
586586
" ::mlir::sparse_tensor::getSparseTensorEncoding($_self).isSlice()">;
587587

588588
class SparseTensorOf<list<Type> allowedTypes>
589-
: TensorOf<allowedTypes, [IsSparseTensorPred], "sparse tensor">;
589+
: RankedTensorOf<allowedTypes, [IsSparseTensorPred], "sparse tensor">;
590590

591591
class SparseTensorSliceOf<list<Type> allowedTypes>
592-
: TensorOf<allowedTypes, [IsSparseTensorSlicePred], "sparse tensor slice">;
592+
: RankedTensorOf<allowedTypes, [IsSparseTensorSlicePred], "sparse tensor slice">;
593593

594594
class ScalarLikeOf<list<Type> allowedTypes>
595595
: AnyTypeOf<[0DTensorOf<allowedTypes>, AnyTypeOf<allowedTypes>], "scalar like">;

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]> {
9292
```
9393
}];
9494

95-
let arguments = (ins Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels,
96-
TensorOf<[AnyType]>:$values);
95+
let arguments = (ins Variadic<RankedTensorOf<[AnySignlessIntegerOrIndex]>>:$levels,
96+
RankedTensorOf<[AnyType]>:$values);
9797
let results = (outs AnySparseTensor: $result);
9898
let assemblyFormat =
9999
"` ` `(` $levels `)` `,` $values attr-dict `:`"
@@ -138,12 +138,12 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
138138
}];
139139

140140
let arguments = (ins AnySparseTensor:$tensor,
141-
Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels,
142-
TensorOf<[AnyType]>:$out_values);
143-
let results = (outs Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
144-
TensorOf<[AnyType]>:$ret_values,
145-
Variadic<AnyIndexingScalarLike>:$lvl_lens,
146-
AnyIndexingScalarLike:$val_len);
141+
Variadic<RankedTensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels,
142+
RankedTensorOf<[AnyType]>:$out_values);
143+
let results = (outs Variadic<RankedTensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
144+
RankedTensorOf<[AnyType]>:$ret_values,
145+
Variadic<AnyIndexingScalarLike>:$lvl_lens,
146+
AnyIndexingScalarLike:$val_len);
147147
let assemblyFormat =
148148
"$tensor attr-dict `:` type($tensor)"
149149
"`out_lvls` `(` $out_levels `:` type($out_levels) `)` "
@@ -196,8 +196,8 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
196196

197197
}];
198198

199-
let arguments = (ins AnyTensor:$source);
200-
let results = (outs AnyTensor:$dest);
199+
let arguments = (ins AnyRankedTensor:$source);
200+
let results = (outs AnyRankedTensor:$dest);
201201
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
202202

203203
let extraClassDeclaration = [{
@@ -1447,7 +1447,7 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
14471447
];
14481448

14491449
let regions = (region SizedRegion<1>:$region);
1450-
let arguments = (ins AnyTensor:$tensor,
1450+
let arguments = (ins AnyRankedTensor:$tensor,
14511451
Variadic<AnyType>:$initArgs,
14521452
OptionalAttr<AffineMapAttr>:$order);
14531453
let results = (outs Variadic<AnyType>:$results);

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 48 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,7 +1310,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
13101310
// The coordinates should be in shape of <? x rank>
13111311
unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
13121312
if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
1313-
op->emitError("input/output trailing COO level-ranks don't match");
1313+
return op->emitError("input/output trailing COO level-ranks don't match");
13141314
}
13151315
}
13161316

@@ -1350,7 +1350,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
13501350
}
13511351

13521352
LogicalResult AssembleOp::verify() {
1353-
const auto valuesTp = getRankedTensorType(getValues());
1353+
RankedTensorType valuesTp = getValues().getType();
13541354
const auto lvlsTp = getLevels().getTypes();
13551355
const auto resTp = getSparseTensorType(getResult());
13561356
return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
@@ -1364,34 +1364,31 @@ LogicalResult DisassembleOp::verify() {
13641364
if (ot.getType() != rt.getType())
13651365
return emitError("output levels and return levels type mismatch");
13661366

1367-
const auto valuesTp = getRankedTensorType(getRetValues());
1367+
RankedTensorType valuesTp = getRetValues().getType();
13681368
const auto lvlsTp = getRetLevels().getTypes();
13691369
const auto srcTp = getSparseTensorType(getTensor());
13701370
return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
13711371
}
13721372

13731373
LogicalResult ConvertOp::verify() {
1374-
if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) {
1375-
if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) {
1376-
if (tp1.getRank() != tp2.getRank())
1377-
return emitError("unexpected conversion mismatch in rank");
1378-
auto dstEnc =
1379-
llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1380-
if (dstEnc && dstEnc.isSlice())
1381-
return emitError("cannot convert to a sparse tensor slice");
1382-
1383-
auto shape1 = tp1.getShape();
1384-
auto shape2 = tp2.getShape();
1385-
// Accept size matches between the source and the destination type
1386-
// (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1387-
// matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1388-
for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1389-
if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1390-
return emitError("unexpected conversion mismatch in dimension ") << d;
1391-
return success();
1392-
}
1393-
}
1394-
return emitError("unexpected type in convert");
1374+
RankedTensorType tp1 = getSource().getType();
1375+
RankedTensorType tp2 = getDest().getType();
1376+
if (tp1.getRank() != tp2.getRank())
1377+
return emitError("unexpected conversion mismatch in rank");
1378+
auto dstEnc =
1379+
llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
1380+
if (dstEnc && dstEnc.isSlice())
1381+
return emitError("cannot convert to a sparse tensor slice");
1382+
1383+
auto shape1 = tp1.getShape();
1384+
auto shape2 = tp2.getShape();
1385+
// Accept size matches between the source and the destination type
1386+
// (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
1387+
// matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
1388+
for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
1389+
if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
1390+
return emitError("unexpected conversion mismatch in dimension ") << d;
1391+
return success();
13951392
}
13961393

13971394
OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
@@ -1495,7 +1492,8 @@ LogicalResult LvlOp::verify() {
14951492
if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
14961493
auto stt = getSparseTensorType(getSource());
14971494
if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
1498-
emitError("Level index exceeds the rank of the input sparse tensor");
1495+
return emitError(
1496+
"Level index exceeds the rank of the input sparse tensor");
14991497
}
15001498
return success();
15011499
}
@@ -1697,14 +1695,14 @@ LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
16971695
}
16981696

16991697
LogicalResult ToSliceOffsetOp::verify() {
1700-
auto rank = getRankedTensorType(getSlice()).getRank();
1698+
auto rank = getSlice().getType().getRank();
17011699
if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
17021700
return emitError("requested dimension out of bound");
17031701
return success();
17041702
}
17051703

17061704
LogicalResult ToSliceStrideOp::verify() {
1707-
auto rank = getRankedTensorType(getSlice()).getRank();
1705+
auto rank = getSlice().getType().getRank();
17081706
if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
17091707
return emitError("requested dimension out of bound");
17101708
return success();
@@ -1986,15 +1984,16 @@ LogicalResult ForeachOp::verify() {
19861984
const auto iTp = IndexType::get(getContext());
19871985
for (Dimension d = 0; d < dimRank; d++)
19881986
if (args[d].getType() != iTp)
1989-
emitError(
1987+
return emitError(
19901988
llvm::formatv("Expecting Index type for argument at index {0}", d));
19911989

19921990
const auto elemTp = t.getElementType();
19931991
const auto valueTp = args[dimRank].getType();
19941992
if (elemTp != valueTp)
1995-
emitError(llvm::formatv("Unmatched element type between input tensor and "
1996-
"block argument, expected:{0}, got: {1}",
1997-
elemTp, valueTp));
1993+
return emitError(
1994+
llvm::formatv("Unmatched element type between input tensor and "
1995+
"block argument, expected:{0}, got: {1}",
1996+
elemTp, valueTp));
19981997
return success();
19991998
}
20001999

@@ -2011,15 +2010,15 @@ LogicalResult ReorderCOOOp::verify() {
20112010
SparseTensorType dstStt = getSparseTensorType(getResultCoo());
20122011

20132012
if (!srcStt.isCOOType() || !dstStt.isCOOType())
2014-
emitError("Expected COO sparse tensors only");
2013+
return emitError("Expected COO sparse tensors only");
20152014

20162015
if (!srcStt.hasSameDimToLvl(dstStt))
2017-
emitError("Unmatched dim2lvl map between input and result COO");
2016+
return emitError("Unmatched dim2lvl map between input and result COO");
20182017

20192018
if (srcStt.getPosType() != dstStt.getPosType() ||
20202019
srcStt.getCrdType() != dstStt.getCrdType() ||
20212020
srcStt.getElementType() != dstStt.getElementType())
2022-
emitError("Unmatched storage format between input and result COO");
2021+
return emitError("Unmatched storage format between input and result COO");
20232022

20242023
return success();
20252024
}
@@ -2044,10 +2043,11 @@ LogicalResult SortOp::verify() {
20442043
AffineMap xPerm = getPermMap();
20452044
uint64_t nx = xPerm.getNumDims();
20462045
if (nx < 1)
2047-
emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
2046+
return emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
20482047

20492048
if (!xPerm.isPermutation())
2050-
emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
2049+
return emitError(
2050+
llvm::formatv("Expected a permutation map, got {0}", xPerm));
20512051

20522052
// We can't check the size of the buffers when n or buffer dimensions aren't
20532053
// compile-time constants.
@@ -2056,19 +2056,24 @@ LogicalResult SortOp::verify() {
20562056
return success();
20572057

20582058
// Verify dimensions.
2059-
const auto checkDim = [&](Value v, Size minSize, const char *message) {
2059+
const auto checkDim = [&](Value v, Size minSize,
2060+
const char *message) -> LogicalResult {
20602061
const Size sh = getMemRefType(v).getShape()[0];
20612062
if (!ShapedType::isDynamic(sh) && sh < minSize)
2062-
emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
2063+
return emitError(
2064+
llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
2065+
return success();
20632066
};
20642067
uint64_t n = cn.value();
20652068
uint64_t ny = 0;
20662069
if (auto nyAttr = getNyAttr())
20672070
ny = nyAttr.getInt();
2068-
checkDim(getXy(), n * (nx + ny),
2069-
"Expected dimension(xy) >= n * (rank(perm_map) + ny)");
2071+
if (failed(checkDim(getXy(), n * (nx + ny),
2072+
"Expected dimension(xy) >= n * (rank(perm_map) + ny)")))
2073+
return failure();
20702074
for (Value opnd : getYs())
2071-
checkDim(opnd, n, "Expected dimension(y) >= n");
2075+
if (failed(checkDim(opnd, n, "Expected dimension(y) >= n")))
2076+
return failure();
20722077

20732078
return success();
20742079
}
@@ -2101,8 +2106,8 @@ static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
21012106
}
21022107

21032108
if (lvlHi <= lvlLo)
2104-
parser.emitError(parser.getNameLoc(),
2105-
"expect larger level upper bound than lower bound");
2109+
return parser.emitError(parser.getNameLoc(),
2110+
"expect larger level upper bound than lower bound");
21062111

21072112
return success();
21082113
}

mlir/test/Dialect/SparseTensor/invalid.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ func.func @invalid_positions_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {
105105

106106
func.func @invalid_positions_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
107107
// expected-error@+1 {{'sparse_tensor.positions' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
108-
%0 = sparse_tensor.positions %arg0 { level = 0 : index } : tensor<*xf64> to memref<?xindex>
108+
%0 = "sparse_tensor.positions"(%arg0) { level = 0 : index } : (tensor<*xf64>) -> (memref<?xindex>)
109109
return %0 : memref<?xindex>
110110
}
111111

@@ -141,7 +141,7 @@ func.func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref<?xindex> {
141141

142142
func.func @invalid_indices_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
143143
// expected-error@+1 {{'sparse_tensor.coordinates' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
144-
%0 = sparse_tensor.coordinates %arg0 { level = 0 : index } : tensor<*xf64> to memref<?xindex>
144+
%0 = "sparse_tensor.coordinates"(%arg0) { level = 0 : index } : (tensor<*xf64>) -> (memref<?xindex>)
145145
return %0 : memref<?xindex>
146146
}
147147

@@ -347,7 +347,7 @@ func.func @sparse_wrong_arity_compression(%arg0: memref<?xf64>,
347347
// -----
348348

349349
func.func @sparse_convert_unranked(%arg0: tensor<*xf32>) -> tensor<10xf32> {
350-
// expected-error@+1 {{unexpected type in convert}}
350+
// expected-error@+1 {{invalid kind of type specified}}
351351
%0 = sparse_tensor.convert %arg0 : tensor<*xf32> to tensor<10xf32>
352352
return %0 : tensor<10xf32>
353353
}

0 commit comments

Comments
 (0)