@@ -195,16 +195,39 @@ static Value genIsNonzero(ConversionPatternRewriter &rewriter, Location loc,
195
195
llvm_unreachable (" Unknown element type" );
196
196
}
197
197
198
+ // / Generates the code to read the value from tensor[ivs], and conditionally
199
+ // / stores the indices ivs to the memory in ind. The generated code looks like
200
+ // / the following and the insertion point after this routine is inside the
201
+ // / if-then branch behind the assignment to ind. This is to ensure that the
202
+ // / addEltX call generated after is inside the if-then branch.
203
+ // / if (tensor[ivs]!=0) {
204
+ // / ind = ivs
205
+ static Value genIndexAndValueForDense (ConversionPatternRewriter &rewriter,
206
+ Operation *op, Type eltType, Value tensor,
207
+ Value ind, ValueRange ivs) {
208
+ Location loc = op->getLoc ();
209
+ Value val = rewriter.create <tensor::ExtractOp>(loc, tensor, ivs);
210
+ Value cond = genIsNonzero (rewriter, loc, eltType, val);
211
+ scf::IfOp ifOp = rewriter.create <scf::IfOp>(loc, cond, /* else*/ false );
212
+ rewriter.setInsertionPointToStart (&ifOp.thenRegion ().front ());
213
+ unsigned i = 0 ;
214
+ for (auto iv : ivs) {
215
+ Value idx = rewriter.create <ConstantOp>(loc, rewriter.getIndexAttr (i++));
216
+ rewriter.create <memref::StoreOp>(loc, iv, ind, idx);
217
+ }
218
+ return val;
219
+ }
220
+
198
221
// / Generates a call that adds one element to a coordinate scheme.
199
222
// / In particular, this generates code like the following:
200
223
// / val = a[i1,..,ik];
201
224
// / if val != 0
202
225
// / t->add(val, [i1,..,ik], [p1,..,pk]);
203
226
static void genAddEltCall (ConversionPatternRewriter &rewriter, Operation *op,
204
- Value ptr, Value tensor, Value ind, Value perm,
205
- ValueRange ivs) {
227
+ Type eltType, Value ptr, Value val, Value ind,
228
+ Value perm) {
229
+ Location loc = op->getLoc ();
206
230
StringRef name;
207
- Type eltType = tensor.getType ().cast <ShapedType>().getElementType ();
208
231
if (eltType.isF64 ())
209
232
name = " addEltF64" ;
210
233
else if (eltType.isF32 ())
@@ -219,16 +242,6 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
219
242
name = " addEltI8" ;
220
243
else
221
244
llvm_unreachable (" Unknown element type" );
222
- Location loc = op->getLoc ();
223
- Value val = rewriter.create <tensor::ExtractOp>(loc, tensor, ivs);
224
- Value cond = genIsNonzero (rewriter, loc, eltType, val);
225
- scf::IfOp ifOp = rewriter.create <scf::IfOp>(loc, cond, /* else*/ false );
226
- rewriter.setInsertionPointToStart (&ifOp.thenRegion ().front ());
227
- unsigned i = 0 ;
228
- for (auto iv : ivs) {
229
- Value idx = rewriter.create <ConstantOp>(loc, rewriter.getIndexAttr (i++));
230
- rewriter.create <memref::StoreOp>(loc, iv, ind, idx);
231
- }
232
245
SmallVector<Value, 8 > params;
233
246
params.push_back (ptr);
234
247
params.push_back (val);
@@ -240,6 +253,41 @@ static void genAddEltCall(ConversionPatternRewriter &rewriter, Operation *op,
240
253
params);
241
254
}
242
255
256
+ // / If the tensor is a sparse constant, generates and returns the pair of
257
+ // / the constants for the indices and the values.
258
+ static Optional<std::pair<Value, Value>>
259
+ genSplitSparseConstant (ConversionPatternRewriter &rewriter, ConvertOp op,
260
+ Value tensor) {
261
+ if (auto constOp = tensor.getDefiningOp <ConstantOp>()) {
262
+ if (auto attr = constOp.value ().dyn_cast <SparseElementsAttr>()) {
263
+ Location loc = op->getLoc ();
264
+ DenseElementsAttr indicesAttr = attr.getIndices ();
265
+ Value indices = rewriter.create <ConstantOp>(loc, indicesAttr);
266
+ DenseElementsAttr valuesAttr = attr.getValues ();
267
+ Value values = rewriter.create <ConstantOp>(loc, valuesAttr);
268
+ return std::make_pair (indices, values);
269
+ }
270
+ }
271
+ return {};
272
+ }
273
+
274
+ // / Generates the code to copy the index at indices[ivs] to ind, and return
275
+ // / the value at value[ivs].
276
+ static Value genIndexAndValueForSparse (ConversionPatternRewriter &rewriter,
277
+ Operation *op, Value indices,
278
+ Value values, Value ind, ValueRange ivs,
279
+ unsigned rank) {
280
+ Location loc = op->getLoc ();
281
+ for (unsigned i = 0 ; i < rank; i++) {
282
+ Value idx = rewriter.create <ConstantOp>(loc, rewriter.getIndexAttr (i));
283
+ Value val = rewriter.create <tensor::ExtractOp>(loc, indices,
284
+ ValueRange{ivs[0 ], idx});
285
+ val = rewriter.create <IndexCastOp>(loc, val, rewriter.getIndexType ());
286
+ rewriter.create <memref::StoreOp>(loc, val, ind, idx);
287
+ }
288
+ return rewriter.create <tensor::ExtractOp>(loc, values, ivs[0 ]);
289
+ }
290
+
243
291
// ===----------------------------------------------------------------------===//
244
292
// Conversion rules.
245
293
// ===----------------------------------------------------------------------===//
@@ -330,15 +378,26 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
330
378
// TODO: sparse => dense
331
379
return failure ();
332
380
}
333
- // This is a dense => sparse conversion, which is handled as follows:
381
+ // This is a dense => sparse conversion or a sparse constant in COO =>
382
+ // sparse conversion, which is handled as follows:
334
383
// t = newSparseCOO()
384
+ // ...code to fill the COO tensor t...
385
+ // s = newSparseTensor(t)
386
+ //
387
+ // To fill the COO tensor from a dense tensor:
335
388
// for i1 in dim1
336
389
// ..
337
390
// for ik in dimk
338
391
// val = a[i1,..,ik]
339
392
// if val != 0
340
393
// t->add(val, [i1,..,ik], [p1,..,pk])
341
- // s = newSparseTensor(t)
394
+ //
395
+ // To fill the COO tensor from a sparse constant in COO format:
396
+ // for i in range(NNZ)
397
+ // val = values[i]
398
+ // [i1,..,ik] = indices[i]
399
+ // t->add(val, [i1,..,ik], [p1,..,pk])
400
+ //
342
401
// Note that the dense tensor traversal code is actually implemented
343
402
// using MLIR IR to avoid having to expose too much low-level
344
403
// memref traversal details to the runtime support library.
@@ -351,7 +410,6 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
351
410
MemRefType::get ({ShapedType::kDynamicSize }, rewriter.getIndexType ());
352
411
Value perm;
353
412
Value ptr = genNewCall (rewriter, op, encDst, 2 , perm);
354
- Value tensor = adaptor.getOperands ()[0 ];
355
413
Value arg = rewriter.create <ConstantOp>(
356
414
loc, rewriter.getIndexAttr (shape.getRank ()));
357
415
Value ind = rewriter.create <memref::AllocaOp>(loc, memTp, ValueRange{arg});
@@ -360,16 +418,38 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
360
418
SmallVector<Value> st;
361
419
Value zero = rewriter.create <ConstantOp>(loc, rewriter.getIndexAttr (0 ));
362
420
Value one = rewriter.create <ConstantOp>(loc, rewriter.getIndexAttr (1 ));
363
- for (unsigned i = 0 , rank = shape.getRank (); i < rank; i++) {
421
+ Value tensor = adaptor.getOperands ()[0 ];
422
+ auto indicesValues = genSplitSparseConstant (rewriter, op, tensor);
423
+ bool isCOOConstant = indicesValues.hasValue ();
424
+ Value indices;
425
+ Value values;
426
+ if (isCOOConstant) {
427
+ indices = indicesValues->first ;
428
+ values = indicesValues->second ;
364
429
lo.push_back (zero);
365
- hi.push_back (linalg::createOrFoldDimOp (rewriter, loc, tensor, i ));
430
+ hi.push_back (linalg::createOrFoldDimOp (rewriter, loc, values, 0 ));
366
431
st.push_back (one);
432
+ } else {
433
+ for (unsigned i = 0 , rank = shape.getRank (); i < rank; i++) {
434
+ lo.push_back (zero);
435
+ hi.push_back (linalg::createOrFoldDimOp (rewriter, loc, tensor, i));
436
+ st.push_back (one);
437
+ }
367
438
}
439
+ Type eltType = shape.getElementType ();
440
+ unsigned rank = shape.getRank ();
368
441
scf::buildLoopNest (rewriter, op.getLoc (), lo, hi, st, {},
369
442
[&](OpBuilder &builder, Location loc, ValueRange ivs,
370
443
ValueRange args) -> scf::ValueVector {
371
- genAddEltCall (rewriter, op, ptr, tensor, ind, perm,
372
- ivs);
444
+ Value val;
445
+ if (isCOOConstant)
446
+ val = genIndexAndValueForSparse (
447
+ rewriter, op, indices, values, ind, ivs, rank);
448
+ else
449
+ val = genIndexAndValueForDense (rewriter, op, eltType,
450
+ tensor, ind, ivs);
451
+ genAddEltCall (rewriter, op, eltType, ptr, val, ind,
452
+ perm);
373
453
return {};
374
454
});
375
455
rewriter.replaceOp (op, genNewCall (rewriter, op, encDst, 1 , perm, ptr));
0 commit comments