Skip to content

[mlir][sparse] test for linalg tensor semantics #70254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1939,9 +1939,12 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {

LogicalResult matchAndRewrite(linalg::GenericOp op,
PatternRewriter &rewriter) const override {
// Only accept single output operations without affine index on sparse
// output.
if (op.getNumDpsInits() != 1 || hasNonTrivialAffineOnSparseOut(op))
// Only accept single output operations with pure tensor semantics.
if (op.getNumDpsInits() != 1 || !op.hasTensorSemantics())
return failure();

// Only accept trivial affine indices.
if (hasNonTrivialAffineOnSparseOut(op))
return failure();

// Sets up a code generation environment.
Expand All @@ -1951,7 +1954,6 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
// TODO: we should probably always use slice-based codegen whenever
// possible, we can even intermix slice-based and filter-loop based codegen.
bool idxReducBased = options.enableIndexReduction && numFilterLoops != 0;

// If we have indexing map like (d0) -> (0, d0), there might be more
// levels then loops because of the constant index, that means we can not
// use numLoops as the upper bound for ranks of all tensors.
Expand All @@ -1964,9 +1966,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank());
}
}

// If we uses slice based algorithm for affine index, we do not need filter
// loop.
// A slice based algorithm for affine indices does not need filter loops.
CodegenEnv env(op, options, numTensors, numLoops,
/*numFilterLoops=*/idxReducBased ? 0 : numFilterLoops,
maxLvlRank);
Expand Down Expand Up @@ -2006,7 +2006,6 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
// to resolve cycles by inserting a conversion.
bool isAdmissible = false;
bool hasCycle = true;

// A const list of all masks that we used for iteration graph
// computation. Must be ordered from more strict to less strict.
// Ideally (though might not be guaranteed), the earlier a constraint mask
Expand All @@ -2030,7 +2029,6 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
? failure() // TODO: should cycle be resolved differently?
: resolveCycle(env, rewriter); // one last shot
}

if (!isAdmissible)
return failure(); // inadmissible expression, reject

Expand Down