Skip to content

Commit 3d89c08

Browse files
authored
[mlir][sparse] support BSR for cuSPARSE (libgen path only) (#69646)
1 parent ec10c36 commit 3d89c08

File tree

3 files changed

+246
-25
lines changed

3 files changed

+246
-25
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ enum class CuSparseFormat {
3939
kCOO,
4040
kCSR,
4141
kCSC,
42-
kBSR, // TODO: coming soon!
42+
kBSR,
4343
};
4444

4545
//===----------------------------------------------------------------------===//
@@ -428,6 +428,19 @@ static bool isAdmissibleCSC(SparseTensorType &aTp) {
428428
aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) && isAdmissibleMetaData(aTp);
429429
}
430430

431+
/// Test for BSR matrix with suitable metadata.
432+
static bool isAdmissibleBSR(SparseTensorType &aTp) {
433+
if (aTp.getDimRank() == 2 && aTp.getLvlRank() == 4 && aTp.isDenseLvl(0) &&
434+
aTp.isCompressedLvl(1) && aTp.isOrderedLvl(1) && aTp.isUniqueLvl(1) &&
435+
aTp.isDenseLvl(2) && aTp.isDenseLvl(3) && isAdmissibleMetaData(aTp)) {
436+
// CuSparse only supports "square" blocks currently.
437+
SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl());
438+
assert(dims.size() == 2);
439+
return dims[0] = dims[1] && dims[0] > 1;
440+
}
441+
return false;
442+
}
443+
431444
/// Returns a suitable sparse format for the operation and given operand
432445
/// types with cuSparse, or kNone if none is available.
433446
static CuSparseFormat getCuSparseFormat(SparseTensorType aTp,
@@ -448,6 +461,8 @@ static CuSparseFormat getCuSparseFormat(SparseTensorType aTp,
448461
return CuSparseFormat::kCSR;
449462
if (isAdmissibleCSC(aTp))
450463
return CuSparseFormat::kCSC;
464+
if (isAdmissibleBSR(aTp))
465+
return CuSparseFormat::kBSR;
451466
return CuSparseFormat::kNone;
452467
}
453468

@@ -475,9 +490,10 @@ static Value genSecondCrds(OpBuilder &builder, Location loc, Value a,
475490
}
476491

477492
/// Generates the sparse matrix handle.
478-
static Operation *genSpMat(OpBuilder &builder, Location loc, Type handleTp,
479-
Type tokenTp, Value token, Value sz1, Value sz2,
480-
Value nseA, Value rowA, Value colA, Value valA,
493+
static Operation *genSpMat(OpBuilder &builder, Location loc,
494+
SparseTensorType &aTp, Type handleTp, Type tokenTp,
495+
Value token, Value sz1, Value sz2, Value nseA,
496+
Value rowA, Value colA, Value valA,
481497
CuSparseFormat format, bool enableRT) {
482498
if (format == CuSparseFormat::kCOO) {
483499
// Library uses SoA COO, direct IR uses AoS COO.
@@ -498,9 +514,24 @@ static Operation *genSpMat(OpBuilder &builder, Location loc, Type handleTp,
498514
if (format == CuSparseFormat::kCSR)
499515
return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, sz1,
500516
sz2, nseA, rowA, colA, valA);
501-
assert(format == CuSparseFormat::kCSC);
502-
return builder.create<gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1,
503-
sz2, nseA, rowA, colA, valA);
517+
if (format == CuSparseFormat::kCSC)
518+
return builder.create<gpu::CreateCscOp>(loc, handleTp, tokenTp, token, sz1,
519+
sz2, nseA, rowA, colA, valA);
520+
// BSR requires a bit more work since we need to pass in the block size
521+
// and all others sizes in terms of blocks (#block-rows, #block-cols,
522+
// #nonzero-blocks).
523+
assert(format == CuSparseFormat::kBSR);
524+
SmallVector<unsigned> dims = getBlockSize(aTp.getDimToLvl());
525+
assert(dims.size() == 2 && dims[0] == dims[1]);
526+
uint64_t b = dims[0];
527+
Value bSz = constantIndex(builder, loc, b);
528+
Value bRows = builder.create<arith::DivUIOp>(loc, sz1, bSz);
529+
Value bCols = builder.create<arith::DivUIOp>(loc, sz2, bSz);
530+
Value bNum = builder.create<arith::DivUIOp>(
531+
loc, nseA, constantIndex(builder, loc, b * b));
532+
return builder.create<gpu::CreateBsrOp>(loc, handleTp, tokenTp, token, bRows,
533+
bCols, bNum, bSz, bSz, rowA, colA,
534+
valA);
504535
}
505536

506537
/// Match and rewrite SpMV kernel.
@@ -566,8 +597,8 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
566597
Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
567598
Value token = genFirstWait(rewriter, loc);
568599
Operation *spGenA =
569-
genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szY, szX, nseA,
570-
rowA, colA, valA, format, enableRT);
600+
genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szY, szX,
601+
nseA, rowA, colA, valA, format, enableRT);
571602
Value spMatA = spGenA->getResult(0);
572603
token = spGenA->getResult(1);
573604
auto dvecX = rewriter.create<gpu::CreateDnTensorOp>(
@@ -691,8 +722,8 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
691722
Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
692723
Value token = genFirstWait(rewriter, loc);
693724
Operation *spGenA =
694-
genSpMat(rewriter, loc, spMatHandleTp, tokenTp, token, szm, szk, nseA,
695-
rowA, colA, valA, format, enableRT);
725+
genSpMat(rewriter, loc, aTp, spMatHandleTp, tokenTp, token, szm, szk,
726+
nseA, rowA, colA, valA, format, enableRT);
696727
Value spMatA = spGenA->getResult(0);
697728
token = spGenA->getResult(1);
698729
auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
@@ -806,13 +837,13 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
806837
Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
807838
Value token = genFirstWait(rewriter, loc);
808839
Operation *spGenA =
809-
genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szm, szk, nseA,
810-
rowA, colA, valA, format, enableRT);
840+
genSpMat(rewriter, loc, aTp, spmatHandleTp, tokenTp, token, szm, szk,
841+
nseA, rowA, colA, valA, format, enableRT);
811842
Value spMatA = spGenA->getResult(0);
812843
token = spGenA->getResult(1);
813844
Operation *spGenB =
814-
genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szk, szn, nseB,
815-
rowB, colB, valB, format, enableRT);
845+
genSpMat(rewriter, loc, bTp, spmatHandleTp, tokenTp, token, szk, szn,
846+
nseB, rowB, colB, valB, format, enableRT);
816847
Value spMatB = spGenB->getResult(0);
817848
token = spGenB->getResult(1);
818849

