@@ -43,29 +43,25 @@ static Value genIndexLoad(OpBuilder &builder, Location loc, Value mem,
43
43
return load;
44
44
}
45
45
46
- // TODO: Support dynamic sized slice.
47
- static Value getSliceOffset (OpBuilder &builder, Location loc,
48
- SparseTensorEncodingAttr enc, unsigned lvl) {
49
- return constantIndex (builder, loc, *enc.getStaticLvlSliceOffset (lvl));
46
+ static Value genSliceOffset (OpBuilder &builder, Location loc, Value tensor,
47
+ unsigned lvl) {
48
+ auto enc = getSparseTensorEncoding (tensor.getType ());
49
+ // FIXME: `toOrigDim` is deprecated
50
+ return createOrFoldSliceOffsetOp (builder, loc, tensor, toOrigDim (enc, lvl));
50
51
}
51
52
52
- static Value getSliceSize (OpBuilder &builder, Location loc,
53
- SparseTensorEncodingAttr enc, unsigned lvl) {
54
- return constantIndex (builder, loc, *enc.getStaticLvlSliceSize (lvl));
55
- }
56
-
57
- static Value getSliceStride (OpBuilder &builder, Location loc,
58
- SparseTensorEncodingAttr enc, unsigned lvl) {
59
- return constantIndex (builder, loc, *enc.getStaticLvlSliceStride (lvl));
53
+ static Value genSliceStride (OpBuilder &builder, Location loc, Value tensor,
54
+ unsigned lvl) {
55
+ auto enc = getSparseTensorEncoding (tensor.getType ());
56
+ // FIXME: `toOrigDim` is deprecated
57
+ return createOrFoldSliceStrideOp (builder, loc, tensor, toOrigDim (enc, lvl));
60
58
}
61
59
62
60
// Converts a coordinate relative to the slice to the coordinate relative
63
61
// to the underlying tensor.
64
62
static Value toSliceCoord (OpBuilder &builder, Location loc, Value v,
65
- SparseTensorEncodingAttr enc, unsigned lvl) {
66
-
67
- Value stride = getSliceStride (builder, loc, enc, lvl);
68
- Value offset = getSliceOffset (builder, loc, enc, lvl);
63
+ Value offset, Value stride, Value tensor,
64
+ unsigned lvl) {
69
65
// iv = iv * stride + offset
70
66
v = builder.create <arith::MulIOp>(loc, v, stride);
71
67
v = builder.create <arith::AddIOp>(loc, v, offset);
@@ -75,40 +71,58 @@ static Value toSliceCoord(OpBuilder &builder, Location loc, Value v,
75
71
// Converts a coordinate relative to the underlying tensor to the coordinate
76
72
// relative to the slice, returns a extra reminder value
77
73
static std::pair<Value, Value> fromSliceCrd (OpBuilder &builder, Location loc,
78
- Value v ,
79
- SparseTensorEncodingAttr enc ,
74
+ Value iv, Value offset ,
75
+ Value stride, Value tensor ,
80
76
unsigned lvl) {
81
- Value stride = getSliceStride (builder, loc, enc, lvl);
82
- Value offset = getSliceOffset (builder, loc, enc, lvl);
83
77
// iv = (iv - offset) / stride
84
- v = builder.create <arith::SubIOp>(loc, v , offset);
85
- Value rem = builder.create <arith::RemUIOp>(loc, v , stride);
86
- v = builder.create <arith::DivUIOp>(loc, v , stride);
87
- return std::make_pair (v , rem);
78
+ iv = builder.create <arith::SubIOp>(loc, iv , offset);
79
+ Value rem = builder.create <arith::RemUIOp>(loc, iv , stride);
80
+ iv = builder.create <arith::DivUIOp>(loc, iv , stride);
81
+ return std::make_pair (iv , rem);
88
82
}
89
83
90
- static std::pair<Value, Value>
91
- genSliceLegitPredicate (OpBuilder &builder, Location loc, Value crd,
92
- SparseTensorEncodingAttr enc, unsigned lvl) {
93
- std::pair<Value, Value> trans = fromSliceCrd (builder, loc, crd, enc, lvl);
94
- // First, crd >= offset (TODO: seems unsigned >= 0 won't be folded, skip
95
- // the check if the offset is zero).
96
- auto geOffset =
97
- builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::uge, crd,
98
- getSliceOffset (builder, loc, enc, lvl));
84
+ std::pair<Value, Value>
85
+ LoopEmitter::genSliceLegitPredicate (OpBuilder &builder, Location loc, Value crd,
86
+ unsigned tid, unsigned lvl) {
87
+ assert (isSparseSlices[tid]);
88
+ Value slice = tensors[tid];
89
+ Value offset = sliceOffsets[tid][lvl];
90
+ Value stride = sliceStrides[tid][lvl];
91
+ auto enc = getSparseTensorEncoding (slice.getType ());
92
+
93
+ std::pair<Value, Value> transformedCrd =
94
+ fromSliceCrd (builder, loc, crd, offset, stride, slice, lvl);
95
+
96
+ SmallVector<Value, 3 > conds; // at most 3 conditions
97
+
98
+ // First, coord >= offset (skip the check if offset is known to be 0).
99
+ if (auto staticOffset = enc.getStaticLvlSliceOffset (lvl);
100
+ !(staticOffset.has_value () && *staticOffset == 0 )) {
101
+ auto geOffset = builder.create <arith::CmpIOp>(
102
+ loc, arith::CmpIPredicate::uge, crd, offset);
103
+ conds.push_back (geOffset);
104
+ }
105
+
99
106
// Second, coord_in_slice < length
100
- auto ltLength =
101
- builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::ult, trans.first ,
102
- getSliceSize (builder, loc, enc, lvl));
103
-
104
- // Third, rem == 0; confirmed that (a % 1) will be folded to 0
105
- auto fitStride =
106
- builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::eq, trans.second ,
107
- constantIndex (builder, loc, 0 ));
108
-
109
- auto pred = builder.create <arith::AndIOp>(loc, geOffset, ltLength);
110
- pred = builder.create <arith::AndIOp>(loc, pred, fitStride);
111
- return {trans.first , pred};
107
+ auto ltLength = builder.create <arith::CmpIOp>(
108
+ loc, arith::CmpIPredicate::ult, transformedCrd.first , lvlSizes[tid][lvl]);
109
+ conds.push_back (ltLength);
110
+
111
+ // Third, rem == 0 (skip the check if stride is known to be 1).
112
+ if (auto staticStride = enc.getStaticLvlSliceStride (lvl);
113
+ !(staticStride.has_value () && *staticStride == 1 )) {
114
+ auto fitStride = builder.create <arith::CmpIOp>(
115
+ loc, arith::CmpIPredicate::eq, transformedCrd.second ,
116
+ constantIndex (builder, loc, 0 ));
117
+ conds.push_back (fitStride);
118
+ }
119
+
120
+ // Must meet all condition to be a valid coordinate in slice.
121
+ auto pred = conds.front ();
122
+ for (auto cond : ValueRange (conds).drop_front ())
123
+ pred = builder.create <arith::AndIOp>(loc, pred, cond);
124
+
125
+ return {transformedCrd.first , pred};
112
126
}
113
127
114
128
// ===----------------------------------------------------------------------===//
@@ -119,10 +133,9 @@ Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, size_t tid,
119
133
size_t dim, Value iv) {
120
134
Value p = dim == 0 ? constantIndex (builder, loc, 0 ) : pidxs[tid][dim - 1 ];
121
135
Value mul = builder.create <arith::MulIOp>(loc, highs[tid][dim], p);
122
- if (isSparseSlices[tid]) {
123
- auto enc = getSparseTensorEncoding (tensors[tid].getType ());
124
- iv = toSliceCoord (builder, loc, iv, enc, dim);
125
- }
136
+ if (isSparseSlices[tid])
137
+ iv = toSliceCoord (builder, loc, iv, sliceOffsets[tid][dim],
138
+ sliceStrides[tid][dim], tensors[tid], dim);
126
139
Value add = builder.create <arith::AddIOp>(loc, mul, iv);
127
140
return add;
128
141
}
@@ -204,6 +217,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
204
217
this ->isSparseOut = isSparseOut;
205
218
this ->tensors .assign (ts.begin (), ts.end ());
206
219
this ->isSparseSlices .assign (tensors.size (), false );
220
+ this ->sliceOffsets .assign (tensors.size (), std::vector<Value>());
221
+ this ->sliceStrides .assign (tensors.size (), std::vector<Value>());
207
222
this ->dimTypes .assign (tensors.size (), std::vector<DimLevelType>());
208
223
this ->pidxs .assign (tensors.size (), std::vector<Value>());
209
224
this ->segHi .assign (tensors.size (), std::vector<Value>());
@@ -246,6 +261,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
246
261
dimTypes[tid].assign (rank, DimLevelType::Dense);
247
262
248
263
// Initialize using empty value.
264
+ sliceOffsets[tid].assign (rank, Value ());
265
+ sliceStrides[tid].assign (rank, Value ());
249
266
pidxs[tid].assign (rank, Value ());
250
267
segHi[tid].assign (rank, Value ());
251
268
coord[tid].assign (rank, Value ());
@@ -300,11 +317,17 @@ void LoopEmitter::initializeLoopEmit(OpBuilder &builder, Location loc,
300
317
assert (isDenseDLT (dlt));
301
318
}
302
319
303
- // Find upper bound in current dimension.
304
320
// FIXME: `toOrigDim` is deprecated
305
- const Dimension d = toOrigDim (enc, l);
306
- lvlSizes[t][l] = highs[t][l] =
307
- mlir::linalg::createOrFoldDimOp (builder, loc, tensor, d);
321
+ // Since we do not have HigherOrdering now, we can always rely on the 1:1
322
+ // mapping from level to dimension to retrieve the level size.
323
+ Value lvlSz = mlir::linalg::createOrFoldDimOp (builder, loc, tensor,
324
+ toOrigDim (enc, l));
325
+ // Find upper bound in current dimension.
326
+ highs[t][l] = lvlSizes[t][l] = lvlSz;
327
+ if (isSparseSlices[t]) {
328
+ sliceOffsets[t][l] = genSliceOffset (builder, loc, tensors[t], l);
329
+ sliceStrides[t][l] = genSliceStride (builder, loc, tensors[t], l);
330
+ }
308
331
}
309
332
310
333
// Perform the required bufferization. Dense inputs materialize
@@ -405,7 +428,6 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
405
428
isSparseInput = isSparseInput || isSparse;
406
429
}
407
430
408
- auto enc = getSparseTensorEncoding (tensors[tid].getType ());
409
431
const auto reassoc = getCollapseReassociation (tid, dim);
410
432
// TODO: support dynamic slices.
411
433
// Uses the first dimension here to build the loop bound (which is also the
@@ -468,7 +490,7 @@ Operation *LoopEmitter::enterLoopOverTensorAtDim(
468
490
for (Value red : reduc)
469
491
types.push_back (red.getType ());
470
492
471
- auto [trans, pred] = genSliceLegitPredicate (builder, loc, crd, enc , dim);
493
+ auto [trans, pred] = genSliceLegitPredicate (builder, loc, crd, tid , dim);
472
494
bool hasReduc = !types.empty ();
473
495
scf::IfOp ifOp = builder.create <scf::IfOp>(loc, types, pred,
474
496
/* else*/ hasReduc);
@@ -660,11 +682,8 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
660
682
isSingletonDLT (dimTypes[tid][dim])) {
661
683
coord[tid][dim] = genSparseCrd (builder, loc, tid, dim);
662
684
if (isSparseSlices[tid]) {
663
- Value load =
664
- genIndexLoad (builder, loc, crdBuffer[tid][dim], pidxs[tid][dim]);
665
- auto enc = getSparseTensorEncoding (tensors[tid].getType ());
666
685
auto [trans, pred] =
667
- genSliceLegitPredicate (builder, loc, load, enc , dim);
686
+ genSliceLegitPredicate (builder, loc, coord[tid][dim], tid , dim);
668
687
slicesPreds.emplace_back (pred, i);
669
688
// Updates to the relative coordinate to the slice.
670
689
coord[tid][dim] = trans;
@@ -679,7 +698,7 @@ Operation *LoopEmitter::enterCoIterationOverTensorsAtDims(
679
698
// Generates a list of if statments
680
699
// pidx = in_slice ? pidx : pidx + 1
681
700
// TODO: instead of always picking pidx + 1, we should set pidx = high to
682
- // break to loop the coordinates is larger than the slice size.
701
+ // break to loop if the coordinates is larger than the slice size.
683
702
for (auto [pred, idx] : slicesPreds) {
684
703
Value nextPidx = builder.create <arith::AddIOp>(
685
704
loc, yields[idx], constantIndex (builder, loc, 1 ));
0 commit comments