Skip to content

Commit 2400f70

Browse files
authored
[mlir][sparse] add assemble test for Batched-CSR and CSR-Dense (#81660)
These are formats supported by PyTorch sparse, so good to make sure that our assemble instructions work on these.
1 parent f0b271e commit 2400f70

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
//--------------------------------------------------------------------------------------------------
2+
// WHEN CREATING A NEW TEST, PLEASE JUST COPY & PASTE WITHOUT EDITS.
3+
//
4+
// Set-up that's shared across all tests in this directory. In principle, this
5+
// config could be moved to lit.local.cfg. However, there are downstream users that
6+
// do not use these LIT config files. Hence why this is kept inline.
7+
//
8+
// DEFINE: %{sparsifier_opts} = enable-runtime-library=true
9+
// DEFINE: %{sparsifier_opts_sve} = enable-arm-sve=true %{sparsifier_opts}
10+
// DEFINE: %{compile} = mlir-opt %s --sparsifier="%{sparsifier_opts}"
11+
// DEFINE: %{compile_sve} = mlir-opt %s --sparsifier="%{sparsifier_opts_sve}"
12+
// DEFINE: %{run_libs} = -shared-libs=%mlir_c_runner_utils,%mlir_runner_utils
13+
// DEFINE: %{run_opts} = -e entry -entry-point-result=void
14+
// DEFINE: %{run} = mlir-cpu-runner %{run_opts} %{run_libs}
15+
// DEFINE: %{run_sve} = %mcr_aarch64_cmd --march=aarch64 --mattr="+sve" %{run_opts} %{run_libs}
16+
//
17+
// DEFINE: %{env} =
18+
//--------------------------------------------------------------------------------------------------
19+
20+
// RUN: %{compile} | %{run} | FileCheck %s
21+
//
22+
// Do the same run, but now with direct IR generation.
23+
// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false
24+
// RUN: %{compile} | %{run} | FileCheck %s
25+
26+
#BatchedCSR = #sparse_tensor.encoding<{
27+
map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed),
28+
posWidth = 64,
29+
crdWidth = 32
30+
}>
31+
32+
#CSRDense = #sparse_tensor.encoding<{
33+
map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense),
34+
posWidth = 64,
35+
crdWidth = 32
36+
}>
37+
38+
// Test with batched-CSR and CSR-dense.
39+
module {
40+
//
41+
// Main driver.
42+
//
43+
func.func @entry() {
44+
%c0 = arith.constant 0 : index
45+
%f0 = arith.constant 0.0 : f32
46+
47+
//
48+
// Setup BatchedCSR.
49+
//
50+
51+
%data1 = arith.constant dense<
52+
[ 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
53+
9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0 ]> : tensor<16xf32>
54+
%pos1 = arith.constant dense<
55+
[ 0, 2, 3, 4, 6, 6, 7, 9, 11, 13, 14, 15, 16 ]> : tensor<13xi64>
56+
%crd1 = arith.constant dense<
57+
[ 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1]> : tensor<16xi32>
58+
59+
%s1 = sparse_tensor.assemble %data1, %pos1, %crd1 : tensor<16xf32>, tensor<13xi64>, tensor<16xi32> to tensor<4x3x2xf32, #BatchedCSR>
60+
61+
//
62+
// Setup CSRDense.
63+
//
64+
65+
%data2 = arith.constant dense<
66+
[ 1.0, 2.0, 0.0, 3.0, 4.0, 0.0, 5.0, 6.0, 0.0, 7.0, 8.0,
67+
9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 0.0, 0.0, 15.0, 0.0, 16.0 ]> : tensor<22xf32>
68+
%pos2 = arith.constant dense<
69+
[ 0, 3, 5, 8, 11 ]> : tensor<5xi64>
70+
%crd2 = arith.constant dense<
71+
[ 0, 1, 2, 0, 2, 0, 1, 2, 0, 1, 2 ]> : tensor<11xi32>
72+
73+
%s2 = sparse_tensor.assemble %data2, %pos2, %crd2 : tensor<22xf32>, tensor<5xi64>, tensor<11xi32> to tensor<4x3x2xf32, #CSRDense>
74+
75+
//
76+
// Verify.
77+
//
78+
// CHECK: ( ( ( 1, 2 ), ( 0, 3 ), ( 4, 0 ) ), ( ( 5, 6 ), ( 0, 0 ), ( 0, 7 ) ), ( ( 8, 9 ), ( 10, 11 ), ( 12, 13 ) ), ( ( 14, 0 ), ( 0, 15 ), ( 0, 16 ) ) )
79+
// CHECK: ( ( ( 1, 2 ), ( 0, 3 ), ( 4, 0 ) ), ( ( 5, 6 ), ( 0, 0 ), ( 0, 7 ) ), ( ( 8, 9 ), ( 10, 11 ), ( 12, 13 ) ), ( ( 14, 0 ), ( 0, 15 ), ( 0, 16 ) ) )
80+
//
81+
82+
%d1 = sparse_tensor.convert %s1 : tensor<4x3x2xf32, #BatchedCSR> to tensor<4x3x2xf32>
83+
%v1 = vector.transfer_read %d1[%c0, %c0, %c0], %f0 : tensor<4x3x2xf32>, vector<4x3x2xf32>
84+
vector.print %v1 : vector<4x3x2xf32>
85+
86+
%d2 = sparse_tensor.convert %s2 : tensor<4x3x2xf32, #CSRDense> to tensor<4x3x2xf32>
87+
%v2 = vector.transfer_read %d1[%c0, %c0, %c0], %f0 : tensor<4x3x2xf32>, vector<4x3x2xf32>
88+
vector.print %v2 : vector<4x3x2xf32>
89+
90+
// FIXME: doing this explicitly crashes runtime
91+
// bufferization.dealloc_tensor %s1 : tensor<4x3x2xf32, #BatchedCSR>
92+
// bufferization.dealloc_tensor %s2 : tensor<4x3x2xf32, #CSRDense>
93+
return
94+
}
95+
}

0 commit comments

Comments
 (0)