23
23
// REDEFINE: %{sparsifier_opts} = enable-runtime-library=false
24
24
// RUN: %{compile} | %{run} | FileCheck %s
25
25
26
+ #CCC = #sparse_tensor.encoding <{
27
+ map = (d0 , d1 , d2 ) -> (d0 : compressed , d1 : compressed , d2 : compressed),
28
+ posWidth = 64 ,
29
+ crdWidth = 32
30
+ }>
31
+
26
32
#BatchedCSR = #sparse_tensor.encoding <{
27
33
map = (d0 , d1 , d2 ) -> (d0 : dense , d1 : dense , d2 : compressed),
28
34
posWidth = 64 ,
35
41
crdWidth = 32
36
42
}>
37
43
38
- // Test with batched-CSR and CSR-dense.
44
+ //
45
+ // Test assembly operation with CCC, batched-CSR and CSR-dense.
46
+ //
39
47
module {
40
48
//
41
49
// Main driver.
@@ -44,6 +52,31 @@ module {
44
52
%c0 = arith.constant 0 : index
45
53
%f0 = arith.constant 0.0 : f32
46
54
55
+ //
56
+ // Setup CCC.
57
+ //
58
+
59
+ %data0 = arith.constant dense <
60
+ [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 , 7.0 , 8.0 ]> : tensor <8 xf32 >
61
+ %pos00 = arith.constant dense <
62
+ [ 0 , 3 ]> : tensor <2 xi64 >
63
+ %crd00 = arith.constant dense <
64
+ [ 0 , 2 , 3 ]> : tensor <3 xi32 >
65
+ %pos01 = arith.constant dense <
66
+ [ 0 , 2 , 4 , 5 ]> : tensor <4 xi64 >
67
+ %crd01 = arith.constant dense <
68
+ [ 0 , 1 , 1 , 2 , 1 ]> : tensor <5 xi32 >
69
+ %pos02 = arith.constant dense <
70
+ [ 0 , 2 , 4 , 5 , 7 , 8 ]> : tensor <6 xi64 >
71
+ %crd02 = arith.constant dense <
72
+ [ 0 , 1 , 0 , 1 , 0 , 0 , 1 , 0 ]> : tensor <8 xi32 >
73
+
74
+ %s0 = sparse_tensor.assemble %data0 , %pos00 , %crd00 , %pos01 , %crd01 , %pos02 , %crd02 :
75
+ tensor <8 xf32 >,
76
+ tensor <2 xi64 >, tensor <3 xi32 >,
77
+ tensor <4 xi64 >, tensor <5 xi32 >,
78
+ tensor <6 xi64 >, tensor <8 xi32 > to tensor <4 x3 x2 xf32 , #CCC >
79
+
47
80
//
48
81
// Setup BatchedCSR.
49
82
//
@@ -75,10 +108,15 @@ module {
75
108
//
76
109
// Verify.
77
110
//
111
+ // CHECK: ( ( ( 1, 2 ), ( 3, 4 ), ( 0, 0 ) ), ( ( 0, 0 ), ( 0, 0 ), ( 0, 0 ) ), ( ( 0, 0 ), ( 5, 0 ), ( 6, 7 ) ), ( ( 0, 0 ), ( 8, 0 ), ( 0, 0 ) ) )
78
112
// CHECK: ( ( ( 1, 2 ), ( 0, 3 ), ( 4, 0 ) ), ( ( 5, 6 ), ( 0, 0 ), ( 0, 7 ) ), ( ( 8, 9 ), ( 10, 11 ), ( 12, 13 ) ), ( ( 14, 0 ), ( 0, 15 ), ( 0, 16 ) ) )
79
113
// CHECK: ( ( ( 1, 2 ), ( 0, 3 ), ( 4, 0 ) ), ( ( 5, 6 ), ( 0, 0 ), ( 0, 7 ) ), ( ( 8, 9 ), ( 10, 11 ), ( 12, 13 ) ), ( ( 14, 0 ), ( 0, 15 ), ( 0, 16 ) ) )
80
114
//
81
115
116
+ %d0 = sparse_tensor.convert %s0 : tensor <4 x3 x2 xf32 , #CCC > to tensor <4 x3 x2 xf32 >
117
+ %v0 = vector.transfer_read %d0 [%c0 , %c0 , %c0 ], %f0 : tensor <4 x3 x2 xf32 >, vector <4 x3 x2 xf32 >
118
+ vector.print %v0 : vector <4 x3 x2 xf32 >
119
+
82
120
%d1 = sparse_tensor.convert %s1 : tensor <4 x3 x2 xf32 , #BatchedCSR > to tensor <4 x3 x2 xf32 >
83
121
%v1 = vector.transfer_read %d1 [%c0 , %c0 , %c0 ], %f0 : tensor <4 x3 x2 xf32 >, vector <4 x3 x2 xf32 >
84
122
vector.print %v1 : vector <4 x3 x2 xf32 >
@@ -88,6 +126,7 @@ module {
88
126
vector.print %v2 : vector <4 x3 x2 xf32 >
89
127
90
128
// FIXME: doing this explicitly crashes runtime
129
+ // bufferization.dealloc_tensor %s0 : tensor<4x3x2xf32, #CCC>
91
130
// bufferization.dealloc_tensor %s1 : tensor<4x3x2xf32, #BatchedCSR>
92
131
// bufferization.dealloc_tensor %s2 : tensor<4x3x2xf32, #CSRDense>
93
132
return
0 commit comments