Skip to content

Commit ff6c84b

Browse files
committed
[mlir][sparse] generalize sparse storage format to many more types
Rationale: Narrower types for overhead storage yield a smaller memory footprint for sparse tensors and thus needs to be supported. Also, more value types need to be supported to deal with all kinds of kernels. Since the "one-size-fits-all" sparse storage scheme implementation is used instead of actual codegen, the library needs to be able to support all combinations of desired types. With some crafty templating and overloading, the actual code for this is kept reasonably sized though. Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D96819
1 parent 766ee10 commit ff6c84b

File tree

6 files changed

+376
-92
lines changed

6 files changed

+376
-92
lines changed
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
// RUN: mlir-opt %s \
2+
// RUN: --test-sparsification="lower ptr-type=2 ind-type=2 fast-output" \
3+
// RUN: --convert-linalg-to-loops \
4+
// RUN: --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
5+
// RUN: --std-bufferize --finalizing-bufferize \
6+
// RUN: --convert-scf-to-std --convert-vector-to-llvm --convert-std-to-llvm | \
7+
// RUN: TENSOR0="%mlir_integration_test_dir/data/test.mtx" \
8+
// RUN: mlir-cpu-runner \
9+
// RUN: -e entry -entry-point-result=void \
10+
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
11+
// RUN: FileCheck %s
12+
13+
//
14+
// Use descriptive names for opaque pointers.
15+
//
16+
!Filename = type !llvm.ptr<i8>
17+
!SparseTensor = type !llvm.ptr<i8>
18+
19+
#trait_sampled_dense_dense = {
20+
indexing_maps = [
21+
affine_map<(i,j,k) -> (i,j)>, // S
22+
affine_map<(i,j,k) -> (i,k)>, // A
23+
affine_map<(i,j,k) -> (k,j)>, // B
24+
affine_map<(i,j,k) -> (i,j)> // X (out)
25+
],
26+
sparse = [
27+
[ "S", "S" ], // S
28+
[ "D", "D" ], // A
29+
[ "D", "D" ], // B
30+
[ "D", "D" ] // X
31+
],
32+
iterator_types = ["parallel", "parallel", "reduction"],
33+
doc = "X(i,j) += S(i,j) SUM_k A(i,k) B(k,j)"
34+
}
35+
36+
//
37+
// Integration test that lowers a kernel annotated as sparse to
38+
// actual sparse code, initializes a matching sparse storage scheme
39+
// from file, and runs the resulting code with the JIT compiler.
40+
//
41+
module {
42+
//
43+
// The kernel expressed as an annotated Linalg op. The kernel
44+
// computes a sampled matrix matrix multiplication.
45+
//
46+
func @sampled_dense_dense(%argS: !SparseTensor,
47+
%arga: tensor<?x?xf32>,
48+
%argb: tensor<?x?xf32>,
49+
%argx: tensor<?x?xf32>) -> tensor<?x?xf32> {
50+
%args = linalg.sparse_tensor %argS : !SparseTensor to tensor<?x?xf32>
51+
%0 = linalg.generic #trait_sampled_dense_dense
52+
ins(%args, %arga, %argb: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>)
53+
outs(%argx: tensor<?x?xf32>) {
54+
^bb(%s: f32, %a: f32, %b: f32, %x: f32):
55+
%0 = mulf %a, %b : f32
56+
%1 = mulf %s, %0 : f32
57+
%2 = addf %x, %1 : f32
58+
linalg.yield %2 : f32
59+
} -> tensor<?x?xf32>
60+
return %0 : tensor<?x?xf32>
61+
}
62+
63+
//
64+
// Runtime support library that is called directly from here.
65+
//
66+
func private @getTensorFilename(index) -> (!Filename)
67+
func private @newSparseTensor(!Filename, memref<?xi1>, index, index, index) -> (!SparseTensor)
68+
func private @delSparseTensor(!SparseTensor) -> ()
69+
func private @print_memref_f32(%ptr : tensor<*xf32>)
70+
71+
//
72+
// Main driver that reads matrix from file and calls the sparse kernel.
73+
//
74+
func @entry() {
75+
%d0 = constant 0.0 : f32
76+
%c0 = constant 0 : index
77+
%c1 = constant 1 : index
78+
%c2 = constant 2 : index
79+
%c5 = constant 5 : index
80+
%c10 = constant 10 : index
81+
82+
// Mark both dimensions of the matrix as sparse and encode the
83+
// storage scheme types (this must match the metadata in the
84+
// trait and compiler switches).
85+
%annotations = alloc(%c2) : memref<?xi1>
86+
%sparse = constant true
87+
store %sparse, %annotations[%c0] : memref<?xi1>
88+
store %sparse, %annotations[%c1] : memref<?xi1>
89+
%i32 = constant 3 : index
90+
%f32 = constant 1 : index
91+
92+
// Setup memory for the dense matrices and initialize.
93+
%adata = alloc(%c5, %c10) : memref<?x?xf32>
94+
%bdata = alloc(%c10, %c5) : memref<?x?xf32>
95+
%xdata = alloc(%c5, %c5) : memref<?x?xf32>
96+
scf.for %i = %c0 to %c5 step %c1 {
97+
scf.for %j = %c0 to %c5 step %c1 {
98+
store %d0, %xdata[%i, %j] : memref<?x?xf32>
99+
}
100+
%p = addi %i, %c1 : index
101+
%q = index_cast %p : index to i32
102+
%d = sitofp %q : i32 to f32
103+
scf.for %j = %c0 to %c10 step %c1 {
104+
store %d, %adata[%i, %j] : memref<?x?xf32>
105+
store %d, %bdata[%j, %i] : memref<?x?xf32>
106+
}
107+
}
108+
%a = tensor_load %adata : memref<?x?xf32>
109+
%b = tensor_load %bdata : memref<?x?xf32>
110+
%x = tensor_load %xdata : memref<?x?xf32>
111+
112+
// Read the sparse matrix from file, construct sparse storage
113+
// according to <sparse,sparse> in memory, and call the kernel.
114+
%fileName = call @getTensorFilename(%c0) : (index) -> (!Filename)
115+
%s = call @newSparseTensor(%fileName, %annotations, %i32, %i32, %f32)
116+
: (!Filename, memref<?xi1>, index, index, index) -> (!SparseTensor)
117+
%0 = call @sampled_dense_dense(%s, %a, %b, %x)
118+
: (!SparseTensor, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
119+
120+
// Print the result for verification.
121+
//
122+
// CHECK: ( 10, 0, 0, 56, 0 )
123+
// CHECK: ( 0, 80, 0, 0, 250 )
124+
// CHECK: ( 0, 0, 270, 0, 0 )
125+
// CHECK: ( 164, 0, 0, 640, 0 )
126+
// CHECK: ( 0, 520, 0, 0, 1250 )
127+
//
128+
%r = tensor_to_memref %0 : memref<?x?xf32>
129+
scf.for %i = %c0 to %c5 step %c1 {
130+
%v = vector.transfer_read %r[%i, %c0], %d0: memref<?x?xf32>, vector<5xf32>
131+
vector.print %v : vector<5xf32>
132+
}
133+
134+
// Release the resources.
135+
call @delSparseTensor(%s) : (!SparseTensor) -> ()
136+
dealloc %adata : memref<?x?xf32>
137+
dealloc %bdata : memref<?x?xf32>
138+
dealloc %xdata : memref<?x?xf32>
139+
140+
return
141+
}
142+
}

mlir/integration_test/Sparse/CPU/sparse_sum.mlir

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ module {
5555
// Runtime support library that is called directly from here.
5656
//
5757
func private @getTensorFilename(index) -> (!Filename)
58-
func private @newSparseTensor(!Filename, memref<?xi1>) -> (!SparseTensor)
58+
func private @newSparseTensor(!Filename, memref<?xi1>, index, index, index) -> (!SparseTensor)
5959
func private @delSparseTensor(!SparseTensor) -> ()
6060
func private @print_memref_f64(%ptr : tensor<*xf64>)
6161

@@ -68,12 +68,15 @@ module {
6868
%c1 = constant 1 : index
6969
%c2 = constant 2 : index
7070

71-
// Mark both dimensions of the matrix as sparse
72-
// (this must match the annotation in the trait).
71+
// Mark both dimensions of the matrix as sparse and encode the
72+
// storage scheme types (this must match the metadata in the
73+
// trait and compiler switches).
7374
%annotations = alloc(%c2) : memref<?xi1>
7475
%sparse = constant true
7576
store %sparse, %annotations[%c0] : memref<?xi1>
7677
store %sparse, %annotations[%c1] : memref<?xi1>
78+
%i64 = constant 2 : index
79+
%f64 = constant 0 : index
7780

7881
// Setup memory for a single reduction scalar,
7982
// initialized to zero.
@@ -84,8 +87,8 @@ module {
8487
// Read the sparse matrix from file, construct sparse storage
8588
// according to <sparse,sparse> in memory, and call the kernel.
8689
%fileName = call @getTensorFilename(%c0) : (index) -> (!Filename)
87-
%a = call @newSparseTensor(%fileName, %annotations)
88-
: (!Filename, memref<?xi1>) -> (!SparseTensor)
90+
%a = call @newSparseTensor(%fileName, %annotations, %i64, %i64, %f64)
91+
: (!Filename, memref<?xi1>, index, index, index) -> (!SparseTensor)
8992
%0 = call @kernel_sum_reduce(%a, %x)
9093
: (!SparseTensor, tensor<f64>) -> tensor<f64>
9194

mlir/lib/Dialect/Linalg/Transforms/SparseLowering.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ class TensorToPointersConverter
7373
Type eltType = resType.cast<ShapedType>().getElementType();
7474
StringRef name;
7575
if (eltType.isIndex() || eltType.isInteger(64))
76-
name = "sparsePtrsI64";
76+
name = "sparsePointers64";
77+
else if (eltType.isInteger(32))
78+
name = "sparsePointers32";
7779
else
7880
return failure();
7981
rewriter.replaceOpWithNewOp<CallOp>(
@@ -95,7 +97,9 @@ class TensorToIndicesConverter
9597
Type eltType = resType.cast<ShapedType>().getElementType();
9698
StringRef name;
9799
if (eltType.isIndex() || eltType.isInteger(64))
98-
name = "sparseIndxsI64";
100+
name = "sparseIndices64";
101+
else if (eltType.isInteger(32))
102+
name = "sparseIndices32";
99103
else
100104
return failure();
101105
rewriter.replaceOpWithNewOp<CallOp>(
@@ -117,7 +121,9 @@ class TensorToValuesConverter
117121
Type eltType = resType.cast<ShapedType>().getElementType();
118122
StringRef name;
119123
if (eltType.isF64())
120-
name = "sparseValsF64";
124+
name = "sparseValuesF64";
125+
else if (eltType.isF32())
126+
name = "sparseValuesF32";
121127
else
122128
return failure();
123129
rewriter.replaceOpWithNewOp<CallOp>(

0 commit comments

Comments
 (0)