Skip to content

Commit 619a888

Browse files
authored
[mlir][sparse][gpu] free all buffers allocated for spGEMM (#66813)
Yup, a bit of an oversight ;-)
1 parent 87b8c85 commit 619a888

File tree

2 files changed

+32
-9
lines changed

2 files changed

+32
-9
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -795,10 +795,10 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
795795
Value rowC = e1.getResult(0);
796796
token = e1.getAsyncToken();
797797
auto e2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), zero, token);
798-
Value colC = e2.getResult(0);
798+
Value colC = e2.getResult(0); // no free needed
799799
token = e2.getAsyncToken();
800800
auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token);
801-
Value valC = e3.getResult(0);
801+
Value valC = e3.getResult(0); // no free needed
802802
token = e3.getAsyncToken();
803803
Operation *spGenC =
804804
genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szm, szn, zero,
@@ -881,6 +881,17 @@ rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
881881
token = genCopyMemRef(rewriter, loc, rowH, rowC, token);
882882
token = genCopyMemRef(rewriter, loc, colH, colC, token);
883883
token = genCopyMemRef(rewriter, loc, valH, valC, token);
884+
token = genDeallocMemRef(rewriter, loc, rowA, token);
885+
token = genDeallocMemRef(rewriter, loc, colA, token);
886+
token = genDeallocMemRef(rewriter, loc, valA, token);
887+
token = genDeallocMemRef(rewriter, loc, rowB, token);
888+
token = genDeallocMemRef(rewriter, loc, colB, token);
889+
token = genDeallocMemRef(rewriter, loc, valB, token);
890+
token = genDeallocMemRef(rewriter, loc, rowC, token);
891+
token = genDeallocMemRef(rewriter, loc, colC, token);
892+
token = genDeallocMemRef(rewriter, loc, valC, token);
893+
token = genDeallocMemRef(rewriter, loc, buffer1, token);
894+
token = genDeallocMemRef(rewriter, loc, buffer2, token);
884895
tokens.push_back(token);
885896
genBlockingWait(rewriter, loc, tokens);
886897
tokens.clear();

mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
// CHECK-LABEL: func.func @matmulCSR(
77
// CHECK-SAME: %[[VAL_0:.*0]]: tensor<8x8xf32, #{{.*}}>,
8-
// CHECK-SAME: %[[VAL_1:.*1]]: tensor<8x8xf32, #{{.*}}>
8+
// CHECK-SAME: %[[VAL_1:.*1]]: tensor<8x8xf32, #{{.*}}>) -> tensor<8x8xf32, #{{.*}}> {
99
// CHECK: %[[VAL_2:.*]] = arith.constant 8 : index
1010
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
1111
// CHECK: %[[VAL_4:.*]] = arith.constant 9 : index
@@ -72,12 +72,24 @@
7272
// CHECK: %[[VAL_88:.*]] = gpu.memcpy async {{\[}}%[[VAL_87]]] %[[VAL_81]], %[[VAL_49]] : memref<?xindex>, memref<?xindex>
7373
// CHECK: %[[VAL_89:.*]] = gpu.memcpy async {{\[}}%[[VAL_88]]] %[[VAL_82]], %[[VAL_75]] : memref<?xindex>, memref<?xindex>
7474
// CHECK: %[[VAL_90:.*]] = gpu.memcpy async {{\[}}%[[VAL_89]]] %[[VAL_83]], %[[VAL_77]] : memref<?xf32>, memref<?xf32>
75-
// CHECK: gpu.wait {{\[}}%[[VAL_90]]]
76-
// CHECK: %[[VAL_91:.*]] = bufferization.to_tensor %[[VAL_83]] : memref<?xf32>
77-
// CHECK: %[[VAL_92:.*]] = bufferization.to_tensor %[[VAL_81]] : memref<?xindex>
78-
// CHECK: %[[VAL_93:.*]] = bufferization.to_tensor %[[VAL_82]] : memref<?xindex>
79-
// CHECK: %[[VAL_94:.*]] = sparse_tensor.pack %[[VAL_91]], %[[VAL_92]], %[[VAL_93]] : tensor<?xf32>, tensor<?xindex>, tensor<?xindex> to tensor<8x8xf32, #{{.*}}>
80-
// CHECK: return %[[VAL_94]] : tensor<8x8xf32, #{{.*}}>
75+
// CHECK: %[[VAL_91:.*]] = gpu.dealloc async {{.*}} : memref<?xindex>
76+
// CHECK: %[[VAL_92:.*]] = gpu.dealloc async {{.*}} : memref<?xindex>
77+
// CHECK: %[[VAL_93:.*]] = gpu.dealloc async {{.*}} : memref<?xf32>
78+
// CHECK: %[[VAL_94:.*]] = gpu.dealloc async {{.*}} : memref<?xindex>
79+
// CHECK: %[[VAL_95:.*]] = gpu.dealloc async {{.*}} : memref<?xindex>
80+
// CHECK: %[[VAL_96:.*]] = gpu.dealloc async {{.*}} : memref<?xf32>
81+
// CHECK: %[[VAL_97:.*]] = gpu.dealloc async {{.*}} : memref<?xindex>
82+
// CHECK: %[[VAL_98:.*]] = gpu.dealloc async {{.*}} : memref<?xindex>
83+
// CHECK: %[[VAL_99:.*]] = gpu.dealloc async {{.*}} : memref<?xf32>
84+
// CHECK: %[[VAL_a0:.*]] = gpu.dealloc async {{.*}} : memref<?xi8>
85+
// CHECK: %[[VAL_a1:.*]] = gpu.dealloc async {{.*}} : memref<?xi8>
86+
// CHECK: gpu.wait [%[[VAL_a1]]]
87+
// CHECK: %[[VAL_a2:.*]] = bufferization.to_tensor %[[VAL_83]] : memref<?xf32>
88+
// CHECK: %[[VAL_a3:.*]] = bufferization.to_tensor %[[VAL_81]] : memref<?xindex>
89+
// CHECK: %[[VAL_a4:.*]] = bufferization.to_tensor %[[VAL_82]] : memref<?xindex>
90+
// CHECK: %[[VAL_a5:.*]] = sparse_tensor.pack %[[VAL_a2]], %[[VAL_a3]], %[[VAL_a4]] : tensor<?xf32>, tensor<?xindex>, tensor<?xindex> to tensor<8x8xf32, #{{.*}}>
91+
// CHECK: return %[[VAL_a5]] : tensor<8x8xf32, #{{.*}}>
92+
// CHECK: }
8193
func.func @matmulCSR(%A: tensor<8x8xf32, #CSR>,
8294
%B: tensor<8x8xf32, #CSR>) -> tensor<8x8xf32, #CSR> {
8395
%init = bufferization.alloc_tensor() : tensor<8x8xf32, #CSR>

0 commit comments

Comments
 (0)