|
41 | 41 | doc = "X(i,j) *= A(i,j) * B(j,i)"
|
42 | 42 | }
|
43 | 43 |
|
| 44 | +#CSR = #sparse_tensor.encoding<{ |
| 45 | + map = ( i, j ) -> (i : dense, j : compressed) |
| 46 | +}> |
| 47 | + |
44 | 48 |
|
45 | 49 | #BSR = #sparse_tensor.encoding<{
|
46 | 50 | map = ( i, j ) ->
|
@@ -89,6 +93,20 @@ func.func @mul_24(%arg0: tensor<4x8xf64>,
|
89 | 93 | return %0 : tensor<4x4xf64>
|
90 | 94 | }
|
91 | 95 |
|
| 96 | +func.func @mul_csr_bsr(%arg0: tensor<4x8xf64, #CSR>, |
| 97 | + %arg1: tensor<4x8xf64, #BSR>) -> tensor<4x4xf64> { |
| 98 | + %out = arith.constant dense<0.0> : tensor<4x4xf64> |
| 99 | + %0 = linalg.generic #trait_mul |
| 100 | + ins(%arg0, %arg1: tensor<4x8xf64, #CSR>, tensor<4x8xf64, #BSR>) |
| 101 | + outs(%out: tensor<4x4xf64>) { |
| 102 | + ^bb(%x: f64, %y : f64, %z : f64): |
| 103 | + %1 = arith.mulf %x, %y : f64 |
| 104 | + %2 = arith.addf %1, %z : f64 |
| 105 | + linalg.yield %2 : f64 |
| 106 | + } -> tensor<4x4xf64> |
| 107 | + return %0 : tensor<4x4xf64> |
| 108 | +} |
| 109 | + |
92 | 110 | func.func @mul_dense(%arg0: tensor<4x8xf64>,
|
93 | 111 | %arg1: tensor<4x8xf64>) -> tensor<4x4xf64> {
|
94 | 112 | %out = arith.constant dense<0.0> : tensor<4x4xf64>
|
@@ -132,18 +150,22 @@ func.func @mul_dense(%arg0: tensor<4x8xf64>,
|
132 | 150 |
|
133 | 151 | %2 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #BSR>
|
134 | 152 | %3 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #NV_24>
|
| 153 | + %4 = sparse_tensor.convert %td : tensor<4x8xf64> to tensor<4x8xf64, #CSR> |
135 | 154 |
|
136 | 155 | %d = call @mul_dense(%td, %td)
|
137 | 156 | : (tensor<4x8xf64>, tensor<4x8xf64>) -> tensor<4x4xf64>
|
138 | 157 | %s = call @mul(%td, %2)
|
139 | 158 | : (tensor<4x8xf64>, tensor<4x8xf64, #BSR>) -> tensor<4x4xf64>
|
140 | 159 | %s24 = call @mul_24(%td, %3)
|
141 | 160 | : (tensor<4x8xf64>, tensor<4x8xf64, #NV_24>) -> tensor<4x4xf64>
|
| 161 | + %scsr = call @mul_csr_bsr(%4, %2) |
| 162 | + : (tensor<4x8xf64, #CSR>, tensor<4x8xf64, #BSR>) -> tensor<4x4xf64> |
142 | 163 |
|
143 |
| - // CHECK-COUNT-3: ( ( 46, 115, 0, 0 ), ( 115, 306, 0, 0 ), ( 0, 0, 858, 1206 ), ( 0, 0, 1206, 1698 ) ) |
| 164 | + // CHECK-COUNT-4: ( ( 46, 115, 0, 0 ), ( 115, 306, 0, 0 ), ( 0, 0, 858, 1206 ), ( 0, 0, 1206, 1698 ) ) |
144 | 165 | call @dumpf64(%d) : (tensor<4x4xf64>) -> ()
|
145 | 166 | call @dumpf64(%s) : (tensor<4x4xf64>) -> ()
|
146 | 167 | call @dumpf64(%s24) : (tensor<4x4xf64>) -> ()
|
| 168 | + call @dumpf64(%scsr) : (tensor<4x4xf64>) -> () |
147 | 169 |
|
148 | 170 | return
|
149 | 171 | }
|
|
0 commit comments