@@ -233,6 +233,19 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
233
233
{llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
234
234
llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
235
235
llvmInt32Type, llvmPointerType /* void *stream */ }};
236
+ FunctionCallBuilder createCscCallBuilder = {
237
+ " mgpuCreateCsc" ,
238
+ llvmPointerType,
239
+ {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmPointerType,
240
+ llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
241
+ llvmInt32Type, llvmPointerType /* void *stream */ }};
242
+ FunctionCallBuilder createBsrCallBuilder = {
243
+ " mgpuCreateBsr" ,
244
+ llvmPointerType,
245
+ {llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
246
+ llvmIntPtrType, llvmPointerType, llvmPointerType, llvmPointerType,
247
+ llvmInt32Type, llvmInt32Type, llvmInt32Type,
248
+ llvmPointerType /* void *stream */ }};
236
249
FunctionCallBuilder destroySpMatCallBuilder = {
237
250
" mgpuDestroySpMat" ,
238
251
llvmVoidType,
@@ -554,6 +567,8 @@ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(DestroyDnTensorOp)
554
567
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN (CreateCooOp)
555
568
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN (CreateCooAoSOp)
556
569
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN (CreateCsrOp)
570
+ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN (CreateCscOp)
571
+ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN (CreateBsrOp)
557
572
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN (Create2To4SpMatOp)
558
573
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN (DestroySpMatOp)
559
574
DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN (SpMVBufferSizeOp)
@@ -627,11 +642,11 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
627
642
628
643
// Corresponding to cusparseIndexType_t defined in cusparse.h.
629
644
static int32_t getCuSparseIndexTypeFrom (Type type) {
630
- if (type.isa <IndexType>( ))
631
- return 3 ; // CUSPARSE_INDEX_64I
632
- else
645
+ if (type.isInteger ( 16 ))
646
+ return 1 ; // CUSPARSE_INDEX_16U
647
+ if (type. isInteger ( 32 ))
633
648
return 2 ; // CUSPARSE_INDEX_32I
634
- // TODO: add support to CUSPARSE_INDEX_16U: 1
649
+ return 3 ; // CUSPARSE_INDEX_64I
635
650
}
636
651
637
652
static int32_t getCuSparseLtDataTypeFrom (Type type) {
@@ -684,6 +699,7 @@ static int32_t getCuSparseDataTypeFrom(Type type) {
684
699
static gpu::Prune2To4SpMatFlag get2To4PruneFlag (Value spMat) {
685
700
return spMat.getDefiningOp <gpu::Create2To4SpMatOp>().getPruneFlag ();
686
701
}
702
+
687
703
// TODO: We may want a run-time (of the mlir compiler) disablement/warning:
688
704
// cusparseLt currently won't work for cuda architecture <8.0 and will trigger a
689
705
// runtime (of the CUDA program) error , but it might be great if we could at
@@ -696,9 +712,13 @@ static bool is2To4Sparsity(Value spMat) {
696
712
return true ;
697
713
if (auto op = spMat.getDefiningOp <gpu::CreateCooOp>())
698
714
return false ;
715
+ if (auto op = spMat.getDefiningOp <gpu::CreateCooAoSOp>())
716
+ return false ;
699
717
if (auto op = spMat.getDefiningOp <gpu::CreateCsrOp>())
700
718
return false ;
701
- if (auto op = spMat.getDefiningOp <gpu::CreateCooAoSOp>())
719
+ if (auto op = spMat.getDefiningOp <gpu::CreateCscOp>())
720
+ return false ;
721
+ if (auto op = spMat.getDefiningOp <gpu::CreateBsrOp>())
702
722
return false ;
703
723
// Print the spMat defining op
704
724
spMat.getDefiningOp ()->print (llvm::errs ());
@@ -1916,6 +1936,83 @@ LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
1916
1936
return success ();
1917
1937
}
1918
1938
1939
+ LogicalResult ConvertCreateCscOpToGpuRuntimeCallPattern::matchAndRewrite (
1940
+ gpu::CreateCscOp op, OpAdaptor adaptor,
1941
+ ConversionPatternRewriter &rewriter) const {
1942
+ if (failed (areAllLLVMTypes (op, adaptor.getOperands (), rewriter)) ||
1943
+ failed (isAsyncWithOneDependency (rewriter, op)))
1944
+ return failure ();
1945
+ Location loc = op.getLoc ();
1946
+ auto stream = adaptor.getAsyncDependencies ().front ();
1947
+ Value pColPos =
1948
+ MemRefDescriptor (adaptor.getColPos ()).allocatedPtr (rewriter, loc);
1949
+ Value pRowIdxs =
1950
+ MemRefDescriptor (adaptor.getRowIdxs ()).allocatedPtr (rewriter, loc);
1951
+ Value pValues =
1952
+ MemRefDescriptor (adaptor.getValues ()).allocatedPtr (rewriter, loc);
1953
+ if (!getTypeConverter ()->useOpaquePointers ()) {
1954
+ pColPos = rewriter.create <LLVM::BitcastOp>(loc, llvmPointerType, pColPos);
1955
+ pRowIdxs = rewriter.create <LLVM::BitcastOp>(loc, llvmPointerType, pRowIdxs);
1956
+ pValues = rewriter.create <LLVM::BitcastOp>(loc, llvmPointerType, pValues);
1957
+ }
1958
+ Type pType =
1959
+ llvm::cast<MemRefType>(op.getColPos ().getType ()).getElementType ();
1960
+ Type iType =
1961
+ llvm::cast<MemRefType>(op.getRowIdxs ().getType ()).getElementType ();
1962
+ Type dType =
1963
+ llvm::cast<MemRefType>(op.getValues ().getType ()).getElementType ();
1964
+ auto ptp = genConstInt32From (rewriter, loc, getCuSparseIndexTypeFrom (pType));
1965
+ auto itp = genConstInt32From (rewriter, loc, getCuSparseIndexTypeFrom (iType));
1966
+ auto dtp = genConstInt32From (rewriter, loc, getCuSparseDataTypeFrom (dType));
1967
+ auto handle =
1968
+ createCscCallBuilder
1969
+ .create (loc, rewriter,
1970
+ {adaptor.getRows (), adaptor.getCols (), adaptor.getNnz (),
1971
+ pColPos, pRowIdxs, pValues, ptp, itp, dtp, stream})
1972
+ .getResult ();
1973
+ rewriter.replaceOp (op, {handle, stream});
1974
+ return success ();
1975
+ }
1976
+
1977
+ LogicalResult ConvertCreateBsrOpToGpuRuntimeCallPattern::matchAndRewrite (
1978
+ gpu::CreateBsrOp op, OpAdaptor adaptor,
1979
+ ConversionPatternRewriter &rewriter) const {
1980
+ if (failed (areAllLLVMTypes (op, adaptor.getOperands (), rewriter)) ||
1981
+ failed (isAsyncWithOneDependency (rewriter, op)))
1982
+ return failure ();
1983
+ Location loc = op.getLoc ();
1984
+ auto stream = adaptor.getAsyncDependencies ().front ();
1985
+ Value pRowPos =
1986
+ MemRefDescriptor (adaptor.getBRowPos ()).allocatedPtr (rewriter, loc);
1987
+ Value pColIdxs =
1988
+ MemRefDescriptor (adaptor.getBColIdxs ()).allocatedPtr (rewriter, loc);
1989
+ Value pValues =
1990
+ MemRefDescriptor (adaptor.getValues ()).allocatedPtr (rewriter, loc);
1991
+ if (!getTypeConverter ()->useOpaquePointers ()) {
1992
+ pRowPos = rewriter.create <LLVM::BitcastOp>(loc, llvmPointerType, pRowPos);
1993
+ pColIdxs = rewriter.create <LLVM::BitcastOp>(loc, llvmPointerType, pColIdxs);
1994
+ pValues = rewriter.create <LLVM::BitcastOp>(loc, llvmPointerType, pValues);
1995
+ }
1996
+ Type pType =
1997
+ llvm::cast<MemRefType>(op.getBRowPos ().getType ()).getElementType ();
1998
+ Type iType =
1999
+ llvm::cast<MemRefType>(op.getBColIdxs ().getType ()).getElementType ();
2000
+ Type dType =
2001
+ llvm::cast<MemRefType>(op.getValues ().getType ()).getElementType ();
2002
+ auto ptp = genConstInt32From (rewriter, loc, getCuSparseIndexTypeFrom (pType));
2003
+ auto itp = genConstInt32From (rewriter, loc, getCuSparseIndexTypeFrom (iType));
2004
+ auto dtp = genConstInt32From (rewriter, loc, getCuSparseDataTypeFrom (dType));
2005
+ auto handle =
2006
+ createBsrCallBuilder
2007
+ .create (loc, rewriter,
2008
+ {adaptor.getBrows (), adaptor.getBcols (), adaptor.getBnnz (),
2009
+ adaptor.getRBlockSize (), adaptor.getCBlockSize (), pRowPos,
2010
+ pColIdxs, pValues, ptp, itp, dtp, stream})
2011
+ .getResult ();
2012
+ rewriter.replaceOp (op, {handle, stream});
2013
+ return success ();
2014
+ }
2015
+
1919
2016
void mlir::populateGpuToLLVMConversionPatterns (LLVMTypeConverter &converter,
1920
2017
RewritePatternSet &patterns,
1921
2018
StringRef gpuBinaryAnnotation,
@@ -1941,6 +2038,8 @@ void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
1941
2038
ConvertCreateCooOpToGpuRuntimeCallPattern,
1942
2039
ConvertCreateCooAoSOpToGpuRuntimeCallPattern,
1943
2040
ConvertCreateCsrOpToGpuRuntimeCallPattern,
2041
+ ConvertCreateCscOpToGpuRuntimeCallPattern,
2042
+ ConvertCreateBsrOpToGpuRuntimeCallPattern,
1944
2043
ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
1945
2044
ConvertDestroySpMatOpToGpuRuntimeCallPattern,
1946
2045
ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
0 commit comments