Skip to content

[mlir][sparse] recognize NVidia 2:4 type for matmul #76758

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,23 @@ static bool isAdmissibleBSR(SparseTensorType &aTp) {
return false;
}

/// Test for 2:4 matrix with suitable metadata.
static bool isAdmissible24(SparseTensorType &aTp) {
return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(0) &&
aTp.isDenseLvl(1) && aTp.is2OutOf4Lvl(2) && isAdmissibleMetaData(aTp);
}

/// Test for conversion into 2:4 matrix.
static bool isConversionInto24(Value v) {
if (auto cnv = v.getDefiningOp<ConvertOp>()) {
Value a = cnv.getResult();
Value d = cnv.getSource();
SparseTensorType aTp = getSparseTensorType(a);
return isDenseTensor(d) && isAdmissible24(aTp);
}
return false;
}

/// Returns a suitable sparse format for the operation and given operand
/// types with cuSparse, or kNone if none is available.
static CuSparseFormat getCuSparseFormat(SparseTensorType aTp,
Expand Down Expand Up @@ -925,6 +942,15 @@ static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter,
Value C = op.getOperand(2); // we have C = AB
SmallVector<Value> tokens;

// The cuSparselt API currently only allows pruning and compression
// to occur on the device. So we recognize the pattern
// A' = convert A ; dense to 2:4
// C = A'B ; 2:4 matrix mult
// and then perform compression and matrix multiplication on device.
auto cnv = A.getDefiningOp<ConvertOp>();
assert(cnv);
A = cnv.getSource();

// All input should be dense tensors.
if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C))
return failure();
Expand Down Expand Up @@ -1260,7 +1286,7 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1)))
return rewriteSpGEMM(rewriter, op, enableRT);
if (op->getAttr("DENSE24"))
if (isConversionInto24(op.getOperand(0)))
return rewrite2To4SpMM(rewriter, op);
return rewriteSpMM(rewriter, op, enableRT);
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,7 @@ mgpuCuSparseLtSpMMBufferSize(void *bs, int32_t ma, int32_t mb, void *a, void *b,
// Note that this adds a synchronization on the stream.
// TODO: Do we want that?
if (prune_flag == 2) {
int *dvalid = (int *)mgpuMemAlloc(sizeof(int), stream);
int *dvalid = (int *)mgpuMemAlloc(sizeof(int), stream, false);
CUSPARSE_REPORT_IF_ERROR(cusparseLtSpMMAPruneCheck(
&cusparseLt_env, &(matA->matmul), matA->values, dvalid, stream))
int valid = 0;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
// RUN: mlir-opt %s --linalg-generalize-named-ops --sparse-gpu-codegen="num-threads=0" | FileCheck %s

#NV_24 = #sparse_tensor.encoding<{
map = ( i, j ) ->
( i : dense,
j floordiv 4 : dense,
j mod 4 : block2_4
)
}>

// CHECK-LABEL: func.func @matmul(
// CHECK-SAME: %[[VAL_0:.*0]]: tensor<?x?xf16>,
// CHECK-SAME: %[[VAL_1:.*1]]: tensor<?x?xf16>,
Expand Down Expand Up @@ -51,18 +59,14 @@
// CHECK: %[[VAL_55:.*]] = bufferization.to_tensor %[[VAL_19]] : memref<?x?xf16>
// CHECK: return %[[VAL_55]] : tensor<?x?xf16>
// CHECK: }

#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
func.func @matmul(%arg0: tensor<?x?xf16>, %arg1: tensor<?x?xf16>, %arg2: tensor<?x?xf16>) -> tensor<?x?xf16> {
%0 = linalg.generic { DENSE24, indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<?x?xf16>, tensor<?x?xf16>) outs(%arg2 : tensor<?x?xf16>) {
^bb0(%in: f16, %in_0: f16, %out: f16):
%1 = arith.mulf %in, %in_0 : f16
%2 = arith.addf %out, %1 : f16
linalg.yield %2 : f16
} -> tensor<?x?xf16>
return %0 : tensor<?x?xf16>
func.func @matmul(%Ad: tensor<?x?xf16>,
%B: tensor<?x?xf16>,
%Cin: tensor<?x?xf16>) -> tensor<?x?xf16> {
%A = sparse_tensor.convert %Ad : tensor<?x?xf16> to tensor<?x?xf16, #NV_24>
%C = linalg.matmul
ins(%A, %B: tensor<?x?xf16, #NV_24>, tensor<?x?xf16>)
outs(%Cin: tensor<?x?xf16>) -> tensor<?x?xf16>
return %C : tensor<?x?xf16>
}
}
Original file line number Diff line number Diff line change
@@ -1,40 +1,58 @@
// NOTE: this test requires gpu-sm80 and cusparselt
//
// DEFINE: %{compile} = mlir-opt %s \
// DEFINE: --sparsifier="enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 gpu-format=%gpu_compilation_format
// DEFINE: %{compile} = mlir-opt --convert-vector-to-scf --convert-scf-to-cf -convert-cf-to-llvm --convert-vector-to-llvm \
// DEFINE: --convert-arith-to-llvm --gpu-to-llvm --reconcile-unrealized-casts \
// DEFINE: %s
// DEFINE: %{run} = mlir-cpu-runner \
// DEFINE: --shared-libs=%mlir_cuda_runtime \
// DEFINE: --shared-libs=%mlir_c_runner_utils \
// DEFINE: --e main --entry-point-result=void \
// DEFINE: | FileCheck %s
//
// with RT lib:
//
// RUN: %{compile} enable-runtime-library=true" | %{run}
//
// without RT lib:
//
// RUN: %{compile} enable-runtime-library=false" | %{run}

#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
// RUN: %{compile} | %{run}

module {
llvm.func @mgpuCreateSparseLtEnv()
llvm.func @mgpuDestroySparseLtEnv()

//
// TODO: This uses our temporary ATTRIBUTE, replace with 2:4 type!
//
func.func @matmul_2to4(%arg0: tensor<16x32xf16>, %arg1: tensor<32x16xf16>, %arg2: tensor<16x16xf16>) -> tensor<16x16xf16> {
%0 = linalg.generic { DENSE24, indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<16x32xf16>, tensor<32x16xf16>) outs(%arg2 : tensor<16x16xf16>) {
^bb0(%in: f16, %in_0: f16, %out: f16):
%1 = arith.mulf %in, %in_0 : f16
%2 = arith.addf %out, %1 : f16
linalg.yield %2 : f16
} -> tensor<16x16xf16>
return %0 : tensor<16x16xf16>
// cuSparselt version for matmul coded by hand.
func.func @matmul24(%a : memref<16x32xf16>,
%b : memref<32x16xf16>,
%c : memref<16x16xf16>) {
%c0 = arith.constant 0.0 : f16
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%c8 = arith.constant 8 : index
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c1048576 = arith.constant 1048576 : index
%token0 = gpu.wait async
%d_a, %token1 = gpu.alloc async [%token0] () : memref<16x32xf16>
%d_b, %token2 = gpu.alloc async [%token1] () : memref<32x16xf16>
%d_c, %token3 = gpu.alloc async [%token2] () : memref<16x16xf16>
%token4 = gpu.memcpy async [%token3] %d_a, %a : memref<16x32xf16>, memref<16x32xf16>
%token5 = gpu.memcpy async [%token4] %d_b, %b : memref<32x16xf16>, memref<32x16xf16>
%token6 = gpu.memcpy async [%token5] %d_c, %c : memref<16x16xf16>, memref<16x16xf16>
%spmat, %token8 = gpu.create_2to4_spmat async [%token6]{PRUNE_AND_CHECK} %c16, %c32, %d_a: memref<16x32xf16>
%dnmat, %token9 = gpu.create_dn_tensor async [%token8] %d_b, %c32, %c16: index, index into memref<32x16xf16>
%dnmat2, %token10 = gpu.create_dn_tensor async [%token9] %d_c, %c16, %c16: index, index into memref<16x16xf16>
%bufferSz0, %bufferSz1, %bufferSz2, %token11 = gpu.spmm_buffer_size async [%token10] %spmat{NON_TRANSPOSE}, %dnmat{NON_TRANSPOSE}, %dnmat2 : index, index,index into f16
%mem1, %token12 = gpu.alloc async [%token11] (%bufferSz0) : memref<?xf16>
%mem2, %token13 = gpu.alloc async [%token12] (%bufferSz1) : memref<?xf16>
%mem3, %token14 = gpu.alloc async [%token13] (%bufferSz2) : memref<?xf16>
%token15 = gpu.spmm async [%token14] %spmat{NON_TRANSPOSE}, %dnmat{NON_TRANSPOSE}, %dnmat2, %mem1, %mem2, %mem3 : memref<?xf16>, memref<?xf16>,memref<?xf16> into f16
%token16 = gpu.destroy_sp_mat async [%token15] %spmat
%token17 = gpu.destroy_dn_tensor async [%token16] %dnmat
%token18 = gpu.destroy_dn_tensor async [%token17] %dnmat2
%token19 = gpu.memcpy async [%token18] %c, %d_c : memref<16x16xf16>, memref<16x16xf16>
%token20 = gpu.dealloc async [%token19] %d_c : memref<16x16xf16>
%token21 = gpu.dealloc async [%token20] %d_b : memref<32x16xf16>
%token22 = gpu.dealloc async [%token21] %d_a : memref<16x32xf16>
%token23 = gpu.dealloc async [%token22] %mem3 : memref<?xf16>
%token24 = gpu.dealloc async [%token23] %mem2 : memref<?xf16>
%token25 = gpu.dealloc async [%token24] %mem1 : memref<?xf16>
gpu.wait [%token25]
return
}

//
Expand All @@ -54,50 +72,49 @@ module {
%c64 = arith.constant 64 : index

// Matrices A, B, C (16x32, 32x16, 16x16).
%a = memref.alloc() : memref<16x32xf16> // 16x32 with 2:4, row-major
%b = memref.alloc() : memref<32x16xf16> // regular dense column-major
%c = memref.alloc() : memref<16x16xf16> // accumulator row-major

//
// Setup matrix A.
//
%DA = tensor.generate {
^bb0(%i: index, %j: index):
// (i+ j/2 + 1) if j %2 == 0 else 0
%cf0 = arith.constant 0.0 : f16
%cf1 = arith.constant 1.0 : f16
%j_2 = arith.floordivsi %j, %c2 : index
%quotient = arith.remsi %j, %c2 : index
%sum = arith.addi %i, %j_2 : index
%sum_i = arith.index_cast %sum : index to i64
%sum_f = arith.uitofp %sum_i : i64 to f16
%sum_f_plus1 = arith.addf %sum_f, %cf1 : f16
%is_zero = arith.cmpi "eq", %quotient, %c0 : index
%s = arith.select %is_zero, %sum_f_plus1, %cf0 : f16
tensor.yield %s : f16
} : tensor<16x32xf16>
scf.for %ai = %c0 to %c16 step %c1 {
scf.for %aj = %c0 to %c16 step %c1 {
%cf0 = arith.constant 0.0: f16
%a0 = arith.addi %ai, %aj : index
%a1 = arith.addi %a0, %c1 : index
%a2 = arith.index_cast %a1 : index to i32
%a3 = arith.sitofp %a2 : i32 to f16
%ajj = arith.muli %aj, %c2 : index
%ajj2 = arith.addi %ajj, %c1 : index
memref.store %a3, %a[%ai, %ajj] : memref<16x32xf16>
memref.store %cf0, %a[%ai, %ajj2] : memref<16x32xf16>
}
}

//
// Setup matrix B.
//
%DB = tensor.generate {
^bb0(%i: index, %j: index):
// if j_i >=8, j_i - 8 else 0
%is_ge8 = arith.cmpi "sge", %j, %c8 : index
%j_minus8 = arith.subi %j, %c8 : index
%j2 = arith.select %is_ge8, %j_minus8, %j : index
%r_i = arith.subi %j2, %i : index
%r_i64 = arith.index_cast %r_i : index to i64
%r_f = arith.sitofp %r_i64 : i64 to f16
tensor.yield %r_f : f16
} : tensor<32x16xf16>
scf.for %bi = %c0 to %c8 step %c1 {
scf.for %bj = %c0 to %c32 step %c1 {
%b0 = arith.subi %bi, %bj : index
%b1 = arith.index_cast %b0 : index to i32
%b2 = arith.sitofp %b1 : i32 to f16
%bii = arith.addi %bi, %c8 : index
memref.store %b2, %b[%bj, %bi] : memref<32x16xf16>
memref.store %b2, %b[%bj, %bii] : memref<32x16xf16>
}
}

//
// Reset matrix C.
//
%DC = tensor.generate {
^bb0(%i: index, %j: index):
%cf0 = arith.constant 0.0 : f16
tensor.yield %cf0 : f16
} : tensor<16x16xf16>

scf.for %ci = %c0 to %c16 step %c1 {
scf.for %cj = %c0 to %c16 step %c1 {
memref.store %f0, %c[%ci, %cj] : memref<16x16xf16>
}
}

//
// Sanity check on 16x32 full 2:4 input matrix A.
Expand All @@ -121,7 +138,7 @@ module {
// CHECK-NEXT: ( 16, 0, 17, 0, 18, 0, 19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24, 0, 25, 0, 26, 0, 27, 0, 28, 0, 29, 0, 30, 0, 31, 0 )
//
scf.for %pai = %c0 to %c16 step %c1 {
%pa0 = vector.transfer_read %DA[%pai, %c0], %f0 : tensor<16x32xf16>, vector<32xf16>
%pa0 = vector.transfer_read %a[%pai, %c0], %f0 : memref<16x32xf16>, vector<32xf16>
vector.print %pa0 : vector<32xf16>
}

Expand Down Expand Up @@ -163,14 +180,12 @@ module {
//
//
scf.for %pbi = %c0 to %c32 step %c1 {
%pb0 = vector.transfer_read %DB[%pbi, %c0], %f0 : tensor<32x16xf16>, vector<16xf16>
%pb0 = vector.transfer_read %b[%pbi, %c0], %f0 : memref<32x16xf16>, vector<16xf16>
vector.print %pb0 : vector<16xf16>
}

// Call the kernel.
%t1 = arith.constant 1 : index
%t32 = arith.constant 32 : index
%c_out = call @matmul_2to4 (%DA, %DB, %DC): (tensor<16x32xf16>, tensor<32x16xf16>, tensor<16x16xf16>) -> tensor<16x16xf16>
call @matmul24(%a, %b, %c): (memref<16x32xf16>, memref<32x16xf16>, memref<16x16xf16>) -> ()

//
// Verify computed matrix C.
Expand All @@ -193,7 +208,7 @@ module {
// CHECK-NEXT: ( -6320, -5944, -5568, -5192, -4816, -4440, -4064, -3688, -6320, -5944, -5568, -5192, -4816, -4440, -4064, -3688 )
//
scf.for %pci = %c0 to %c16 step %c1 {
%pc0 = vector.transfer_read %c_out[%pci, %c0], %f0 : tensor<16x16xf16>, vector<16xf16>
%pc0 = vector.transfer_read %c[%pci, %c0], %f0 : memref<16x16xf16>, vector<16xf16>
vector.print %pc0 : vector<16xf16>
}

Expand Down
Loading