Skip to content

Commit 740582f

Browse files
authored
[mlir][sparse] test for linalg tensor semantics (#70254)
This test used to be here, but somehow got lost while linalg rewrote their interfaces. It is essential to test this on entry of sparsification, however, since all subsequent analysis simply assumes tensor types. Fixes: #64325
1 parent 3dbcd73 commit 740582f

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1939,9 +1939,12 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
19391939

19401940
LogicalResult matchAndRewrite(linalg::GenericOp op,
19411941
PatternRewriter &rewriter) const override {
1942-
// Only accept single output operations without affine index on sparse
1943-
// output.
1944-
if (op.getNumDpsInits() != 1 || hasNonTrivialAffineOnSparseOut(op))
1942+
// Only accept single output operations with pure tensor semantics.
1943+
if (op.getNumDpsInits() != 1 || !op.hasTensorSemantics())
1944+
return failure();
1945+
1946+
// Only accept trivial affine indices.
1947+
if (hasNonTrivialAffineOnSparseOut(op))
19451948
return failure();
19461949

19471950
// Sets up a code generation environment.
@@ -1951,7 +1954,6 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
19511954
// TODO: we should probably always use slice-based codegen whenever
19521955
// possible, we can even intermix slice-based and filter-loop based codegen.
19531956
bool idxReducBased = options.enableIndexReduction && numFilterLoops != 0;
1954-
19551957
// If we have indexing map like (d0) -> (0, d0), there might be more
19561958
// levels then loops because of the constant index, that means we can not
19571959
// use numLoops as the upper bound for ranks of all tensors.
@@ -1964,9 +1966,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
19641966
maxLvlRank = std::max(maxLvlRank, SparseTensorType(rtp).getLvlRank());
19651967
}
19661968
}
1967-
1968-
// If we uses slice based algorithm for affine index, we do not need filter
1969-
// loop.
1969+
// A slice based algorithm for affine indices does not need filter loops.
19701970
CodegenEnv env(op, options, numTensors, numLoops,
19711971
/*numFilterLoops=*/idxReducBased ? 0 : numFilterLoops,
19721972
maxLvlRank);
@@ -2006,7 +2006,6 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
20062006
// to resolve cycles by inserting a conversion.
20072007
bool isAdmissible = false;
20082008
bool hasCycle = true;
2009-
20102009
// A const list of all masks that we used for iteration graph
20112010
// computation. Must be ordered from more strict to less strict.
20122011
// Ideally (though might not be guaranteed), the earlier a constraint mask
@@ -2030,7 +2029,6 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
20302029
? failure() // TODO: should cycle be resolved differently?
20312030
: resolveCycle(env, rewriter); // one last shot
20322031
}
2033-
20342032
if (!isAdmissible)
20352033
return failure(); // inadmissible expression, reject
20362034

0 commit comments

Comments
 (0)