@@ -1939,9 +1939,12 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1939
1939
1940
1940
LogicalResult matchAndRewrite (linalg::GenericOp op,
1941
1941
PatternRewriter &rewriter) const override {
1942
- // Only accept single output operations without affine index on sparse
1943
- // output.
1944
- if (op.getNumDpsInits () != 1 || hasNonTrivialAffineOnSparseOut (op))
1942
+ // Only accept single output operations with pure tensor semantics.
1943
+ if (op.getNumDpsInits () != 1 || !op.hasTensorSemantics ())
1944
+ return failure ();
1945
+
1946
+ // Only accept trivial affine indices.
1947
+ if (hasNonTrivialAffineOnSparseOut (op))
1945
1948
return failure ();
1946
1949
1947
1950
// Sets up a code generation environment.
@@ -1951,7 +1954,6 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1951
1954
// TODO: we should probably always use slice-based codegen whenever
1952
1955
// possible, we can even intermix slice-based and filter-loop based codegen.
1953
1956
bool idxReducBased = options.enableIndexReduction && numFilterLoops != 0 ;
1954
-
1955
1957
// If we have indexing map like (d0) -> (0, d0), there might be more
1956
1958
// levels then loops because of the constant index, that means we can not
1957
1959
// use numLoops as the upper bound for ranks of all tensors.
@@ -1964,9 +1966,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1964
1966
maxLvlRank = std::max (maxLvlRank, SparseTensorType (rtp).getLvlRank ());
1965
1967
}
1966
1968
}
1967
-
1968
- // If we uses slice based algorithm for affine index, we do not need filter
1969
- // loop.
1969
+ // A slice based algorithm for affine indices does not need filter loops.
1970
1970
CodegenEnv env (op, options, numTensors, numLoops,
1971
1971
/* numFilterLoops=*/ idxReducBased ? 0 : numFilterLoops,
1972
1972
maxLvlRank);
@@ -2006,7 +2006,6 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
2006
2006
// to resolve cycles by inserting a conversion.
2007
2007
bool isAdmissible = false ;
2008
2008
bool hasCycle = true ;
2009
-
2010
2009
// A const list of all masks that we used for iteration graph
2011
2010
// computation. Must be ordered from more strict to less strict.
2012
2011
// Ideally (though might not be guaranteed), the earlier a constraint mask
@@ -2030,7 +2029,6 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
2030
2029
? failure () // TODO: should cycle be resolved differently?
2031
2030
: resolveCycle (env, rewriter); // one last shot
2032
2031
}
2033
-
2034
2032
if (!isAdmissible)
2035
2033
return failure (); // inadmissible expression, reject
2036
2034
0 commit comments