Skip to content

Commit 41a07e6

Browse files
authored
[mlir][sparse] recognize NVidia 2:4 type for matmul (llvm#76758)
This removes the temporary DENSE24 attribute and replaces it with proper recognition of dense to 24 conversion. The compressionh will be performed on the device prior to performing the matrix mult. Note that we no longer need to start with the linalg version, we can lift this to the proper named linalg op. Also renames some files into more consistent names.
1 parent 67c2e35 commit 41a07e6

File tree

6 files changed

+204
-175
lines changed

6 files changed

+204
-175
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,23 @@ static bool isAdmissibleBSR(SparseTensorType &aTp) {
448448
return false;
449449
}
450450

451+
/// Test for 2:4 matrix with suitable metadata.
452+
static bool isAdmissible24(SparseTensorType &aTp) {
453+
return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(0) &&
454+
aTp.isDenseLvl(1) && aTp.is2OutOf4Lvl(2) && isAdmissibleMetaData(aTp);
455+
}
456+
457+
/// Test for conversion into 2:4 matrix.
458+
static bool isConversionInto24(Value v) {
459+
if (auto cnv = v.getDefiningOp<ConvertOp>()) {
460+
Value a = cnv.getResult();
461+
Value d = cnv.getSource();
462+
SparseTensorType aTp = getSparseTensorType(a);
463+
return isDenseTensor(d) && isAdmissible24(aTp);
464+
}
465+
return false;
466+
}
467+
451468
/// Returns a suitable sparse format for the operation and given operand
452469
/// types with cuSparse, or kNone if none is available.
453470
static CuSparseFormat getCuSparseFormat(SparseTensorType aTp,
@@ -925,6 +942,15 @@ static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter,
925942
Value C = op.getOperand(2); // we have C = AB
926943
SmallVector<Value> tokens;
927944

945+
// The cuSparselt API currently only allows pruning and compression
946+
// to occur on the device. So we recognize the pattern
947+
// A' = convert A ; dense to 2:4
948+
// C = A'B ; 2:4 matrix mult
949+
// and then perform compression and matrix multiplication on device.
950+
auto cnv = A.getDefiningOp<ConvertOp>();
951+
assert(cnv);
952+
A = cnv.getSource();
953+
928954
// All input should be dense tensors.
929955
if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C))
930956
return failure();
@@ -1260,7 +1286,7 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
12601286
maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
12611287
if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1)))
12621288
return rewriteSpGEMM(rewriter, op, enableRT);
1263-
if (op->getAttr("DENSE24"))
1289+
if (isConversionInto24(op.getOperand(0)))
12641290
return rewrite2To4SpMM(rewriter, op);
12651291
return rewriteSpMM(rewriter, op, enableRT);
12661292
}

mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -970,7 +970,7 @@ mgpuCuSparseLtSpMMBufferSize(void *bs, int32_t ma, int32_t mb, void *a, void *b,
970970
// Note that this adds a synchronization on the stream.
971971
// TODO: Do we want that?
972972
if (prune_flag == 2) {
973-
int *dvalid = (int *)mgpuMemAlloc(sizeof(int), stream);
973+
int *dvalid = (int *)mgpuMemAlloc(sizeof(int), stream, false);
974974
CUSPARSE_REPORT_IF_ERROR(cusparseLtSpMMAPruneCheck(
975975
&cusparseLt_env, &(matA->matmul), matA->values, dvalid, stream))
976976
int valid = 0;

mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib_2to4.mlir renamed to mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
// RUN: mlir-opt %s --linalg-generalize-named-ops --sparse-gpu-codegen="num-threads=0" | FileCheck %s
22

3+
#NV_24 = #sparse_tensor.encoding<{
4+
map = ( i, j ) ->
5+
( i : dense,
6+
j floordiv 4 : dense,
7+
j mod 4 : block2_4
8+
)
9+
}>
10+
311
// CHECK-LABEL: func.func @matmul(
412
// CHECK-SAME: %[[VAL_0:.*0]]: tensor<?x?xf16>,
513
// CHECK-SAME: %[[VAL_1:.*1]]: tensor<?x?xf16>,
@@ -51,18 +59,14 @@
5159
// CHECK: %[[VAL_55:.*]] = bufferization.to_tensor %[[VAL_19]] : memref<?x?xf16>
5260
// CHECK: return %[[VAL_55]] : tensor<?x?xf16>
5361
// CHECK: }
54-
55-
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
56-
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
57-
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
5862
module {
59-
func.func @matmul(%arg0: tensor<?x?xf16>, %arg1: tensor<?x?xf16>, %arg2: tensor<?x?xf16>) -> tensor<?x?xf16> {
60-
%0 = linalg.generic { DENSE24, indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<?x?xf16>, tensor<?x?xf16>) outs(%arg2 : tensor<?x?xf16>) {
61-
^bb0(%in: f16, %in_0: f16, %out: f16):
62-
%1 = arith.mulf %in, %in_0 : f16
63-
%2 = arith.addf %out, %1 : f16
64-
linalg.yield %2 : f16
65-
} -> tensor<?x?xf16>
66-
return %0 : tensor<?x?xf16>
63+
func.func @matmul(%Ad: tensor<?x?xf16>,
64+
%B: tensor<?x?xf16>,
65+
%Cin: tensor<?x?xf16>) -> tensor<?x?xf16> {
66+
%A = sparse_tensor.convert %Ad : tensor<?x?xf16> to tensor<?x?xf16, #NV_24>
67+
%C = linalg.matmul
68+
ins(%A, %B: tensor<?x?xf16, #NV_24>, tensor<?x?xf16>)
69+
outs(%Cin: tensor<?x?xf16>) -> tensor<?x?xf16>
70+
return %C : tensor<?x?xf16>
6771
}
6872
}
Lines changed: 77 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,58 @@
11
// NOTE: this test requires gpu-sm80 and cusparselt
22
//
3-
// DEFINE: %{compile} = mlir-opt %s \
4-
// DEFINE: --sparsifier="enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 gpu-format=%gpu_compilation_format
3+
// DEFINE: %{compile} = mlir-opt --convert-vector-to-scf --convert-scf-to-cf -convert-cf-to-llvm --convert-vector-to-llvm \
4+
// DEFINE: --convert-arith-to-llvm --gpu-to-llvm --reconcile-unrealized-casts \
5+
// DEFINE: %s
56
// DEFINE: %{run} = mlir-cpu-runner \
67
// DEFINE: --shared-libs=%mlir_cuda_runtime \
78
// DEFINE: --shared-libs=%mlir_c_runner_utils \
89
// DEFINE: --e main --entry-point-result=void \
910
// DEFINE: | FileCheck %s
1011
//
11-
// with RT lib:
12-
//
13-
// RUN: %{compile} enable-runtime-library=true" | %{run}
14-
//
15-
// without RT lib:
16-
//
17-
// RUN: %{compile} enable-runtime-library=false" | %{run}
18-
19-
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
20-
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
21-
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
12+
// RUN: %{compile} | %{run}
2213

2314
module {
2415
llvm.func @mgpuCreateSparseLtEnv()
2516
llvm.func @mgpuDestroySparseLtEnv()
2617

27-
//
28-
// TODO: This uses our temporary ATTRIBUTE, replace with 2:4 type!
29-
//
30-
func.func @matmul_2to4(%arg0: tensor<16x32xf16>, %arg1: tensor<32x16xf16>, %arg2: tensor<16x16xf16>) -> tensor<16x16xf16> {
31-
%0 = linalg.generic { DENSE24, indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<16x32xf16>, tensor<32x16xf16>) outs(%arg2 : tensor<16x16xf16>) {
32-
^bb0(%in: f16, %in_0: f16, %out: f16):
33-
%1 = arith.mulf %in, %in_0 : f16
34-
%2 = arith.addf %out, %1 : f16
35-
linalg.yield %2 : f16
36-
} -> tensor<16x16xf16>
37-
return %0 : tensor<16x16xf16>
18+
// cuSparselt version for matmul coded by hand.
19+
func.func @matmul24(%a : memref<16x32xf16>,
20+
%b : memref<32x16xf16>,
21+
%c : memref<16x16xf16>) {
22+
%c0 = arith.constant 0.0 : f16
23+
%c1 = arith.constant 1 : index
24+
%c2 = arith.constant 2 : index
25+
%c8 = arith.constant 8 : index
26+
%c16 = arith.constant 16 : index
27+
%c32 = arith.constant 32 : index
28+
%c1048576 = arith.constant 1048576 : index
29+
%token0 = gpu.wait async
30+
%d_a, %token1 = gpu.alloc async [%token0] () : memref<16x32xf16>
31+
%d_b, %token2 = gpu.alloc async [%token1] () : memref<32x16xf16>
32+
%d_c, %token3 = gpu.alloc async [%token2] () : memref<16x16xf16>
33+
%token4 = gpu.memcpy async [%token3] %d_a, %a : memref<16x32xf16>, memref<16x32xf16>
34+
%token5 = gpu.memcpy async [%token4] %d_b, %b : memref<32x16xf16>, memref<32x16xf16>
35+
%token6 = gpu.memcpy async [%token5] %d_c, %c : memref<16x16xf16>, memref<16x16xf16>
36+
%spmat, %token8 = gpu.create_2to4_spmat async [%token6]{PRUNE_AND_CHECK} %c16, %c32, %d_a: memref<16x32xf16>
37+
%dnmat, %token9 = gpu.create_dn_tensor async [%token8] %d_b, %c32, %c16: index, index into memref<32x16xf16>
38+
%dnmat2, %token10 = gpu.create_dn_tensor async [%token9] %d_c, %c16, %c16: index, index into memref<16x16xf16>
39+
%bufferSz0, %bufferSz1, %bufferSz2, %token11 = gpu.spmm_buffer_size async [%token10] %spmat{NON_TRANSPOSE}, %dnmat{NON_TRANSPOSE}, %dnmat2 : index, index,index into f16
40+
%mem1, %token12 = gpu.alloc async [%token11] (%bufferSz0) : memref<?xf16>
41+
%mem2, %token13 = gpu.alloc async [%token12] (%bufferSz1) : memref<?xf16>
42+
%mem3, %token14 = gpu.alloc async [%token13] (%bufferSz2) : memref<?xf16>
43+
%token15 = gpu.spmm async [%token14] %spmat{NON_TRANSPOSE}, %dnmat{NON_TRANSPOSE}, %dnmat2, %mem1, %mem2, %mem3 : memref<?xf16>, memref<?xf16>,memref<?xf16> into f16
44+
%token16 = gpu.destroy_sp_mat async [%token15] %spmat
45+
%token17 = gpu.destroy_dn_tensor async [%token16] %dnmat
46+
%token18 = gpu.destroy_dn_tensor async [%token17] %dnmat2
47+
%token19 = gpu.memcpy async [%token18] %c, %d_c : memref<16x16xf16>, memref<16x16xf16>
48+
%token20 = gpu.dealloc async [%token19] %d_c : memref<16x16xf16>
49+
%token21 = gpu.dealloc async [%token20] %d_b : memref<32x16xf16>
50+
%token22 = gpu.dealloc async [%token21] %d_a : memref<16x32xf16>
51+
%token23 = gpu.dealloc async [%token22] %mem3 : memref<?xf16>
52+
%token24 = gpu.dealloc async [%token23] %mem2 : memref<?xf16>
53+
%token25 = gpu.dealloc async [%token24] %mem1 : memref<?xf16>
54+
gpu.wait [%token25]
55+
return
3856
}
3957

4058
//
@@ -54,50 +72,49 @@ module {
5472
%c64 = arith.constant 64 : index
5573

5674
// Matrices A, B, C (16x32, 32x16, 16x16).
75+
%a = memref.alloc() : memref<16x32xf16> // 16x32 with 2:4, row-major
76+
%b = memref.alloc() : memref<32x16xf16> // regular dense column-major
77+
%c = memref.alloc() : memref<16x16xf16> // accumulator row-major
5778

5879
//
5980
// Setup matrix A.
6081
//
61-
%DA = tensor.generate {
62-
^bb0(%i: index, %j: index):
63-
// (i+ j/2 + 1) if j %2 == 0 else 0
64-
%cf0 = arith.constant 0.0 : f16
65-
%cf1 = arith.constant 1.0 : f16
66-
%j_2 = arith.floordivsi %j, %c2 : index
67-
%quotient = arith.remsi %j, %c2 : index
68-
%sum = arith.addi %i, %j_2 : index
69-
%sum_i = arith.index_cast %sum : index to i64
70-
%sum_f = arith.uitofp %sum_i : i64 to f16
71-
%sum_f_plus1 = arith.addf %sum_f, %cf1 : f16
72-
%is_zero = arith.cmpi "eq", %quotient, %c0 : index
73-
%s = arith.select %is_zero, %sum_f_plus1, %cf0 : f16
74-
tensor.yield %s : f16
75-
} : tensor<16x32xf16>
82+
scf.for %ai = %c0 to %c16 step %c1 {
83+
scf.for %aj = %c0 to %c16 step %c1 {
84+
%cf0 = arith.constant 0.0: f16
85+
%a0 = arith.addi %ai, %aj : index
86+
%a1 = arith.addi %a0, %c1 : index
87+
%a2 = arith.index_cast %a1 : index to i32
88+
%a3 = arith.sitofp %a2 : i32 to f16
89+
%ajj = arith.muli %aj, %c2 : index
90+
%ajj2 = arith.addi %ajj, %c1 : index
91+
memref.store %a3, %a[%ai, %ajj] : memref<16x32xf16>
92+
memref.store %cf0, %a[%ai, %ajj2] : memref<16x32xf16>
93+
}
94+
}
7695

7796
//
7897
// Setup matrix B.
7998
//
80-
%DB = tensor.generate {
81-
^bb0(%i: index, %j: index):
82-
// if j_i >=8, j_i - 8 else 0
83-
%is_ge8 = arith.cmpi "sge", %j, %c8 : index
84-
%j_minus8 = arith.subi %j, %c8 : index
85-
%j2 = arith.select %is_ge8, %j_minus8, %j : index
86-
%r_i = arith.subi %j2, %i : index
87-
%r_i64 = arith.index_cast %r_i : index to i64
88-
%r_f = arith.sitofp %r_i64 : i64 to f16
89-
tensor.yield %r_f : f16
90-
} : tensor<32x16xf16>
99+
scf.for %bi = %c0 to %c8 step %c1 {
100+
scf.for %bj = %c0 to %c32 step %c1 {
101+
%b0 = arith.subi %bi, %bj : index
102+
%b1 = arith.index_cast %b0 : index to i32
103+
%b2 = arith.sitofp %b1 : i32 to f16
104+
%bii = arith.addi %bi, %c8 : index
105+
memref.store %b2, %b[%bj, %bi] : memref<32x16xf16>
106+
memref.store %b2, %b[%bj, %bii] : memref<32x16xf16>
107+
}
108+
}
91109

92110
//
93111
// Reset matrix C.
94112
//
95-
%DC = tensor.generate {
96-
^bb0(%i: index, %j: index):
97-
%cf0 = arith.constant 0.0 : f16
98-
tensor.yield %cf0 : f16
99-
} : tensor<16x16xf16>
100-
113+
scf.for %ci = %c0 to %c16 step %c1 {
114+
scf.for %cj = %c0 to %c16 step %c1 {
115+
memref.store %f0, %c[%ci, %cj] : memref<16x16xf16>
116+
}
117+
}
101118

102119
//
103120
// Sanity check on 16x32 full 2:4 input matrix A.
@@ -121,7 +138,7 @@ module {
121138
// CHECK-NEXT: ( 16, 0, 17, 0, 18, 0, 19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24, 0, 25, 0, 26, 0, 27, 0, 28, 0, 29, 0, 30, 0, 31, 0 )
122139
//
123140
scf.for %pai = %c0 to %c16 step %c1 {
124-
%pa0 = vector.transfer_read %DA[%pai, %c0], %f0 : tensor<16x32xf16>, vector<32xf16>
141+
%pa0 = vector.transfer_read %a[%pai, %c0], %f0 : memref<16x32xf16>, vector<32xf16>
125142
vector.print %pa0 : vector<32xf16>
126143
}
127144

@@ -163,14 +180,12 @@ module {
163180
//
164181
//
165182
scf.for %pbi = %c0 to %c32 step %c1 {
166-
%pb0 = vector.transfer_read %DB[%pbi, %c0], %f0 : tensor<32x16xf16>, vector<16xf16>
183+
%pb0 = vector.transfer_read %b[%pbi, %c0], %f0 : memref<32x16xf16>, vector<16xf16>
167184
vector.print %pb0 : vector<16xf16>
168185
}
169186

170187
// Call the kernel.
171-
%t1 = arith.constant 1 : index
172-
%t32 = arith.constant 32 : index
173-
%c_out = call @matmul_2to4 (%DA, %DB, %DC): (tensor<16x32xf16>, tensor<32x16xf16>, tensor<16x16xf16>) -> tensor<16x16xf16>
188+
call @matmul24(%a, %b, %c): (memref<16x32xf16>, memref<32x16xf16>, memref<16x16xf16>) -> ()
174189

175190
//
176191
// Verify computed matrix C.
@@ -193,7 +208,7 @@ module {
193208
// CHECK-NEXT: ( -6320, -5944, -5568, -5192, -4816, -4440, -4064, -3688, -6320, -5944, -5568, -5192, -4816, -4440, -4064, -3688 )
194209
//
195210
scf.for %pci = %c0 to %c16 step %c1 {
196-
%pc0 = vector.transfer_read %c_out[%pci, %c0], %f0 : tensor<16x16xf16>, vector<16xf16>
211+
%pc0 = vector.transfer_read %c[%pci, %c0], %f0 : memref<16x16xf16>, vector<16xf16>
197212
vector.print %pc0 : vector<16xf16>
198213
}
199214

0 commit comments

Comments
 (0)