@@ -80,3 +80,51 @@ func.func @sparse_foreach_reinterpret_map(%6 : tensor<2x4xf64, #BSR>) -> tensor<
80
80
%9 = sparse_tensor.load %8 hasInserts : tensor <2 x4 xf64 , #BSR >
81
81
return %9 : tensor <2 x4 xf64 , #BSR >
82
82
}
83
+
84
+
85
+ // -----
86
+
87
+ #BSR = #sparse_tensor.encoding <{
88
+ map = ( i , j ) ->
89
+ ( i floordiv 2 : dense ,
90
+ j floordiv 2 : compressed ,
91
+ i mod 2 : dense ,
92
+ j mod 2 : dense
93
+ )
94
+ }>
95
+ // CHECK-DAG: #[[$remap:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 floordiv 2 : dense, d1 floordiv 2 : compressed, d0 mod 2 : dense, d1 mod 2 : dense) }>
96
+ // CHECK-DAG: #[[$demap:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : dense, d1 : compressed, d2 : dense, d3 : dense) }>
97
+
98
+ // CHECK-LABEL: func.func @sparse_assemble_reinterpret_map(
99
+ // CHECK-SAME: %[[VAL_0:.*]]: tensor<?xf64>,
100
+ // CHECK-SAME: %[[VAL_1:.*]]: tensor<?xindex>,
101
+ // CHECK-SAME: %[[VAL_2:.*]]: tensor<?xindex>) -> tensor<2x4xf64, #[[$remap]]> {
102
+ // CHECK: %[[VAL_3:.*]] = sparse_tensor.assemble %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : tensor<?xf64>, tensor<?xindex>, tensor<?xindex> to tensor<1x2x2x2xf64, #[[$demap]]>
103
+ // CHECK: %[[VAL_4:.*]] = sparse_tensor.reinterpret_map %[[VAL_3]] : tensor<1x2x2x2xf64, #[[$demap]]> to tensor<2x4xf64, #[[$remap]]>
104
+ // CHECK: return %[[VAL_4]] : tensor<2x4xf64, #[[$remap]]>
105
+ // CHECK: }
106
+ func.func @sparse_assemble_reinterpret_map (%val : tensor <?xf64 >, %pos:tensor <?xindex >, %crd:tensor <?xindex >) -> tensor <2 x4 xf64 , #BSR > {
107
+ %0 = sparse_tensor.assemble %val , %pos , %crd
108
+ : tensor <?xf64 >, tensor <?xindex >, tensor <?xindex > to tensor <2 x4 xf64 , #BSR >
109
+ return %0 : tensor <2 x4 xf64 , #BSR >
110
+ }
111
+
112
+ // CHECK-LABEL: func.func @sparse_disassemble_reinterpret_map(
113
+ // CHECK-SAME: %[[VAL_0:.*]]: tensor<2x4xf64, #[[$remap]]>,
114
+ // CHECK-SAME: %[[VAL_1:.*]]: tensor<?xf64>,
115
+ // CHECK-SAME: %[[VAL_2:.*]]: tensor<?xindex>,
116
+ // CHECK-SAME: %[[VAL_3:.*]]: tensor<?xindex>) -> (tensor<?xf64>, tensor<?xindex>, tensor<?xindex>) {
117
+ // CHECK: %[[VAL_4:.*]] = sparse_tensor.reinterpret_map %[[VAL_0]] : tensor<2x4xf64, #[[$remap]]> to tensor<1x2x2x2xf64, #[[$demap]]>
118
+ // CHECK: %[[VAL_5:.*]], %[[VAL_6:.*]]:2, %[[VAL_7:.*]], %[[VAL_8:.*]]:2 = sparse_tensor.disassemble %[[VAL_4]] : tensor<1x2x2x2xf64, #[[$demap]]>
119
+ // CHECK: return
120
+ // CHECK: }
121
+ func.func @sparse_disassemble_reinterpret_map (%sp : tensor <2 x4 xf64 , #BSR >,
122
+ %od : tensor <?xf64 >,
123
+ %op : tensor <?xindex >,
124
+ %oi : tensor <?xindex >)
125
+ -> (tensor <?xf64 >, tensor <?xindex >, tensor <?xindex >) {
126
+ %rd , %rp , %ri , %dl , %pl , %il = sparse_tensor.disassemble %sp : tensor <2 x4 xf64 , #BSR >
127
+ outs (%od , %op , %oi : tensor <?xf64 >, tensor <?xindex >, tensor <?xindex >)
128
+ -> tensor <?xf64 >, (tensor <?xindex >, tensor <?xindex >), index , (index , index )
129
+ return %rd , %rp , %ri : tensor <?xf64 >, tensor <?xindex >, tensor <?xindex >
130
+ }
0 commit comments