|
10 | 10 | // visible buffers and actual compiler IR that implements these primitives on
|
11 | 11 | // the selected sparse tensor storage schemes. This pass provides an alternative
|
12 | 12 | // to the SparseTensorConversion pass, eliminating the dependence on a runtime
|
13 |
| -// support library, and providing much more opportunities for subsequent |
14 |
| -// compiler optimization of the generated code. |
| 13 | +// support library (other than for file I/O), and providing many more |
| 14 | +// opportunities for subsequent compiler optimization of the generated code. |
15 | 15 | //
|
16 | 16 | //===----------------------------------------------------------------------===//
|
17 | 17 |
|
|
37 | 37 | using namespace mlir;
|
38 | 38 | using namespace mlir::sparse_tensor;
|
39 | 39 |
|
40 |
| -namespace { |
41 |
| - |
42 |
| -using FuncGeneratorType = |
43 |
| - function_ref<void(OpBuilder &, ModuleOp, func::FuncOp, RankedTensorType)>; |
44 |
| - |
45 | 40 | //===----------------------------------------------------------------------===//
|
46 | 41 | // Helper methods.
|
47 | 42 | //===----------------------------------------------------------------------===//
|
48 | 43 |
|
49 |
| -/// Flatten a list of operands that may contain sparse tensors. |
| 44 | +/// Flattens a list of operands that may contain sparse tensors. |
50 | 45 | static void flattenOperands(ValueRange operands,
|
51 | 46 | SmallVectorImpl<Value> &flattened) {
|
52 | 47 | // In case of
|
@@ -97,6 +92,7 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value upper,
|
97 | 92 | return forOp;
|
98 | 93 | }
|
99 | 94 |
|
| 95 | +/// Creates a push back operation. |
100 | 96 | static void createPushback(OpBuilder &builder, Location loc,
|
101 | 97 | MutSparseTensorDescriptor desc,
|
102 | 98 | SparseTensorFieldKind kind, std::optional<Level> lvl,
|
@@ -368,6 +364,95 @@ static Value genCompressed(OpBuilder &builder, Location loc,
|
368 | 364 | return ifOp2.getResult(o);
|
369 | 365 | }
|
370 | 366 |
|
| 367 | +/// Generates insertion finalization code. |
| 368 | +static void genEndInsert(OpBuilder &builder, Location loc, |
| 369 | + SparseTensorDescriptor desc) { |
| 370 | + const SparseTensorType stt(desc.getRankedTensorType()); |
| 371 | + const Level lvlRank = stt.getLvlRank(); |
| 372 | + for (Level l = 0; l < lvlRank; l++) { |
| 373 | + const auto dlt = stt.getLvlType(l); |
| 374 | + if (isLooseCompressedDLT(dlt)) |
| 375 | + llvm_unreachable("TODO: Not yet implemented"); |
| 376 | + if (isCompressedDLT(dlt)) { |
| 377 | + // Compressed dimensions need a position cleanup for all entries |
| 378 | + // that were not visited during the insertion pass. |
| 379 | + // |
| 380 | + // TODO: avoid cleanup and keep compressed scheme consistent at all |
| 381 | + // times? |
| 382 | + // |
| 383 | + if (l > 0) { |
| 384 | + Type posType = stt.getPosType(); |
| 385 | + Value posMemRef = desc.getPosMemRef(l); |
| 386 | + Value hi = desc.getPosMemSize(builder, loc, l); |
| 387 | + Value zero = constantIndex(builder, loc, 0); |
| 388 | + Value one = constantIndex(builder, loc, 1); |
| 389 | + // Vector of only one, but needed by createFor's prototype. |
| 390 | + SmallVector<Value, 1> inits{genLoad(builder, loc, posMemRef, zero)}; |
| 391 | + scf::ForOp loop = createFor(builder, loc, hi, inits, one); |
| 392 | + Value i = loop.getInductionVar(); |
| 393 | + Value oldv = loop.getRegionIterArg(0); |
| 394 | + Value newv = genLoad(builder, loc, posMemRef, i); |
| 395 | + Value posZero = constantZero(builder, loc, posType); |
| 396 | + Value cond = builder.create<arith::CmpIOp>( |
| 397 | + loc, arith::CmpIPredicate::eq, newv, posZero); |
| 398 | + scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(posType), |
| 399 | + cond, /*else*/ true); |
| 400 | + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
| 401 | + genStore(builder, loc, oldv, posMemRef, i); |
| 402 | + builder.create<scf::YieldOp>(loc, oldv); |
| 403 | + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
| 404 | + builder.create<scf::YieldOp>(loc, newv); |
| 405 | + builder.setInsertionPointAfter(ifOp); |
| 406 | + builder.create<scf::YieldOp>(loc, ifOp.getResult(0)); |
| 407 | + builder.setInsertionPointAfter(loop); |
| 408 | + } |
| 409 | + } else { |
| 410 | + assert(isDenseDLT(dlt) || isSingletonDLT(dlt)); |
| 411 | + } |
| 412 | + } |
| 413 | +} |
| 414 | + |
| 415 | +/// Generates a subview into the sizes. |
| 416 | +static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, |
| 417 | + Value sz) { |
| 418 | + auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType(); |
| 419 | + return builder |
| 420 | + .create<memref::SubViewOp>( |
| 421 | + loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem, |
| 422 | + ValueRange{}, ValueRange{sz}, ValueRange{}, |
| 423 | + ArrayRef<int64_t>{0}, // static offset |
| 424 | + ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size |
| 425 | + ArrayRef<int64_t>{1}) // static stride |
| 426 | + .getResult(); |
| 427 | +} |
| 428 | + |
| 429 | +/// Creates the reassociation array. |
| 430 | +static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) { |
| 431 | + ReassociationIndices reassociation; |
| 432 | + for (int i = 0, e = srcTp.getRank(); i < e; i++) |
| 433 | + reassociation.push_back(i); |
| 434 | + return reassociation; |
| 435 | +} |
| 436 | + |
| 437 | +/// Generates scalar to tensor cast. |
| 438 | +static Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, |
| 439 | + Type dstTp) { |
| 440 | + if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) { |
| 441 | + // Scalars can only be converted to 0-ranked tensors. |
| 442 | + if (rtp.getRank() != 0) |
| 443 | + return nullptr; |
| 444 | + elem = genCast(builder, loc, elem, rtp.getElementType()); |
| 445 | + return builder.create<tensor::FromElementsOp>(loc, rtp, elem); |
| 446 | + } |
| 447 | + return genCast(builder, loc, elem, dstTp); |
| 448 | +} |
| 449 | + |
| 450 | +//===----------------------------------------------------------------------===// |
| 451 | +// Codegen rules. |
| 452 | +//===----------------------------------------------------------------------===// |
| 453 | + |
| 454 | +namespace { |
| 455 | + |
371 | 456 | /// Helper class to help lowering sparse_tensor.insert operation.
|
372 | 457 | class SparseInsertGenerator
|
373 | 458 | : public FuncCallOrInlineGenerator<SparseInsertGenerator> {
|
@@ -472,90 +557,6 @@ class SparseInsertGenerator
|
472 | 557 | TensorType rtp;
|
473 | 558 | };
|
474 | 559 |
|
475 |
| -/// Generations insertion finalization code. |
476 |
| -static void genEndInsert(OpBuilder &builder, Location loc, |
477 |
| - SparseTensorDescriptor desc) { |
478 |
| - const SparseTensorType stt(desc.getRankedTensorType()); |
479 |
| - const Level lvlRank = stt.getLvlRank(); |
480 |
| - for (Level l = 0; l < lvlRank; l++) { |
481 |
| - const auto dlt = stt.getLvlType(l); |
482 |
| - if (isLooseCompressedDLT(dlt)) |
483 |
| - llvm_unreachable("TODO: Not yet implemented"); |
484 |
| - if (isCompressedDLT(dlt)) { |
485 |
| - // Compressed dimensions need a position cleanup for all entries |
486 |
| - // that were not visited during the insertion pass. |
487 |
| - // |
488 |
| - // TODO: avoid cleanup and keep compressed scheme consistent at all |
489 |
| - // times? |
490 |
| - // |
491 |
| - if (l > 0) { |
492 |
| - Type posType = stt.getPosType(); |
493 |
| - Value posMemRef = desc.getPosMemRef(l); |
494 |
| - Value hi = desc.getPosMemSize(builder, loc, l); |
495 |
| - Value zero = constantIndex(builder, loc, 0); |
496 |
| - Value one = constantIndex(builder, loc, 1); |
497 |
| - // Vector of only one, but needed by createFor's prototype. |
498 |
| - SmallVector<Value, 1> inits{genLoad(builder, loc, posMemRef, zero)}; |
499 |
| - scf::ForOp loop = createFor(builder, loc, hi, inits, one); |
500 |
| - Value i = loop.getInductionVar(); |
501 |
| - Value oldv = loop.getRegionIterArg(0); |
502 |
| - Value newv = genLoad(builder, loc, posMemRef, i); |
503 |
| - Value posZero = constantZero(builder, loc, posType); |
504 |
| - Value cond = builder.create<arith::CmpIOp>( |
505 |
| - loc, arith::CmpIPredicate::eq, newv, posZero); |
506 |
| - scf::IfOp ifOp = builder.create<scf::IfOp>(loc, TypeRange(posType), |
507 |
| - cond, /*else*/ true); |
508 |
| - builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
509 |
| - genStore(builder, loc, oldv, posMemRef, i); |
510 |
| - builder.create<scf::YieldOp>(loc, oldv); |
511 |
| - builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
512 |
| - builder.create<scf::YieldOp>(loc, newv); |
513 |
| - builder.setInsertionPointAfter(ifOp); |
514 |
| - builder.create<scf::YieldOp>(loc, ifOp.getResult(0)); |
515 |
| - builder.setInsertionPointAfter(loop); |
516 |
| - } |
517 |
| - } else { |
518 |
| - assert(isDenseDLT(dlt) || isSingletonDLT(dlt)); |
519 |
| - } |
520 |
| - } |
521 |
| -} |
522 |
| - |
523 |
| -static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, |
524 |
| - Value sz) { |
525 |
| - auto elemTp = llvm::cast<MemRefType>(mem.getType()).getElementType(); |
526 |
| - return builder |
527 |
| - .create<memref::SubViewOp>( |
528 |
| - loc, MemRefType::get({ShapedType::kDynamic}, elemTp), mem, |
529 |
| - ValueRange{}, ValueRange{sz}, ValueRange{}, |
530 |
| - ArrayRef<int64_t>{0}, // static offset |
531 |
| - ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size |
532 |
| - ArrayRef<int64_t>{1}) // static stride |
533 |
| - .getResult(); |
534 |
| -} |
535 |
| - |
536 |
| -static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) { |
537 |
| - ReassociationIndices reassociation; |
538 |
| - for (int i = 0, e = srcTp.getRank(); i < e; i++) |
539 |
| - reassociation.push_back(i); |
540 |
| - return reassociation; |
541 |
| -} |
542 |
| - |
543 |
| -static Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem, |
544 |
| - Type dstTp) { |
545 |
| - if (auto rtp = dstTp.dyn_cast<RankedTensorType>()) { |
546 |
| - // Scalars can only be converted to 0-ranked tensors. |
547 |
| - if (rtp.getRank() != 0) |
548 |
| - return nullptr; |
549 |
| - elem = genCast(builder, loc, elem, rtp.getElementType()); |
550 |
| - return builder.create<tensor::FromElementsOp>(loc, rtp, elem); |
551 |
| - } |
552 |
| - return genCast(builder, loc, elem, dstTp); |
553 |
| -} |
554 |
| - |
555 |
| -//===----------------------------------------------------------------------===// |
556 |
| -// Codegen rules. |
557 |
| -//===----------------------------------------------------------------------===// |
558 |
| - |
559 | 560 | /// Sparse tensor storage conversion rule for returns.
|
560 | 561 | class SparseReturnConverter : public OpConversionPattern<func::ReturnOp> {
|
561 | 562 | public:
|
|
0 commit comments