Skip to content

Commit e37fc3c

Browse files
author
K-Wu
committed
[mlir][sparse][gpu] Impl 2:4 SpMM rewrite for linalg op w/ DENSE24 attr
Differential Revision: https://reviews.llvm.org/D154772
1 parent 7f9ba19 commit e37fc3c

File tree

2 files changed

+180
-0
lines changed

2 files changed

+180
-0
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,11 @@ static bool matchSumReductionOfMulUnary(linalg::GenericOp op) {
374374
return false;
375375
}
376376

377+
/// Determines if the given value is a dense tensor instead of a sparse one.
378+
static bool isDenseTensor(Value v) {
379+
return (sparse_tensor::getSparseTensorType(v).isAllDense());
380+
}
381+
377382
/// Test for sorted COO with suitable data and coordinates types.
378383
static bool isAdmissibleCOO(SparseTensorType &aTp) {
379384
return aTp.isCompressedLvl(0) && aTp.isOrderedLvl(0) && !aTp.isUniqueLvl(0) &&
@@ -656,6 +661,109 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
656661
return success();
657662
}
658663

664+
// Match and rewrite 2:4 SpMM kernels.
665+
static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter,
666+
linalg::GenericOp op) {
667+
Location loc = op.getLoc();
668+
Value A = op.getOperand(0);
669+
Value B = op.getOperand(1);
670+
Value C = op.getOperand(2); // we have C = AB
671+
SmallVector<Value> tokens;
672+
673+
// All input should be dense tensors.
674+
if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C))
675+
return failure();
676+
677+
Value bufA = genTensorToMemref(rewriter, loc, A);
678+
Value matA = genAllocCopy(rewriter, loc, bufA, tokens);
679+
Value bufB = genTensorToMemref(rewriter, loc, B);
680+
Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
681+
Value bufC = genTensorToMemref(rewriter, loc, C);
682+
Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
683+
genBlockingWait(rewriter, loc, tokens);
684+
tokens.clear();
685+
Value szm = linalg::createOrFoldDimOp(rewriter, loc, matA, 0);
686+
Value szk = linalg::createOrFoldDimOp(rewriter, loc, matB, 0);
687+
Value szn = linalg::createOrFoldDimOp(rewriter, loc, matC, 1);
688+
689+
Type indexTp = rewriter.getIndexType();
690+
Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
691+
Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
692+
Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
693+
Value token = genFirstWait(rewriter, loc);
694+
Operation *spGenA = rewriter.create<gpu::Create2To4SpMatOp>(
695+
loc, spMatHandleTp, tokenTp, token, szm, szk, matA);
696+
697+
Value spMatA = spGenA->getResult(0);
698+
token = spGenA->getResult(1);
699+
auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
700+
loc, dnTensorHandleTp, tokenTp, token, matB,
701+
SmallVector<Value>{szk, szn});
702+
Value dnB = dmatB.getResult(0);
703+
token = dmatB.getAsyncToken();
704+
auto dmatC = rewriter.create<gpu::CreateDnTensorOp>(
705+
loc, dnTensorHandleTp, tokenTp, token, matC,
706+
SmallVector<Value>{szm, szn});
707+
Value dnC = dmatC.getResult(0);
708+
token = dmatC.getAsyncToken();
709+
710+
auto dmatCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
711+
712+
// Precompute buffersize for SpMM.
713+
SmallVector<Type> bufferTypes_{indexTp, indexTp, indexTp};
714+
TypeRange bufferTypes(bufferTypes_);
715+
auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>(
716+
loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE,
717+
gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC,
718+
/*computeType=*/dmatCType);
719+
720+
token = bufferComp.getAsyncToken();
721+
Value bufferSz = bufferComp.getResult(0);
722+
auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
723+
Value buffer = buf.getResult(0);
724+
token = buf.getAsyncToken();
725+
726+
Value bufferSz2 = bufferComp.getResult(1);
727+
auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
728+
Value buffer2 = buf2.getResult(0);
729+
token = buf2.getAsyncToken();
730+
731+
Value bufferSz3 = bufferComp.getResult(2);
732+
auto buf3 = genAllocBuffer(rewriter, loc, bufferSz3, token);
733+
Value buffer3 = buf3.getResult(0);
734+
token = buf3.getAsyncToken();
735+
736+
auto dnCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
737+
738+
// Perform the SpMM.
739+
auto spmmComp = rewriter.create<gpu::SpMMOp>(
740+
loc, tokenTp, token, spMatA, dnB, dnC, /*computeType=*/dnCType,
741+
SmallVector<Value>{buffer, buffer2, buffer3});
742+
token = spmmComp.getAsyncToken();
743+
744+
// Copy data back to host and free all the resources.
745+
token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
746+
.getAsyncToken();
747+
token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnB)
748+
.getAsyncToken();
749+
token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
750+
.getAsyncToken();
751+
SmallVector<Value> newDynamicSizes;
752+
753+
token = genDeallocMemRef(rewriter, loc, buffer, token);
754+
token = genDeallocMemRef(rewriter, loc, buffer2, token);
755+
token = genDeallocMemRef(rewriter, loc, buffer3, token);
756+
token = genDeallocMemRef(rewriter, loc, matA, token);
757+
token = genDeallocMemRef(rewriter, loc, matB, token);
758+
token = genCopyMemRef(rewriter, loc, bufC, matC, token);
759+
token = genDeallocMemRef(rewriter, loc, matC, token);
760+
tokens.push_back(token);
761+
genBlockingWait(rewriter, loc, tokens);
762+
tokens.clear();
763+
rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
764+
return success();
765+
}
766+
659767
/// Match and rewrite SDDMM kernel.
660768
static LogicalResult rewriteSDDMM(PatternRewriter &rewriter,
661769
linalg::GenericOp op, bool enableRT) {
@@ -906,6 +1014,9 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
9061014
// TODO: add transposed {i, k}, {k, j}
9071015
// TODO: maybe add transposed {i, j} in future
9081016
maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
1017+
if (op->getAttr("DENSE24"))
1018+
return rewrite2To4SpMM(rewriter, op);
1019+
9091020
return rewriteSpMM(rewriter, op, enableRT);
9101021
}
9111022

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// RUN: mlir-opt %s --linalg-generalize-named-ops \
2+
// RUN: --sparsification="enable-gpu-libgen" | FileCheck %s
3+
4+
// CHECK-LABEL: func.func @matmul(
5+
// CHECK-SAME: %[[VAL_0:.*0]]: tensor<?x?xf16>,
6+
// CHECK-SAME: %[[VAL_1:.*1]]: tensor<?x?xf16>,
7+
// CHECK-SAME: %[[VAL_2:.*2]]: tensor<?x?xf16>) -> tensor<?x?xf16> {
8+
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
9+
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
10+
// CHECK: %[[VAL_5:.*]] = bufferization.to_memref %[[VAL_0]] : memref<?x?xf16>
11+
// CHECK: %[[VAL_6:.*]] = gpu.wait async
12+
// CHECK: %[[VAL_7:.*]] = memref.dim %[[VAL_5]], %[[VAL_3]] : memref<?x?xf16>
13+
// CHECK: %[[VAL_8:.*]] = memref.dim %[[VAL_5]], %[[VAL_4]] : memref<?x?xf16>
14+
// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]] = gpu.alloc async {{\[}}%[[VAL_6]]] (%[[VAL_7]], %[[VAL_8]]) : memref<?x?xf16>
15+
// CHECK: %[[VAL_11:.*]] = gpu.memcpy async {{\[}}%[[VAL_10]]] %[[VAL_9]], %[[VAL_5]] : memref<?x?xf16>, memref<?x?xf16>
16+
// CHECK: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_1]] : memref<?x?xf16>
17+
// CHECK: %[[VAL_13:.*]] = gpu.wait async
18+
// CHECK: %[[VAL_14:.*]] = memref.dim %[[VAL_12]], %[[VAL_3]] : memref<?x?xf16>
19+
// CHECK: %[[VAL_15:.*]] = memref.dim %[[VAL_12]], %[[VAL_4]] : memref<?x?xf16>
20+
// CHECK: %[[VAL_16:.*]], %[[VAL_17:.*]] = gpu.alloc async {{\[}}%[[VAL_13]]] (%[[VAL_14]], %[[VAL_15]]) : memref<?x?xf16>
21+
// CHECK: %[[VAL_18:.*]] = gpu.memcpy async {{\[}}%[[VAL_17]]] %[[VAL_16]], %[[VAL_12]] : memref<?x?xf16>, memref<?x?xf16>
22+
// CHECK: %[[VAL_19:.*]] = bufferization.to_memref %[[VAL_2]] : memref<?x?xf16>
23+
// CHECK: %[[VAL_20:.*]] = gpu.wait async
24+
// CHECK: %[[VAL_21:.*]] = memref.dim %[[VAL_19]], %[[VAL_3]] : memref<?x?xf16>
25+
// CHECK: %[[VAL_22:.*]] = memref.dim %[[VAL_19]], %[[VAL_4]] : memref<?x?xf16>
26+
// CHECK: %[[VAL_23:.*]], %[[VAL_24:.*]] = gpu.alloc async {{\[}}%[[VAL_20]]] (%[[VAL_21]], %[[VAL_22]]) : memref<?x?xf16>
27+
// CHECK: %[[VAL_25:.*]] = gpu.memcpy async {{\[}}%[[VAL_24]]] %[[VAL_23]], %[[VAL_19]] : memref<?x?xf16>, memref<?x?xf16>
28+
// CHECK: gpu.wait {{\[}}%[[VAL_11]], %[[VAL_18]], %[[VAL_25]]]
29+
// CHECK: %[[VAL_26:.*]] = memref.dim %[[VAL_9]], %[[VAL_3]] : memref<?x?xf16>
30+
// CHECK: %[[VAL_27:.*]] = memref.dim %[[VAL_16]], %[[VAL_3]] : memref<?x?xf16>
31+
// CHECK: %[[VAL_28:.*]] = memref.dim %[[VAL_23]], %[[VAL_4]] : memref<?x?xf16>
32+
// CHECK: %[[VAL_29:.*]] = gpu.wait async
33+
// CHECK: %[[VAL_30:.*]], %[[VAL_31:.*]] = gpu.create_2to4_spmat async {{\[}}%[[VAL_29]]] %[[VAL_26]], %[[VAL_27]], %[[VAL_9]] : memref<?x?xf16>
34+
// CHECK: %[[VAL_32:.*]], %[[VAL_33:.*]] = gpu.create_dn_tensor async {{\[}}%[[VAL_31]]] %[[VAL_16]], %[[VAL_27]], %[[VAL_28]] : index, index into memref<?x?xf16>
35+
// CHECK: %[[VAL_34:.*]], %[[VAL_35:.*]] = gpu.create_dn_tensor async {{\[}}%[[VAL_33]]] %[[VAL_23]], %[[VAL_26]], %[[VAL_28]] : index, index into memref<?x?xf16>
36+
// CHECK: %[[VAL_36:.*]]:3, %[[VAL_37:.*]] = gpu.spmm_buffer_size async {{\[}}%[[VAL_35]]] %[[VAL_30]], %[[VAL_32]], %[[VAL_34]] : index, index, index into f16
37+
// CHECK: %[[VAL_38:.*]], %[[VAL_39:.*]] = gpu.alloc async {{\[}}%[[VAL_37]]] (%[[VAL_36]]#0) : memref<?xi8>
38+
// CHECK: %[[VAL_40:.*]], %[[VAL_41:.*]] = gpu.alloc async {{\[}}%[[VAL_39]]] (%[[VAL_36]]#1) : memref<?xi8>
39+
// CHECK: %[[VAL_42:.*]], %[[VAL_43:.*]] = gpu.alloc async {{\[}}%[[VAL_41]]] (%[[VAL_36]]#2) : memref<?xi8>
40+
// CHECK: %[[VAL_44:.*]] = gpu.spmm async {{\[}}%[[VAL_43]]] %[[VAL_30]], %[[VAL_32]], %[[VAL_34]], %[[VAL_38]], %[[VAL_40]], %[[VAL_42]] : memref<?xi8>, memref<?xi8>, memref<?xi8> into f16
41+
// CHECK: %[[VAL_45:.*]] = gpu.destroy_sp_mat async {{\[}}%[[VAL_44]]] %[[VAL_30]]
42+
// CHECK: %[[VAL_46:.*]] = gpu.destroy_dn_tensor async {{\[}}%[[VAL_45]]] %[[VAL_32]]
43+
// CHECK: %[[VAL_47:.*]] = gpu.destroy_dn_tensor async {{\[}}%[[VAL_46]]] %[[VAL_34]]
44+
// CHECK: %[[VAL_48:.*]] = gpu.dealloc async {{\[}}%[[VAL_47]]] %[[VAL_38]] : memref<?xi8>
45+
// CHECK: %[[VAL_49:.*]] = gpu.dealloc async {{\[}}%[[VAL_48]]] %[[VAL_40]] : memref<?xi8>
46+
// CHECK: %[[VAL_50:.*]] = gpu.dealloc async {{\[}}%[[VAL_49]]] %[[VAL_42]] : memref<?xi8>
47+
// CHECK: %[[VAL_51:.*]] = gpu.dealloc async {{\[}}%[[VAL_50]]] %[[VAL_9]] : memref<?x?xf16>
48+
// CHECK: %[[VAL_52:.*]] = gpu.dealloc async {{\[}}%[[VAL_51]]] %[[VAL_16]] : memref<?x?xf16>
49+
// CHECK: %[[VAL_53:.*]] = gpu.memcpy async {{\[}}%[[VAL_52]]] %[[VAL_19]], %[[VAL_23]] : memref<?x?xf16>, memref<?x?xf16>
50+
// CHECK: %[[VAL_54:.*]] = gpu.dealloc async {{\[}}%[[VAL_53]]] %[[VAL_23]] : memref<?x?xf16>
51+
// CHECK: gpu.wait {{\[}}%[[VAL_54]]]
52+
// CHECK: %[[VAL_55:.*]] = bufferization.to_tensor %[[VAL_19]] : memref<?x?xf16>
53+
// CHECK: return %[[VAL_55]] : tensor<?x?xf16>
54+
// CHECK: }
55+
56+
#map = affine_map<(d0, d1, d2) -> (d0, d2)>
57+
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
58+
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
59+
module {
60+
func.func @matmul(%arg0: tensor<?x?xf16>, %arg1: tensor<?x?xf16>, %arg2: tensor<?x?xf16>) -> tensor<?x?xf16> {
61+
%0 = linalg.generic { DENSE24, indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<?x?xf16>, tensor<?x?xf16>) outs(%arg2 : tensor<?x?xf16>) {
62+
^bb0(%in: f16, %in_0: f16, %out: f16):
63+
%1 = arith.mulf %in, %in_0 : f16
64+
%2 = arith.addf %out, %1 : f16
65+
linalg.yield %2 : f16
66+
} -> tensor<?x?xf16>
67+
return %0 : tensor<?x?xf16>
68+
}
69+
}

0 commit comments

Comments
 (0)