@@ -374,6 +374,11 @@ static bool matchSumReductionOfMulUnary(linalg::GenericOp op) {
374
374
return false ;
375
375
}
376
376
377
+ // / Determines if the given value is a dense tensor instead of a sparse one.
378
+ static bool isDenseTensor (Value v) {
379
+ return (sparse_tensor::getSparseTensorType (v).isAllDense ());
380
+ }
381
+
377
382
// / Test for sorted COO with suitable data and coordinates types.
378
383
static bool isAdmissibleCOO (SparseTensorType &aTp) {
379
384
return aTp.isCompressedLvl (0 ) && aTp.isOrderedLvl (0 ) && !aTp.isUniqueLvl (0 ) &&
@@ -656,6 +661,109 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
656
661
return success ();
657
662
}
658
663
664
+ // Match and rewrite 2:4 SpMM kernels.
665
+ static LogicalResult rewrite2To4SpMM (PatternRewriter &rewriter,
666
+ linalg::GenericOp op) {
667
+ Location loc = op.getLoc ();
668
+ Value A = op.getOperand (0 );
669
+ Value B = op.getOperand (1 );
670
+ Value C = op.getOperand (2 ); // we have C = AB
671
+ SmallVector<Value> tokens;
672
+
673
+ // All input should be dense tensors.
674
+ if (!isDenseTensor (A) || !isDenseTensor (B) || !isDenseTensor (C))
675
+ return failure ();
676
+
677
+ Value bufA = genTensorToMemref (rewriter, loc, A);
678
+ Value matA = genAllocCopy (rewriter, loc, bufA, tokens);
679
+ Value bufB = genTensorToMemref (rewriter, loc, B);
680
+ Value matB = genAllocCopy (rewriter, loc, bufB, tokens);
681
+ Value bufC = genTensorToMemref (rewriter, loc, C);
682
+ Value matC = genAllocCopy (rewriter, loc, bufC, tokens);
683
+ genBlockingWait (rewriter, loc, tokens);
684
+ tokens.clear ();
685
+ Value szm = linalg::createOrFoldDimOp (rewriter, loc, matA, 0 );
686
+ Value szk = linalg::createOrFoldDimOp (rewriter, loc, matB, 0 );
687
+ Value szn = linalg::createOrFoldDimOp (rewriter, loc, matC, 1 );
688
+
689
+ Type indexTp = rewriter.getIndexType ();
690
+ Type dnTensorHandleTp = rewriter.getType <gpu::SparseDnTensorHandleType>();
691
+ Type spMatHandleTp = rewriter.getType <gpu::SparseSpMatHandleType>();
692
+ Type tokenTp = rewriter.getType <gpu::AsyncTokenType>();
693
+ Value token = genFirstWait (rewriter, loc);
694
+ Operation *spGenA = rewriter.create <gpu::Create2To4SpMatOp>(
695
+ loc, spMatHandleTp, tokenTp, token, szm, szk, matA);
696
+
697
+ Value spMatA = spGenA->getResult (0 );
698
+ token = spGenA->getResult (1 );
699
+ auto dmatB = rewriter.create <gpu::CreateDnTensorOp>(
700
+ loc, dnTensorHandleTp, tokenTp, token, matB,
701
+ SmallVector<Value>{szk, szn});
702
+ Value dnB = dmatB.getResult (0 );
703
+ token = dmatB.getAsyncToken ();
704
+ auto dmatC = rewriter.create <gpu::CreateDnTensorOp>(
705
+ loc, dnTensorHandleTp, tokenTp, token, matC,
706
+ SmallVector<Value>{szm, szn});
707
+ Value dnC = dmatC.getResult (0 );
708
+ token = dmatC.getAsyncToken ();
709
+
710
+ auto dmatCType = llvm::cast<ShapedType>(matC.getType ()).getElementType ();
711
+
712
+ // Precompute buffersize for SpMM.
713
+ SmallVector<Type> bufferTypes_{indexTp, indexTp, indexTp};
714
+ TypeRange bufferTypes (bufferTypes_);
715
+ auto bufferComp = rewriter.create <gpu::SpMMBufferSizeOp>(
716
+ loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE,
717
+ gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC,
718
+ /* computeType=*/ dmatCType);
719
+
720
+ token = bufferComp.getAsyncToken ();
721
+ Value bufferSz = bufferComp.getResult (0 );
722
+ auto buf = genAllocBuffer (rewriter, loc, bufferSz, token);
723
+ Value buffer = buf.getResult (0 );
724
+ token = buf.getAsyncToken ();
725
+
726
+ Value bufferSz2 = bufferComp.getResult (1 );
727
+ auto buf2 = genAllocBuffer (rewriter, loc, bufferSz2, token);
728
+ Value buffer2 = buf2.getResult (0 );
729
+ token = buf2.getAsyncToken ();
730
+
731
+ Value bufferSz3 = bufferComp.getResult (2 );
732
+ auto buf3 = genAllocBuffer (rewriter, loc, bufferSz3, token);
733
+ Value buffer3 = buf3.getResult (0 );
734
+ token = buf3.getAsyncToken ();
735
+
736
+ auto dnCType = llvm::cast<ShapedType>(matC.getType ()).getElementType ();
737
+
738
+ // Perform the SpMM.
739
+ auto spmmComp = rewriter.create <gpu::SpMMOp>(
740
+ loc, tokenTp, token, spMatA, dnB, dnC, /* computeType=*/ dnCType,
741
+ SmallVector<Value>{buffer, buffer2, buffer3});
742
+ token = spmmComp.getAsyncToken ();
743
+
744
+ // Copy data back to host and free all the resources.
745
+ token = rewriter.create <gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
746
+ .getAsyncToken ();
747
+ token = rewriter.create <gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
748
+ .getAsyncToken ();
749
+ token = rewriter.create <gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
750
+ .getAsyncToken ();
751
+ SmallVector<Value> newDynamicSizes;
752
+
753
+ token = genDeallocMemRef (rewriter, loc, buffer, token);
754
+ token = genDeallocMemRef (rewriter, loc, buffer2, token);
755
+ token = genDeallocMemRef (rewriter, loc, buffer3, token);
756
+ token = genDeallocMemRef (rewriter, loc, matA, token);
757
+ token = genDeallocMemRef (rewriter, loc, matB, token);
758
+ token = genCopyMemRef (rewriter, loc, bufC, matC, token);
759
+ token = genDeallocMemRef (rewriter, loc, matC, token);
760
+ tokens.push_back (token);
761
+ genBlockingWait (rewriter, loc, tokens);
762
+ tokens.clear ();
763
+ rewriter.replaceOpWithNewOp <bufferization::ToTensorOp>(op, bufC);
764
+ return success ();
765
+ }
766
+
659
767
// / Match and rewrite SDDMM kernel.
660
768
static LogicalResult rewriteSDDMM (PatternRewriter &rewriter,
661
769
linalg::GenericOp op, bool enableRT) {
@@ -906,6 +1014,9 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
906
1014
// TODO: add transposed {i, k}, {k, j}
907
1015
// TODO: maybe add transposed {i, j} in future
908
1016
maps == infer ({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs (op)) {
1017
+ if (op->getAttr (" DENSE24" ))
1018
+ return rewrite2To4SpMM (rewriter, op);
1019
+
909
1020
return rewriteSpMM (rewriter, op, enableRT);
910
1021
}
911
1022
0 commit comments