@@ -39,7 +39,7 @@ enum class CuSparseFormat {
39
39
kCOO ,
40
40
kCSR ,
41
41
kCSC ,
42
- kBSR , // TODO: coming soon!
42
+ kBSR ,
43
43
};
44
44
45
45
// ===----------------------------------------------------------------------===//
@@ -428,6 +428,19 @@ static bool isAdmissibleCSC(SparseTensorType &aTp) {
428
428
aTp.isOrderedLvl (1 ) && aTp.isUniqueLvl (1 ) && isAdmissibleMetaData (aTp);
429
429
}
430
430
431
+ // / Test for BSR matrix with suitable metadata.
432
+ static bool isAdmissibleBSR (SparseTensorType &aTp) {
433
+ if (aTp.getDimRank () == 2 && aTp.getLvlRank () == 4 && aTp.isDenseLvl (0 ) &&
434
+ aTp.isCompressedLvl (1 ) && aTp.isOrderedLvl (1 ) && aTp.isUniqueLvl (1 ) &&
435
+ aTp.isDenseLvl (2 ) && aTp.isDenseLvl (3 ) && isAdmissibleMetaData (aTp)) {
436
+ // CuSparse only supports "square" blocks currently.
437
+ SmallVector<unsigned > dims = getBlockSize (aTp.getDimToLvl ());
438
+ assert (dims.size () == 2 );
439
+ return dims[0 ] = dims[1 ] && dims[0 ] > 1 ;
440
+ }
441
+ return false ;
442
+ }
443
+
431
444
// / Returns a suitable sparse format for the operation and given operand
432
445
// / types with cuSparse, or kNone if none is available.
433
446
static CuSparseFormat getCuSparseFormat (SparseTensorType aTp,
@@ -448,6 +461,8 @@ static CuSparseFormat getCuSparseFormat(SparseTensorType aTp,
448
461
return CuSparseFormat::kCSR ;
449
462
if (isAdmissibleCSC (aTp))
450
463
return CuSparseFormat::kCSC ;
464
+ if (isAdmissibleBSR (aTp))
465
+ return CuSparseFormat::kBSR ;
451
466
return CuSparseFormat::kNone ;
452
467
}
453
468
@@ -475,9 +490,10 @@ static Value genSecondCrds(OpBuilder &builder, Location loc, Value a,
475
490
}
476
491
477
492
// / Generates the sparse matrix handle.
478
- static Operation *genSpMat (OpBuilder &builder, Location loc, Type handleTp,
479
- Type tokenTp, Value token, Value sz1, Value sz2,
480
- Value nseA, Value rowA, Value colA, Value valA,
493
+ static Operation *genSpMat (OpBuilder &builder, Location loc,
494
+ SparseTensorType &aTp, Type handleTp, Type tokenTp,
495
+ Value token, Value sz1, Value sz2, Value nseA,
496
+ Value rowA, Value colA, Value valA,
481
497
CuSparseFormat format, bool enableRT) {
482
498
if (format == CuSparseFormat::kCOO ) {
483
499
// Library uses SoA COO, direct IR uses AoS COO.
@@ -498,9 +514,24 @@ static Operation *genSpMat(OpBuilder &builder, Location loc, Type handleTp,
498
514
if (format == CuSparseFormat::kCSR )
499
515
return builder.create <gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1,
500
516
sz2, nseA, rowA, colA, valA);
501
- assert (format == CuSparseFormat::kCSC );
502
- return builder.create <gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1,
503
- sz2, nseA, rowA, colA, valA);
517
+ if (format == CuSparseFormat::kCSC )
518
+ return builder.create <gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1,
519
+ sz2, nseA, rowA, colA, valA);
520
+ // BSR requires a bit more work since we need to pass in the block size
521
+ // and all others sizes in terms of blocks (#block-rows, #block-cols,
522
+ // #nonzero-blocks).
523
+ assert (format == CuSparseFormat::kBSR );
524
+ SmallVector<unsigned > dims = getBlockSize (aTp.getDimToLvl ());
525
+ assert (dims.size () == 2 && dims[0 ] == dims[1 ]);
526
+ uint64_t b = dims[0 ];
527
+ Value bSz = constantIndex (builder, loc, b);
528
+ Value bRows = builder.create <arith::DivUIOp>(loc, sz1, bSz);
529
+ Value bCols = builder.create <arith::DivUIOp>(loc, sz2, bSz);
530
+ Value bNum = builder.create <arith::DivUIOp>(
531
+ loc, nseA, constantIndex (builder, loc, b * b));
532
+ return builder.create <gpu::CreateBsrOp>(loc, handleTp, tokenTp, token, bRows,
533
+ bCols, bNum, bSz, bSz, rowA, colA,
534
+ valA);
504
535
}
505
536
506
537
// / Match and rewrite SpMV kernel.
@@ -566,8 +597,8 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
566
597
Type tokenTp = rewriter.getType <gpu::AsyncTokenType>();
567
598
Value token = genFirstWait (rewriter, loc);
568
599
Operation *spGenA =
569
- genSpMat (rewriter, loc, spmatHandleTp, tokenTp, token, szY, szX, nseA ,
570
- rowA, colA, valA, format, enableRT);
600
+ genSpMat (rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szY, szX,
601
+ nseA, rowA, colA, valA, format, enableRT);
571
602
Value spMatA = spGenA->getResult (0 );
572
603
token = spGenA->getResult (1 );
573
604
auto dvecX = rewriter.create <gpu::CreateDnTensorOp>(
@@ -691,8 +722,8 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
691
722
Type tokenTp = rewriter.getType <gpu::AsyncTokenType>();
692
723
Value token = genFirstWait (rewriter, loc);
693
724
Operation *spGenA =
694
- genSpMat (rewriter, loc, spMatHandleTp, tokenTp, token, szm, szk, nseA ,
695
- rowA, colA, valA, format, enableRT);
725
+ genSpMat (rewriter, loc, aTp, spMatHandleTp, tokenTp, token, szm, szk,
726
+ nseA, rowA, colA, valA, format, enableRT);
696
727
Value spMatA = spGenA->getResult (0 );
697
728
token = spGenA->getResult (1 );
698
729
auto dmatB = rewriter.create <gpu::CreateDnTensorOp>(
@@ -806,13 +837,13 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
806
837
Type tokenTp = rewriter.getType <gpu::AsyncTokenType>();
807
838
Value token = genFirstWait (rewriter, loc);
808
839
Operation *spGenA =
809
- genSpMat (rewriter, loc, spmatHandleTp, tokenTp, token, szm, szk, nseA ,
810
- rowA, colA, valA, format, enableRT);
840
+ genSpMat (rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szm, szk,
841
+ nseA, rowA, colA, valA, format, enableRT);
811
842
Value spMatA = spGenA->getResult (0 );
812
843
token = spGenA->getResult (1 );
813
844
Operation *spGenB =
814
- genSpMat (rewriter, loc, spmatHandleTp, tokenTp, token, szk, szn, nseB ,
815
- rowB, colB, valB, format, enableRT);
845
+ genSpMat (rewriter, loc, bTp, spmatHandleTp, tokenTp, token, szk, szn,
846
+ nseB, rowB, colB, valB, format, enableRT);
816
847
Value spMatB = spGenB->getResult (0 );
817
848
token = spGenB->getResult (1 );
818
849
@@ -830,8 +861,8 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
830
861
Value valC = e3 .getResult (0 ); // no free needed
831
862
token = e3 .getAsyncToken ();
832
863
Operation *spGenC =
833
- genSpMat (rewriter, loc, spmatHandleTp, tokenTp, token, szm, szn, zero ,
834
- rowC, colC, valC, format, enableRT);
864
+ genSpMat (rewriter, loc, cTp, spmatHandleTp, tokenTp, token, szm, szn,
865
+ zero, rowC, colC, valC, format, enableRT);
835
866
Value spMatC = spGenC->getResult (0 );
836
867
token = spGenC->getResult (1 );
837
868
@@ -1137,8 +1168,8 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
1137
1168
Value dnB = dmatB.getResult (0 );
1138
1169
token = dmatB.getAsyncToken ();
1139
1170
Operation *spGenC =
1140
- genSpMat (rewriter, loc, spMatHandleTp, tokenTp, token, szm, szn, nseC ,
1141
- rowC, colC, valC, format, enableRT);
1171
+ genSpMat (rewriter, loc, cTp, spMatHandleTp, tokenTp, token, szm, szn,
1172
+ nseC, rowC, colC, valC, format, enableRT);
1142
1173
Value spMatC = spGenC->getResult (0 );
1143
1174
token = spGenC->getResult (1 );
1144
1175
auto dnCType = llvm::cast<ShapedType>(c.getType ()).getElementType ();
0 commit comments