Skip to content

Commit 8394ec9

Browse files
authored
[mlir][sparse] add a few more cases to sparse_tensor.print test (#83338)
1 parent 68f0edf commit 8394ec9

File tree

1 file changed

+56
-2
lines changed

1 file changed

+56
-2
lines changed

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print.mlir

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,24 @@
102102
)
103103
}>
104104

105+
#BSR0 = #sparse_tensor.encoding<{
106+
map = (i, j) -> (
107+
i floordiv 2 : dense,
108+
j floordiv 4 : compressed,
109+
i mod 2 : dense,
110+
j mod 4 : dense
111+
)
112+
}>
113+
114+
#BSC0 = #sparse_tensor.encoding<{
115+
map = (i, j) -> (
116+
j floordiv 4 : dense,
117+
i floordiv 2 : compressed,
118+
i mod 2 : dense,
119+
j mod 4 : dense
120+
)
121+
}>
122+
105123
module {
106124

107125
//
@@ -114,6 +132,21 @@ module {
114132
[ 0, 0, 0, 0, 0, 0, 0, 0 ],
115133
[ 0, 0, 3, 4, 0, 5, 0, 0 ] ]> : tensor<4x8xi32>
116134

135+
%XO = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #AllDense>
136+
%XT = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #AllDenseT>
137+
138+
// CHECK: ---- Sparse Tensor ----
139+
// CHECK-NEXT: nse = 32
140+
// CHECK-NEXT: values : ( 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 5, 0, 0,
141+
// CHECK-NEXT: ----
142+
sparse_tensor.print %XO : tensor<4x8xi32, #AllDense>
143+
144+
// CHECK-NEXT: ---- Sparse Tensor ----
145+
// CHECK-NEXT: nse = 32
146+
// CHECK-NEXT: values : ( 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0,
147+
// CHECK-NEXT: ----
148+
sparse_tensor.print %XT : tensor<4x8xi32, #AllDenseT>
149+
117150
%a = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #CSR>
118151
%b = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #DCSR>
119152
%c = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #CSC>
@@ -122,9 +155,10 @@ module {
122155
%f = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSRC>
123156
%g = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSC>
124157
%h = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSCC>
158+
%i = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSR0>
159+
%j = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSC0>
125160

126-
//
127-
// CHECK: ---- Sparse Tensor ----
161+
// CHECK-NEXT: ---- Sparse Tensor ----
128162
// CHECK-NEXT: nse = 5
129163
// CHECK-NEXT: pos[1] : ( 0, 2, 2, 2, 5,
130164
// CHECK-NEXT: crd[1] : ( 0, 2, 2, 3, 5,
@@ -200,7 +234,25 @@ module {
200234
// CHECK-NEXT: ----
201235
sparse_tensor.print %h : tensor<4x8xi32, #BSCC>
202236

237+
// CHECK-NEXT: ---- Sparse Tensor ----
238+
// CHECK-NEXT: nse = 24
239+
// CHECK-NEXT: pos[1] : ( 0, 1, 3,
240+
// CHECK-NEXT: crd[1] : ( 0, 0, 1,
241+
// CHECK-NEXT: values : ( 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0, 0, 5, 0, 0,
242+
// CHECK-NEXT: ----
243+
sparse_tensor.print %i : tensor<4x8xi32, #BSR0>
244+
245+
// CHECK-NEXT: ---- Sparse Tensor ----
246+
// CHECK-NEXT: nse = 24
247+
// CHECK-NEXT: pos[1] : ( 0, 2, 3,
248+
// CHECK-NEXT: crd[1] : ( 0, 1, 1,
249+
// CHECK-NEXT: values : ( 1, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 0, 0, 0, 0, 0, 5, 0, 0,
250+
// CHECK-NEXT: ----
251+
sparse_tensor.print %j : tensor<4x8xi32, #BSC0>
252+
203253
// Release the resources.
254+
bufferization.dealloc_tensor %XO : tensor<4x8xi32, #AllDense>
255+
bufferization.dealloc_tensor %XT : tensor<4x8xi32, #AllDenseT>
204256
bufferization.dealloc_tensor %a : tensor<4x8xi32, #CSR>
205257
bufferization.dealloc_tensor %b : tensor<4x8xi32, #DCSR>
206258
bufferization.dealloc_tensor %c : tensor<4x8xi32, #CSC>
@@ -209,6 +261,8 @@ module {
209261
bufferization.dealloc_tensor %f : tensor<4x8xi32, #BSRC>
210262
bufferization.dealloc_tensor %g : tensor<4x8xi32, #BSC>
211263
bufferization.dealloc_tensor %h : tensor<4x8xi32, #BSCC>
264+
bufferization.dealloc_tensor %i : tensor<4x8xi32, #BSR0>
265+
bufferization.dealloc_tensor %j : tensor<4x8xi32, #BSC0>
212266

213267
return
214268
}

0 commit comments

Comments
 (0)