Skip to content

[mlir][sparse] Improve sparse tensor type constraints #112133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -586,10 +586,10 @@ def IsSparseTensorSlicePred
" ::mlir::sparse_tensor::getSparseTensorEncoding($_self).isSlice()">;

class SparseTensorOf<list<Type> allowedTypes>
: TensorOf<allowedTypes, [IsSparseTensorPred], "sparse tensor">;
: RankedTensorOf<allowedTypes, [IsSparseTensorPred], "sparse tensor">;

class SparseTensorSliceOf<list<Type> allowedTypes>
: TensorOf<allowedTypes, [IsSparseTensorSlicePred], "sparse tensor slice">;
: RankedTensorOf<allowedTypes, [IsSparseTensorSlicePred], "sparse tensor slice">;

class ScalarLikeOf<list<Type> allowedTypes>
: AnyTypeOf<[0DTensorOf<allowedTypes>, AnyTypeOf<allowedTypes>], "scalar like">;
Expand Down
22 changes: 11 additions & 11 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]> {
```
}];

let arguments = (ins Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels,
TensorOf<[AnyType]>:$values);
let arguments = (ins Variadic<RankedTensorOf<[AnySignlessIntegerOrIndex]>>:$levels,
RankedTensorOf<[AnyType]>:$values);
let results = (outs AnySparseTensor: $result);
let assemblyFormat =
"` ` `(` $levels `)` `,` $values attr-dict `:`"
Expand Down Expand Up @@ -138,12 +138,12 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
}];

let arguments = (ins AnySparseTensor:$tensor,
Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels,
TensorOf<[AnyType]>:$out_values);
let results = (outs Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
TensorOf<[AnyType]>:$ret_values,
Variadic<AnyIndexingScalarLike>:$lvl_lens,
AnyIndexingScalarLike:$val_len);
Variadic<RankedTensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels,
RankedTensorOf<[AnyType]>:$out_values);
let results = (outs Variadic<RankedTensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
RankedTensorOf<[AnyType]>:$ret_values,
Variadic<AnyIndexingScalarLike>:$lvl_lens,
AnyIndexingScalarLike:$val_len);
let assemblyFormat =
"$tensor attr-dict `:` type($tensor)"
"`out_lvls` `(` $out_levels `:` type($out_levels) `)` "
Expand Down Expand Up @@ -196,8 +196,8 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",

}];

let arguments = (ins AnyTensor:$source);
let results = (outs AnyTensor:$dest);
let arguments = (ins AnyRankedTensor:$source);
let results = (outs AnyRankedTensor:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";

let extraClassDeclaration = [{
Expand Down Expand Up @@ -1447,7 +1447,7 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
];

let regions = (region SizedRegion<1>:$region);
let arguments = (ins AnyTensor:$tensor,
let arguments = (ins AnyRankedTensor:$tensor,
Variadic<AnyType>:$initArgs,
OptionalAttr<AffineMapAttr>:$order);
let results = (outs Variadic<AnyType>:$results);
Expand Down
91 changes: 48 additions & 43 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
// The coordinates should be in shape of <? x rank>
unsigned expCOORank = stt.getLvlRank() - cooStartLvl;
if (cooTp.getRank() != 2 || expCOORank != cooTp.getShape().back()) {
op->emitError("input/output trailing COO level-ranks don't match");
return op->emitError("input/output trailing COO level-ranks don't match");
}
}

Expand Down Expand Up @@ -1350,7 +1350,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
}

LogicalResult AssembleOp::verify() {
const auto valuesTp = getRankedTensorType(getValues());
RankedTensorType valuesTp = getValues().getType();
const auto lvlsTp = getLevels().getTypes();
const auto resTp = getSparseTensorType(getResult());
return verifyPackUnPack(*this, true, resTp, valuesTp, lvlsTp);
Expand All @@ -1364,34 +1364,31 @@ LogicalResult DisassembleOp::verify() {
if (ot.getType() != rt.getType())
return emitError("output levels and return levels type mismatch");

const auto valuesTp = getRankedTensorType(getRetValues());
RankedTensorType valuesTp = getRetValues().getType();
const auto lvlsTp = getRetLevels().getTypes();
const auto srcTp = getSparseTensorType(getTensor());
return verifyPackUnPack(*this, false, srcTp, valuesTp, lvlsTp);
}

LogicalResult ConvertOp::verify() {
if (auto tp1 = llvm::dyn_cast<RankedTensorType>(getSource().getType())) {
if (auto tp2 = llvm::dyn_cast<RankedTensorType>(getDest().getType())) {
if (tp1.getRank() != tp2.getRank())
return emitError("unexpected conversion mismatch in rank");
auto dstEnc =
llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
if (dstEnc && dstEnc.isSlice())
return emitError("cannot convert to a sparse tensor slice");

auto shape1 = tp1.getShape();
auto shape2 = tp2.getShape();
// Accept size matches between the source and the destination type
// (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
// matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
return emitError("unexpected conversion mismatch in dimension ") << d;
return success();
}
}
return emitError("unexpected type in convert");
RankedTensorType tp1 = getSource().getType();
RankedTensorType tp2 = getDest().getType();
if (tp1.getRank() != tp2.getRank())
return emitError("unexpected conversion mismatch in rank");
auto dstEnc =
llvm::dyn_cast_or_null<SparseTensorEncodingAttr>(tp2.getEncoding());
if (dstEnc && dstEnc.isSlice())
return emitError("cannot convert to a sparse tensor slice");

auto shape1 = tp1.getShape();
auto shape2 = tp2.getShape();
// Accept size matches between the source and the destination type
// (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
// matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
for (Dimension d = 0, dimRank = tp1.getRank(); d < dimRank; d++)
if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamic)
return emitError("unexpected conversion mismatch in dimension ") << d;
return success();
}

OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) {
Expand Down Expand Up @@ -1495,7 +1492,8 @@ LogicalResult LvlOp::verify() {
if (std::optional<uint64_t> lvl = getConstantLvlIndex()) {
auto stt = getSparseTensorType(getSource());
if (static_cast<uint64_t>(lvl.value()) >= stt.getLvlRank())
emitError("Level index exceeds the rank of the input sparse tensor");
return emitError(
"Level index exceeds the rank of the input sparse tensor");
}
return success();
}
Expand Down Expand Up @@ -1697,14 +1695,14 @@ LogicalResult ToValuesOp::inferReturnTypes(MLIRContext *ctx,
}

LogicalResult ToSliceOffsetOp::verify() {
auto rank = getRankedTensorType(getSlice()).getRank();
auto rank = getSlice().getType().getRank();
if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
return emitError("requested dimension out of bound");
return success();
}

LogicalResult ToSliceStrideOp::verify() {
auto rank = getRankedTensorType(getSlice()).getRank();
auto rank = getSlice().getType().getRank();
if (rank <= getDim().getSExtValue() || getDim().getSExtValue() < 0)
return emitError("requested dimension out of bound");
return success();
Expand Down Expand Up @@ -1986,15 +1984,16 @@ LogicalResult ForeachOp::verify() {
const auto iTp = IndexType::get(getContext());
for (Dimension d = 0; d < dimRank; d++)
if (args[d].getType() != iTp)
emitError(
return emitError(
llvm::formatv("Expecting Index type for argument at index {0}", d));

const auto elemTp = t.getElementType();
const auto valueTp = args[dimRank].getType();
if (elemTp != valueTp)
emitError(llvm::formatv("Unmatched element type between input tensor and "
"block argument, expected:{0}, got: {1}",
elemTp, valueTp));
return emitError(
llvm::formatv("Unmatched element type between input tensor and "
"block argument, expected:{0}, got: {1}",
elemTp, valueTp));
return success();
}

Expand All @@ -2011,15 +2010,15 @@ LogicalResult ReorderCOOOp::verify() {
SparseTensorType dstStt = getSparseTensorType(getResultCoo());

if (!srcStt.isCOOType() || !dstStt.isCOOType())
emitError("Expected COO sparse tensors only");
return emitError("Expected COO sparse tensors only");

if (!srcStt.hasSameDimToLvl(dstStt))
emitError("Unmatched dim2lvl map between input and result COO");
return emitError("Unmatched dim2lvl map between input and result COO");

if (srcStt.getPosType() != dstStt.getPosType() ||
srcStt.getCrdType() != dstStt.getCrdType() ||
srcStt.getElementType() != dstStt.getElementType())
emitError("Unmatched storage format between input and result COO");
return emitError("Unmatched storage format between input and result COO");

return success();
}
Expand All @@ -2044,10 +2043,11 @@ LogicalResult SortOp::verify() {
AffineMap xPerm = getPermMap();
uint64_t nx = xPerm.getNumDims();
if (nx < 1)
emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));
return emitError(llvm::formatv("Expected rank(perm_map) > 1, got {0}", nx));

if (!xPerm.isPermutation())
emitError(llvm::formatv("Expected a permutation map, got {0}", xPerm));
return emitError(
llvm::formatv("Expected a permutation map, got {0}", xPerm));

// We can't check the size of the buffers when n or buffer dimensions aren't
// compile-time constants.
Expand All @@ -2056,19 +2056,24 @@ LogicalResult SortOp::verify() {
return success();

// Verify dimensions.
const auto checkDim = [&](Value v, Size minSize, const char *message) {
const auto checkDim = [&](Value v, Size minSize,
const char *message) -> LogicalResult {
const Size sh = getMemRefType(v).getShape()[0];
if (!ShapedType::isDynamic(sh) && sh < minSize)
emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
return emitError(
llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
return success();
};
uint64_t n = cn.value();
uint64_t ny = 0;
if (auto nyAttr = getNyAttr())
ny = nyAttr.getInt();
checkDim(getXy(), n * (nx + ny),
"Expected dimension(xy) >= n * (rank(perm_map) + ny)");
if (failed(checkDim(getXy(), n * (nx + ny),
"Expected dimension(xy) >= n * (rank(perm_map) + ny)")))
return failure();
for (Value opnd : getYs())
checkDim(opnd, n, "Expected dimension(y) >= n");
if (failed(checkDim(opnd, n, "Expected dimension(y) >= n")))
return failure();

return success();
}
Expand Down Expand Up @@ -2101,8 +2106,8 @@ static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
}

if (lvlHi <= lvlLo)
parser.emitError(parser.getNameLoc(),
"expect larger level upper bound than lower bound");
return parser.emitError(parser.getNameLoc(),
"expect larger level upper bound than lower bound");

return success();
}
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/SparseTensor/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func.func @invalid_positions_dense(%arg0: tensor<128xf64>) -> memref<?xindex> {

func.func @invalid_positions_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
// expected-error@+1 {{'sparse_tensor.positions' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
%0 = sparse_tensor.positions %arg0 { level = 0 : index } : tensor<*xf64> to memref<?xindex>
%0 = "sparse_tensor.positions"(%arg0) { level = 0 : index } : (tensor<*xf64>) -> (memref<?xindex>)
return %0 : memref<?xindex>
}

Expand Down Expand Up @@ -141,7 +141,7 @@ func.func @invalid_indices_dense(%arg0: tensor<10x10xi32>) -> memref<?xindex> {

func.func @invalid_indices_unranked(%arg0: tensor<*xf64>) -> memref<?xindex> {
// expected-error@+1 {{'sparse_tensor.coordinates' op operand #0 must be sparse tensor of any type values, but got 'tensor<*xf64>'}}
%0 = sparse_tensor.coordinates %arg0 { level = 0 : index } : tensor<*xf64> to memref<?xindex>
%0 = "sparse_tensor.coordinates"(%arg0) { level = 0 : index } : (tensor<*xf64>) -> (memref<?xindex>)
return %0 : memref<?xindex>
}

Expand Down Expand Up @@ -347,7 +347,7 @@ func.func @sparse_wrong_arity_compression(%arg0: memref<?xf64>,
// -----

func.func @sparse_convert_unranked(%arg0: tensor<*xf32>) -> tensor<10xf32> {
// expected-error@+1 {{unexpected type in convert}}
// expected-error@+1 {{invalid kind of type specified}}
%0 = sparse_tensor.convert %arg0 : tensor<*xf32> to tensor<10xf32>
return %0 : tensor<10xf32>
}
Expand Down
Loading