@@ -830,8 +861,8 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
830861
Value valC = e3.getResult(0); // no free needed
831862
token = e3.getAsyncToken();
832863
Operation *spGenC =
833-
genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szm, szn, zero,
834-
rowC, colC, valC, format, enableRT);
864+
genSpMat(rewriter, loc, cTp, spmatHandleTp, tokenTp, token, szm, szn,
865+
zero, rowC, colC, valC, format, enableRT);
835866
Value spMatC = spGenC->getResult(0);
836867
token = spGenC->getResult(1);
837868

@@ -1137,8 +1168,8 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
11371168
Value dnB = dmatB.getResult(0);
11381169
token = dmatB.getAsyncToken();
11391170
Operation *spGenC =
1140-
genSpMat(rewriter, loc, spMatHandleTp, tokenTp, token, szm, szn, nseC,
1141-
rowC, colC, valC, format, enableRT);
1171+
genSpMat(rewriter, loc, cTp, spMatHandleTp, tokenTp, token, szm, szn,
1172+
nseC, rowC, colC, valC, format, enableRT);
11421173
Value spMatC = spGenC->getResult(0);
11431174
token = spGenC->getResult(1);
11441175
auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();

mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-sampled-matmul-lib.mlir

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
//
44
// DEFINE: %{compile} = mlir-opt %s \
55
// DEFINE: --sparse-compiler="enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 gpu-format=%gpu_compilation_format
6-
// DEFINE: %{run} = TENSOR0="%mlir_src_dir/test/Integration/data/test.mtx" \
6+
// DEFINE: %{run} = \
7+
// DEFINE: env TENSOR0="%mlir_src_dir/test/Integration/data/test.mtx" \
78
// DEFINE: mlir-cpu-runner \
89
// DEFINE: --shared-libs=%mlir_cuda_runtime \
910
// DEFINE: --shared-libs=%mlir_c_runner_utils \
@@ -12,16 +13,16 @@
1213
//
1314
// with RT lib:
1415
//
15-
// RUN: %{compile} enable-runtime-library=true" | %{run}
16-
// RUN: %{compile} enable-runtime-library=true gpu-data-transfer-strategy=pinned-dma" | %{run}
17-
// Tracker #64316
18-
// RUNNOT: %{compile} enable-runtime-library=true gpu-data-transfer-strategy=zero-copy" | %{run}
16+
// RUN: %{compile} enable-runtime-library=true" | %{run}
17+
// RUN: %{compile} enable-runtime-library=true gpu-data-transfer-strategy=pinned-dma" | %{run}
18+
// TODO: Tracker #64316
19+
// RUNNOT: %{compile} enable-runtime-library=true gpu-data-transfer-strategy=zero-copy" | %{run}
1920
//
2021
// without RT lib:
2122
//
2223
// RUN: %{compile} enable-runtime-library=false" | %{run}
2324
// RUN: %{compile} enable-runtime-library=false gpu-data-transfer-strategy=pinned-dma" | %{run}
24-
// Tracker #64316
25+
// TODO: Tracker #64316
2526
// RUNNOT: %{compile} enable-runtime-library=false gpu-data-transfer-strategy=zero-copy" | %{run}
2627
//
2728

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
//
2+
// NOTE: this test requires gpu-sm80
3+
//
4+
// DEFINE: %{compile} = mlir-opt %s \
5+
// DEFINE: --sparse-compiler="enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 gpu-format=%gpu_compilation_format
6+
// DEFINE: %{run} = \
7+
// DEFINE: env TENSOR0="%mlir_src_dir/test/Integration/data/block.mtx" \
8+
// DEFINE: mlir-cpu-runner \
9+
// DEFINE: --shared-libs=%mlir_cuda_runtime \
10+
// DEFINE: --shared-libs=%mlir_c_runner_utils \
11+
// DEFINE: --e entry --entry-point-result=void \
12+
// DEFINE: | FileCheck %s
13+
//
14+
// with RT lib:
15+
//
16+
// RUN: %{compile} enable-runtime-library=true" | %{run}
17+
//
18+
// without RT lib:
19+
//
20+
// TODO: make this work
21+
// R_UN: %{compile} enable-runtime-library=false" | %{run}
22+
//
23+
24+
!Filename = !llvm.ptr<i8>
25+
26+
#CSR = #sparse_tensor.encoding<{
27+
map = (d0, d1) -> (d0 : dense, d1 : compressed)
28+
}>
29+
30+
#BSR = #sparse_tensor.encoding<{
31+
map = (i, j) -> (
32+
i floordiv 2 : dense,
33+
j floordiv 2 : compressed,
34+
i mod 2 : dense,
35+
j mod 2 : dense)
36+
}>
37+
38+
#trait_SDDMM = {
39+
indexing_maps = [
40+
affine_map<(i,j,k) -> (i,k)>, // A
41+
affine_map<(i,j,k) -> (k,j)>, // B
42+
affine_map<(i,j,k) -> (i,j)> // S (in/out)
43+
],
44+
iterator_types = ["parallel", "parallel", "reduction"],
45+
doc = "S(i,j) += spy[S(i,j)] x SUM_k A(i,k) B(k,j)"
46+
}
47+
48+
//
49+
// Integration test that lowers a kernel annotated as sparse to
50+
// actual sparse code, initializes sparse storage schemes, and
51+
// runs the resulting code with the JIT compiler.
52+
//
53+
module {
54+
llvm.func @mgpuCreateSparseEnv()
55+
llvm.func @mgpuDestroySparseEnv()
56+
57+
//
58+
// A kernel that computes a CSR sampled dense matrix matrix multiplication
59+
// using a "spy" function and in-place update of the sampling sparse matrix.
60+
//
61+
func.func @SDDMM(%args: tensor<?x?xf32, #CSR>,
62+
%arga: tensor<?x?xf32>,
63+
%argb: tensor<?x?xf32>) -> tensor<?x?xf32, #CSR> {
64+
%result = linalg.generic #trait_SDDMM
65+
ins(%arga, %argb: tensor<?x?xf32>, tensor<?x?xf32>)
66+
outs(%args: tensor<?x?xf32, #CSR>) {
67+
^bb(%a: f32, %b: f32, %s: f32):
68+
%f0 = arith.constant 0.0 : f32
69+
%u = sparse_tensor.unary %s : f32 to f32
70+
present={
71+
^bb0(%p: f32):
72+
%mul = arith.mulf %a, %b : f32
73+
sparse_tensor.yield %mul : f32
74+
}
75+
absent={}
76+
%r = sparse_tensor.reduce %s, %u, %f0 : f32 {
77+
^bb0(%p: f32, %q: f32):
78+
%add = arith.addf %p, %q : f32
79+
sparse_tensor.yield %add : f32
80+
}
81+
linalg.yield %r : f32
82+
} -> tensor<?x?xf32, #CSR>
83+
return %result : tensor<?x?xf32, #CSR>
84+
}
85+
86+
//
87+
// A kernel that computes a BSR sampled dense matrix matrix multiplication
88+
// using a "spy" function and in-place update of the sampling sparse matrix.
89+
//
90+
func.func @SDDMM_block(%args: tensor<?x?xf32, #BSR>,
91+
%arga: tensor<?x?xf32>,
92+
%argb: tensor<?x?xf32>) -> tensor<?x?xf32, #BSR> {
93+
%result = linalg.generic #trait_SDDMM
94+
ins(%arga, %argb: tensor<?x?xf32>, tensor<?x?xf32>)
95+
outs(%args: tensor<?x?xf32, #BSR>) {
96+
^bb(%a: f32, %b: f32, %s: f32):
97+
%f0 = arith.constant 0.0 : f32
98+
%u = sparse_tensor.unary %s : f32 to f32
99+
present={
100+
^bb0(%p: f32):
101+
%mul = arith.mulf %a, %b : f32
102+
sparse_tensor.yield %mul : f32
103+
}
104+
absent={}
105+
%r = sparse_tensor.reduce %s, %u, %f0 : f32 {
106+
^bb0(%p: f32, %q: f32):
107+
%add = arith.addf %p, %q : f32
108+
sparse_tensor.yield %add : f32
109+
}
110+
linalg.yield %r : f32
111+
} -> tensor<?x?xf32, #BSR>
112+
return %result : tensor<?x?xf32, #BSR>
113+
}
114+
115+
func.func private @getTensorFilename(index) -> (!Filename)
116+
117+
//
118+
// Main driver.
119+
//
120+
func.func @entry() {
121+
llvm.call @mgpuCreateSparseEnv() : () -> ()
122+
%d0 = arith.constant 0.0 : f32
123+
%c0 = arith.constant 0 : index
124+
%c1 = arith.constant 1 : index
125+
%c4 = arith.constant 4 : index
126+
%c6 = arith.constant 6 : index
127+
128+
// Initialize dense matrices.
129+
%a = tensor.generate %c4, %c4 {
130+
^bb0(%i: index, %j: index):
131+
%p = arith.addi %i, %c1 : index
132+
%q = arith.index_cast %p : index to i32
133+
%d = arith.sitofp %q : i32 to f32
134+
tensor.yield %d : f32
135+
} : tensor<?x?xf32>
136+
%b = tensor.generate %c4, %c6 {
137+
^bb0(%i: index, %j: index):
138+
%p = arith.addi %j, %c1 : index
139+
%q = arith.index_cast %p : index to i32
140+
%d = arith.sitofp %q : i32 to f32
141+
tensor.yield %d : f32
142+
} : tensor<?x?xf32>
143+
144+
// Read the sparse matrix from file, construct sparse storage.
145+
//
146+
// +-----+-----+-----+
147+
// | 1 2 | . . | 4 . |
148+
// | . 3 | . . | . 5 |
149+
// +-----+-----+-----+
150+
// | . . | 6 7 | . . |
151+
// | . . | 8 . | . . |
152+
// +-----+-----+-----+
153+
//
154+
%fileName = call @getTensorFilename(%c0) : (index) -> (!Filename)
155+
%m_csr = sparse_tensor.new %fileName : !Filename to tensor<?x?xf32, #CSR>
156+
%m_bsr = sparse_tensor.new %fileName : !Filename to tensor<?x?xf32, #BSR>
157+
158+
// Call the kernel.
159+
%0 = call @SDDMM(%m_csr, %a, %b)
160+
: (tensor<?x?xf32, #CSR>,
161+
tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32, #CSR>
162+
%1 = call @SDDMM_block(%m_bsr, %a, %b)
163+
: (tensor<?x?xf32, #BSR>,
164+
tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32, #BSR>
165+
166+
//
167+
// Print the result for verification. Note that the "spy" determines what
168+
// dot products are sampled, but the original contents are added back to
169+
// the result (which is why the block sparse version has actual results
170+
// in the original zero positions).
171+
//
172+
// CHECK: ( 5, 10, 24, 19, 53, 42, 55, 56 )
173+
// CHECK-NEXT: ( 5, 10, 8, 19, 24, 24, 40, 53, 42, 55, 56, 64 )
174+
//
175+
%v0 = sparse_tensor.values %0 : tensor<?x?xf32, #CSR> to memref<?xf32>
176+
%vv0 = vector.transfer_read %v0[%c0], %d0 : memref<?xf32>, vector<8xf32>
177+
vector.print %vv0 : vector<8xf32>
178+
%v1 = sparse_tensor.values %1 : tensor<?x?xf32, #BSR> to memref<?xf32>
179+
%vv1 = vector.transfer_read %v1[%c0], %d0 : memref<?xf32>, vector<12xf32>
180+
vector.print %vv1 : vector<12xf32>
181+
182+
// Release the resources.
183+
bufferization.dealloc_tensor %0 : tensor<?x?xf32, #CSR>
184+
bufferization.dealloc_tensor %1 : tensor<?x?xf32, #BSR>
185+
186+
llvm.call @mgpuDestroySparseEnv() : () -> ()
187+
return
188+
}
189+
}

0 commit comments

Comments
 (0)