@@ -78,12 +78,8 @@ static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx,
78
78
// / Helper method to inspect affine expressions. Rejects cases where the
79
79
// / same index is used more than once. Also rejects compound affine
80
80
// / expressions in sparse dimensions.
81
- // / filterIdx stores the current filter loop idx should be used for the next
82
- // / compound affine sparse level, and it will be incremented by one when
83
- // / used.
84
81
static bool findAffine (Merger &merger, TensorId tid, Level lvl, AffineExpr a,
85
- DimLevelType dlt, LoopId &filterLdx,
86
- bool setLvlFormat = true ) {
82
+ DimLevelType dlt, bool setLvlFormat = true ) {
87
83
switch (a.getKind ()) {
88
84
case AffineExprKind::DimId: {
89
85
const LoopId idx = merger.makeLoopId (cast<AffineDimExpr>(a).getPosition ());
@@ -97,22 +93,14 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
97
93
case AffineExprKind::Add:
98
94
case AffineExprKind::Mul:
99
95
case AffineExprKind::Constant: {
100
- if (!isDenseDLT (dlt) && setLvlFormat) {
101
- assert (isUndefDLT (merger.getLvlType (tid, filterLdx)));
102
- // Use a filter loop for sparse affine expression.
103
- merger.setLevelAndType (tid, filterLdx, lvl, dlt);
104
- ++filterLdx;
105
- }
106
-
96
+ assert (isDenseDLT (dlt));
107
97
if (auto binOp = dyn_cast<AffineBinaryOpExpr>(a)) {
108
98
// We do not set dim level format for affine expression like d0 + d1 on
109
99
// either loop index at d0 or d1.
110
100
// We continue the recursion merely to check whether current affine is
111
101
// admissible or not.
112
- return findAffine (merger, tid, lvl, binOp.getLHS (), dlt, filterLdx,
113
- false ) &&
114
- findAffine (merger, tid, lvl, binOp.getRHS (), dlt, filterLdx,
115
- false );
102
+ return findAffine (merger, tid, lvl, binOp.getLHS (), dlt, false ) &&
103
+ findAffine (merger, tid, lvl, binOp.getRHS (), dlt, false );
116
104
}
117
105
// Falls through when it is a constant Affine
118
106
return true ;
@@ -225,32 +213,13 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map,
225
213
return 0 ;
226
214
const SparseTensorType stt (rtp);
227
215
228
- // FIXME: There's some dim/lvl confusion here. The previous version of
229
- // the code asserted that there are `lvlRank`-many expressions, but then
230
- // the `exprs[d]` expression assumes there are in fact `dimRank`-many
231
- // expressions. Even though `ArrayRef::operator[]` will check for OOB,
232
- // the mismatch between the assertion and the usage belies that this code
233
- // cannot support non-permutations.
234
- //
235
- // Elsewhere in this file the maps returned by
236
- // `linalg::GenericOp::getMatchingIndexingMap` are inconsistent about
237
- // whether they're expected to have `lvlRank`-many or `dimRank`-many
238
- // expressions (cf., `genSubscript` vs `findSparseAnnotations`);
239
- // so those are no help in determining which is actually intended.
240
- //
241
- // For now we work around this problem by asserting the two ranks agree.
242
- const Dimension dimRank = stt.getDimRank ();
243
216
const Level lvlRank = stt.getLvlRank ();
244
- assert (dimRank == lvlRank && " Non-permutations not currently supported" );
245
217
const auto exprs = map.getResults ();
246
- assert (static_cast <Dimension>(exprs.size ()) == dimRank &&
218
+ assert (static_cast <Dimension>(exprs.size ()) == lvlRank &&
247
219
" AffineMap does not have dimension-rank many results" );
248
- (void )dimRank;
249
220
unsigned num = 0 ;
250
221
for (Level l = 0 ; l < lvlRank; l++) {
251
- // FIXME: `toOrigDim` is deprecated.
252
- const Dimension d = toOrigDim (stt.getEncoding (), l);
253
- if (!isa<AffineDimExpr>(exprs[d]) && !stt.isDenseLvl (l))
222
+ if (!isa<AffineDimExpr>(exprs[l]) && !stt.isDenseLvl (l))
254
223
num++;
255
224
}
256
225
return num;
@@ -281,15 +250,10 @@ static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {
281
250
// / no annotations are found or inadmissible constructs occur.
282
251
// / We currently support two different ways to handle non-trivial index
283
252
// / expression on sparse tensors, and they accept different affine expressions.
284
- // / When using filter-loop-based approach, it accept (almost) arbitrary affine
285
- // / index expression on sparse tensor but it is much less efficient, and will be
286
- // / gradually removed from the codebase.
287
253
// / When using dependent index reducton-based approach, it currently only
288
254
// / supports affine addition index expression.
289
255
static bool findSparseAnnotations (CodegenEnv &env, bool idxReducBased) {
290
256
bool annotated = false ;
291
- // `filterLdx` may be mutated by `findAffine`.
292
- LoopId filterLdx = env.merger ().getStartingFilterLoopId ();
293
257
for (OpOperand &t : env.op ()->getOpOperands ()) {
294
258
const TensorId tid = env.makeTensorId (t.getOperandNumber ());
295
259
const auto map = env.op ().getMatchingIndexingMap (&t);
@@ -310,19 +274,17 @@ static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
310
274
// If then current tensor being inspected requires affine index, it need
311
275
// to be sliced.
312
276
for (Level l = 0 ; l < lvlRank; l++) {
313
- // FIXME: `toOrigDim` is deprecated.
314
- const AffineExpr a = map.getResult (toOrigDim (enc, l));
277
+ const AffineExpr a = map.getResult (l);
315
278
const DimLevelType dlt = enc.getLvlType (l);
316
279
if (idxReducBased && needIdxReduc) {
317
280
if (!findDepIdxSet (env.merger (), tid, l, a, dlt))
318
281
return false ; // inadmissible affine expression
319
282
} else {
320
- if (!findAffine (env.merger (), tid, l, a, dlt, filterLdx ))
283
+ if (!findAffine (env.merger (), tid, l, a, dlt))
321
284
return false ; // inadmissible affine expression
322
285
}
323
286
}
324
287
}
325
- assert (filterLdx == env.merger ().getNumLoops ());
326
288
return annotated;
327
289
}
328
290
@@ -374,13 +336,8 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
374
336
}
375
337
return init;
376
338
},
377
- [&loopRange, &env](OpBuilder &b, Location loc, Level l) {
378
- assert (l < env.getLoopNum ());
379
- // FIXME: Remove filter loop since we have a better algorithm to
380
- // deal with affine index expression.
381
- if (l >= env.merger ().getStartingFilterLoopId ())
382
- return Value ();
383
-
339
+ [&loopRange](OpBuilder &b, Location loc, Level l) {
340
+ assert (l < loopRange.size ());
384
341
return mlir::getValueOrCreateConstantIndexOp (b, loc, loopRange[l].size );
385
342
});
386
343
}
@@ -394,10 +351,7 @@ static Value genIndex(CodegenEnv &env, OpOperand *t) {
394
351
const auto stt = getSparseTensorType (t->get ());
395
352
const Level lvlRank = stt.getLvlRank ();
396
353
assert (static_cast <Level>(map.getNumResults ()) == lvlRank);
397
- // FIXME: `toOrigDim` is deprecated.
398
- // FIXME: above we asserted that there are `lvlRank` many results,
399
- // but this is assuming there are in fact `dimRank` many results instead.
400
- const AffineExpr a = map.getResult (toOrigDim (stt.getEncoding (), lvlRank - 1 ));
354
+ const AffineExpr a = map.getResult (lvlRank - 1 );
401
355
assert (a.getKind () == AffineExprKind::DimId);
402
356
const LoopId idx = env.makeLoopId (cast<AffineDimExpr>(a).getPosition ());
403
357
return env.getLoopVar (idx);
@@ -727,19 +681,8 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
727
681
const Level lvlRank = stt.getLvlRank ();
728
682
assert (static_cast <Level>(map.getNumResults ()) == lvlRank);
729
683
for (Level l = 0 ; l < lvlRank; l++) {
730
- // FIXME: `toOrigDim` is deprecated.
731
- // FIXME: above we asserted that there are `lvlRank` many results,
732
- // but this is assuming there are in fact `dimRank` many results instead.
733
- const AffineExpr a = map.getResult (toOrigDim (stt.getEncoding (), l));
734
- const auto sldx =
735
- env.merger ().getLoopId (env.makeTensorId (t.getOperandNumber ()), l);
736
- if (sldx && env.merger ().isFilterLoop (*sldx)) {
737
- if (!env.getLoopVar (*sldx))
738
- // The filter loops has not been constructed.
739
- return ;
740
- if (*sldx == ldx)
741
- isAtLoop = true ;
742
- } else if (!isInvariantAffine (a, env.getLoopDepth (), ldx, isAtLoop))
684
+ const AffineExpr a = map.getResult (l);
685
+ if (!isInvariantAffine (a, env.getLoopDepth (), ldx, isAtLoop))
743
686
return ; // still in play
744
687
}
745
688
// All exhausted at this level (isAtLoop denotes exactly at this LoopId).
@@ -1073,10 +1016,8 @@ static void genConstantDenseAddressFromLevel(CodegenEnv &env,
1073
1016
const TensorId tid = env.makeTensorId (input->getOperandNumber ());
1074
1017
const Level lvlRank = enc.getLvlRank ();
1075
1018
assert (lvlExprs.size () == static_cast <size_t >(lvlRank));
1076
- // FIXME: there is dim/lvl confusion here
1077
1019
for (Level l = startLvl; l < lvlRank; l++) {
1078
- // FIXME: `toOrigDim` is deprecated.
1079
- AffineExpr lvlExpr = lvlExprs[toOrigDim (enc, l)];
1020
+ AffineExpr lvlExpr = lvlExprs[l];
1080
1021
if (enc.isDenseLvl (l) && isa<AffineConstantExpr>(lvlExpr))
1081
1022
env.emitter ().genDenseAffineAddress (
1082
1023
builder, loc, env.makeTensorLevel (tid, l), lvlExpr);
@@ -1164,8 +1105,7 @@ static bool translateBitsToTidLvlPairs(
1164
1105
const Level lvlRank = stt.getLvlRank ();
1165
1106
assert (affines.size () == static_cast <size_t >(lvlRank));
1166
1107
for (Level l = 0 ; l < lvlRank; l++) {
1167
- // FIXME: `toOrigDim` is deprecated.
1168
- AffineExpr exp = affines[toOrigDim (stt.getEncoding (), l)];
1108
+ AffineExpr exp = affines[l];
1169
1109
// Skip simple affine expression and non-dense levels (which
1170
1110
// have their own filter loop).
1171
1111
if (isa<AffineDimExpr>(exp) || !stt.isDenseLvl (l))
@@ -1396,14 +1336,13 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1396
1336
op, " Loops not yet scheduled, try run --sparse-reinterpret-map "
1397
1337
" before sparsification." );
1398
1338
}
1339
+ // Must have been demapped as well if the generic op is sorted.
1340
+ assert (!hasAnyNonIdentityOperandsOrResults (op));
1399
1341
1400
1342
// Sets up a code generation environment.
1401
1343
const unsigned numTensors = op->getNumOperands ();
1402
1344
const unsigned numLoops = op.getNumLoops ();
1403
- const unsigned numFilterLoops = getNumNonTrivialIdxExpOnSparseLvls (op);
1404
- // TODO: we should probably always use slice-based codegen whenever
1405
- // possible, we can even intermix slice-based and filter-loop based codegen.
1406
- bool idxReducBased = numFilterLoops != 0 ;
1345
+ bool needIdxRed = getNumNonTrivialIdxExpOnSparseLvls (op) != 0 ;
1407
1346
// If we have indexing map like (d0) -> (0, d0), there might be more
1408
1347
// levels then loops because of the constant index, that means we can not
1409
1348
// use numLoops as the upper bound for ranks of all tensors.
@@ -1417,14 +1356,10 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1417
1356
}
1418
1357
}
1419
1358
1420
- // A slice based algorithm for affine indices does not need filter loops.
1421
- CodegenEnv env (op, options, numTensors, numLoops,
1422
- /* numFilterLoops=*/ idxReducBased ? 0 : numFilterLoops,
1423
- maxLvlRank);
1424
-
1359
+ CodegenEnv env (op, options, numTensors, numLoops, maxLvlRank);
1425
1360
// Detects sparse annotations and translates the per-level sparsity
1426
1361
// information for all tensors to loop indices in the kernel.
1427
- if (!findSparseAnnotations (env, idxReducBased ))
1362
+ if (!findSparseAnnotations (env, needIdxRed ))
1428
1363
return failure ();
1429
1364
1430
1365
// Only standard reduction operations (add, sub, or, xor) that can be
0 commit comments