@@ -232,7 +232,10 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
232
232
this ->hasOutput = hasOutput;
233
233
this ->isSparseOut = isSparseOut;
234
234
235
- const unsigned numTensors = ts.size ();
235
+ const unsigned numManifestTensors = ts.size ();
236
+ const unsigned synTensorId = numManifestTensors;
237
+ const unsigned numTensors = numManifestTensors + 1 ;
238
+
236
239
this ->tensors .assign (ts.begin (), ts.end ());
237
240
this ->lvlTypes .assign (numTensors, std::vector<DimLevelType>());
238
241
this ->lvlSizes .assign (numTensors, std::vector<Value>());
@@ -265,33 +268,43 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
265
268
266
269
// Initialize nested types of `TensorId`-indexed fields.
267
270
for (TensorId tid = 0 ; tid < numTensors; tid++) {
268
- const Value t = tensors[tid];
269
- // a scalar or 0-dimension tensors
270
- if (isZeroRankedTensorOrScalar (t.getType ()))
271
- continue ;
272
-
273
- auto rtp = getRankedTensorType (t);
274
- if (auto reshape = t.getDefiningOp <tensor::CollapseShapeOp>();
275
- isUniqueCOOType (rtp) && reshape) {
276
- // TODO: Supports more kinds of sparse tensors.
277
- // FIXME: We should instead lower reshape operations on sparse tensors to
278
- // view change.
279
- collapseReassoc[tid] = reshape.getReassociation ();
280
- rtp = reshape.getSrcType ();
281
- // Overwrites the tensor to the source tensor of reshape operations.
282
- tensors[tid] = reshape.getSrc ();
283
- }
284
- const SparseTensorType stt (rtp);
285
- const Level lvlRank = stt.getLvlRank ();
286
- // We always treat sparse output tensor as dense so that we always iterate
287
- // it based on lvl size.
288
- if (stt.hasEncoding () && !(isOutputTensor (tid) && isSparseOut)) {
289
- const auto enc = stt.getEncoding ();
290
- isSparseSlices[tid] = enc.isSlice ();
291
- for (auto lvlTp : enc.getLvlTypes ())
292
- lvlTypes[tid].push_back (lvlTp);
293
- } else {
271
+ Level lvlRank;
272
+ if (tid == synTensorId) {
273
+ // Synthetic tensor (conceptually) is an all-dense tensor with rank equal
274
+ // to the total number of loops (each level can potentially be mapped to
275
+ // one of the loop being generated).
276
+ lvlRank = numLoops;
294
277
lvlTypes[tid].assign (lvlRank, DimLevelType::Dense);
278
+ } else {
279
+ const Value t = tensors[tid];
280
+ // a scalar or 0-dimension tensors
281
+ if (isZeroRankedTensorOrScalar (t.getType ()))
282
+ continue ;
283
+
284
+ auto rtp = getRankedTensorType (t);
285
+ if (auto reshape = t.getDefiningOp <tensor::CollapseShapeOp>();
286
+ isUniqueCOOType (rtp) && reshape) {
287
+ // TODO: Supports more kinds of sparse tensors.
288
+ // FIXME: We should instead lower reshape operations on sparse tensors
289
+ // to view change.
290
+ collapseReassoc[tid] = reshape.getReassociation ();
291
+ rtp = reshape.getSrcType ();
292
+ // Overwrites the tensor to the source tensor of reshape operations.
293
+ tensors[tid] = reshape.getSrc ();
294
+ }
295
+ const SparseTensorType stt (rtp);
296
+ lvlRank = stt.getLvlRank ();
297
+
298
+ // We always treat sparse output tensor as dense so that we always iterate
299
+ // it based on lvl size.
300
+ if (stt.hasEncoding () && !(isOutputTensor (tid) && isSparseOut)) {
301
+ const auto enc = stt.getEncoding ();
302
+ isSparseSlices[tid] = enc.isSlice ();
303
+ for (auto lvlTp : enc.getLvlTypes ())
304
+ lvlTypes[tid].push_back (lvlTp);
305
+ } else {
306
+ lvlTypes[tid].assign (lvlRank, DimLevelType::Dense);
307
+ }
295
308
}
296
309
297
310
// Initialize using empty value.
@@ -314,7 +327,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
314
327
sliceStack[tid].emplace_back (/* minCrd=*/ Value (),
315
328
/* offset=*/ Value (), /* isNonEmpty*/ Value (),
316
329
std::nullopt, 0 );
317
- if (dimGetter) {
330
+ if (dimGetter && ! isSynTensor (tid) ) {
318
331
auto reassoc = collapseReassoc[tid];
319
332
Level dstRank = reassoc ? reassoc.size () : lvlRank;
320
333
for (Level l = 0 ; l < dstRank; l++) {
@@ -461,15 +474,28 @@ void LoopEmitter::enterNewLoopSeq(OpBuilder &builder, Location loc,
461
474
assert (loopSeqStack.size () == loopStack.size ());
462
475
// Prepares for all the tensors used in the current loop sequence.
463
476
std::vector<std::tuple<TensorId, Level, bool >> slicedTids;
477
+
478
+ bool hasSynTensor = false ;
479
+ std::optional<std::pair<TensorId, Level>> loopBoundDefLevel = std::nullopt;
464
480
for (auto [tid, lvl] : unpackTensorLevelRange (tidLvls)) {
465
481
if (!dependentLvlMap[tid][lvl].empty ()) {
466
482
bool fullyRed = genSliceBegin (builder, loc, tid, lvl);
467
483
slicedTids.emplace_back (tid, lvl, fullyRed);
468
484
} else {
469
- prepareLoopOverTensorAtLvl (builder, loc, tid, lvl);
485
+ if (isSynTensor (tid)) {
486
+ hasSynTensor = true ;
487
+ } else {
488
+ loopBoundDefLevel = std::make_pair (tid, lvl);
489
+ prepareLoopOverTensorAtLvl (builder, loc, tid, lvl);
490
+ }
470
491
}
471
492
}
472
493
494
+ if (hasSynTensor && loopBoundDefLevel.has_value ()) {
495
+ // TODO: compute the loopBound for index reduction by d - sum(unres_lvls).
496
+ highs[getSynTensorId ()][getCurrentDepth ()] =
497
+ lvlSizes[loopBoundDefLevel->first ][loopBoundDefLevel->second ];
498
+ }
473
499
// Universal Index starts from 0.
474
500
loopSeqStack.emplace_back (C_IDX (0 ), std::move (slicedTids));
475
501
}
@@ -1137,6 +1163,9 @@ void LoopEmitter::emitExtraLocalsForTensorsAtDenseLvls(
1137
1163
// output tensor unconditionally, since they may not appear in the lattice,
1138
1164
// but may be needed for linearized codegen.
1139
1165
for (auto [tid, lvl] : unpackTensorLevelRange (tidLvls)) {
1166
+ if (isSynTensor (tid))
1167
+ continue ;
1168
+
1140
1169
if (isDenseDLT (lvlTypes[tid][lvl])) {
1141
1170
// Slice-driven dense level should have be handled already.
1142
1171
if (!dependentLvlMap[tid][lvl].empty ())
0 commit comments