|
| 1 | +// RUN: mlir-opt %s --sparse-compiler | \ |
| 2 | +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ |
| 3 | +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ |
| 4 | +// RUN: FileCheck %s |
| 5 | + |
| 6 | +#SparseVector = #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ] }> |
| 7 | + |
| 8 | +module { |
| 9 | + |
| 10 | + // |
| 11 | + // Sparse kernel. |
| 12 | + // |
| 13 | + func @sparse_dot(%a: tensor<1024xf32, #SparseVector>, |
| 14 | + %b: tensor<1024xf32, #SparseVector>) -> tensor<f32> { |
| 15 | + %x = linalg.init_tensor [] : tensor<f32> |
| 16 | + %dot = linalg.dot ins(%a, %b: tensor<1024xf32, #SparseVector>, |
| 17 | + tensor<1024xf32, #SparseVector>) |
| 18 | + outs(%x: tensor<f32>) -> tensor<f32> |
| 19 | + return %dot : tensor<f32> |
| 20 | + } |
| 21 | + |
| 22 | + // |
| 23 | + // Main driver. |
| 24 | + // |
| 25 | + func @entry() { |
| 26 | + // Setup two sparse vectors. |
| 27 | + %d1 = arith.constant sparse< |
| 28 | + [ [0], [1], [22], [23], [1022] ], [1.0, 2.0, 3.0, 4.0, 5.0] |
| 29 | + > : tensor<1024xf32> |
| 30 | + %d2 = arith.constant sparse< |
| 31 | + [ [22], [1022], [1023] ], [6.0, 7.0, 8.0] |
| 32 | + > : tensor<1024xf32> |
| 33 | + %s1 = sparse_tensor.convert %d1 : tensor<1024xf32> to tensor<1024xf32, #SparseVector> |
| 34 | + %s2 = sparse_tensor.convert %d2 : tensor<1024xf32> to tensor<1024xf32, #SparseVector> |
| 35 | + |
| 36 | + // Call the kernel and verify the output. |
| 37 | + // |
| 38 | + // CHECK: 53 |
| 39 | + // |
| 40 | + %0 = call @sparse_dot(%s1, %s2) : (tensor<1024xf32, #SparseVector>, |
| 41 | + tensor<1024xf32, #SparseVector>) -> tensor<f32> |
| 42 | + %1 = tensor.extract %0[] : tensor<f32> |
| 43 | + vector.print %1 : f32 |
| 44 | + |
| 45 | + // Release the resources. |
| 46 | + sparse_tensor.release %s1 : tensor<1024xf32, #SparseVector> |
| 47 | + sparse_tensor.release %s2 : tensor<1024xf32, #SparseVector> |
| 48 | + %m = bufferization.to_memref %0 : memref<f32> |
| 49 | + memref.dealloc %m : memref<f32> |
| 50 | + |
| 51 | + return |
| 52 | + } |
| 53 | +} |
0 commit comments