Skip to content

Commit 46e77b5

Browse files
committed
[mlir][sparse] add a sparse quantized_matmul example to integration test
Note that this revision adds a very tiny bit of constant folding in the sparse compiler lattice construction. Although I am generally trying to avoid such canonicalizations (and rely on other passes to fix this instead), the benefits of avoiding a very expensive disjunction lattice construction justify having this special code (at least for now). Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D109939
1 parent d4e1617 commit 46e77b5

File tree

3 files changed

+94
-0
lines changed

3 files changed

+94
-0
lines changed

mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ class Merger {
230230
Value v1);
231231

232232
private:
233+
bool isZero(unsigned e) const;
233234
bool maybeZero(unsigned e) const;
234235
bool isInvariant(unsigned e) const;
235236
Type inferType(unsigned e, Value src);

mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,11 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
489489
// ---+---+---+ ---+---+---+
490490
// !x | 0 | y | !x | 0 |-y |
491491
// x | x |x+y| x | x |x-y|
492+
//
493+
// TODO: remove this zero "folding" in favor of external pass into linalg
494+
//
495+
if (isZero(tensorExps[e].children.e1))
496+
return buildLattices(tensorExps[e].children.e0, i);
492497
return takeDisj(kind, // take binary disjunction
493498
buildLattices(tensorExps[e].children.e0, i),
494499
buildLattices(tensorExps[e].children.e1, i));
@@ -511,6 +516,18 @@ Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
511516
return buildTensorExp(op, yield->getOperand(0));
512517
}
513518

519+
/// Only returns true if we are certain this is a zero.
520+
bool Merger::isZero(unsigned e) const {
521+
if (tensorExps[e].kind == kInvariant) {
522+
if (auto c = tensorExps[e].val.getDefiningOp<ConstantIntOp>())
523+
return c.getValue() == 0;
524+
if (auto c = tensorExps[e].val.getDefiningOp<ConstantFloatOp>())
525+
return c.getValue().isZero();
526+
}
527+
return false;
528+
}
529+
530+
/// Only returns false if we are certain this is a nonzero.
514531
bool Merger::maybeZero(unsigned e) const {
515532
if (tensorExps[e].kind == kInvariant) {
516533
if (auto c = tensorExps[e].val.getDefiningOp<ConstantIntOp>())
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// RUN: mlir-opt %s \
2+
// RUN: --linalg-generalize-named-ops \
3+
// RUN: --sparsification --sparse-tensor-conversion \
4+
// RUN: --convert-vector-to-scf --convert-scf-to-std \
5+
// RUN: --func-bufferize --tensor-constant-bufferize --tensor-bufferize \
6+
// RUN: --std-bufferize --finalizing-bufferize --lower-affine \
7+
// RUN: --convert-vector-to-llvm --convert-memref-to-llvm \
8+
// RUN: --convert-std-to-llvm --reconcile-unrealized-casts | \
9+
// RUN: mlir-cpu-runner \
10+
// RUN: -e entry -entry-point-result=void \
11+
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
12+
// RUN: FileCheck %s
13+
14+
#DCSR = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
15+
16+
// An example of a quantized sparse matmul. With the zero offset for the
17+
// sparse input, the sparse compiler generates very efficient code for the
18+
// x(i,j) += (ext(a(i,k)) - 2) * ext(b(k,j))
19+
// operation.
20+
module {
21+
22+
func @quantized_matmul(%input1: tensor<5x3xi8>,
23+
%input2: tensor<3x6xi8, #DCSR>,
24+
%output: tensor<5x6xi32>) -> tensor<5x6xi32> {
25+
%c0 = constant 0 : i32
26+
%c2 = constant 2 : i32
27+
%0 = linalg.quantized_matmul
28+
ins(%input1, %input2, %c2, %c0 : tensor<5x3xi8>, tensor<3x6xi8, #DCSR>, i32, i32)
29+
outs(%output : tensor<5x6xi32>) -> tensor<5x6xi32>
30+
return %0: tensor<5x6xi32>
31+
}
32+
33+
func @entry() {
34+
%c0 = constant 0 : index
35+
%i0 = constant 0 : i32
36+
37+
%input1 = constant dense<[
38+
[ -128, 3, 127 ],
39+
[ 0, 0, 0 ],
40+
[ 11, 1, 0 ],
41+
[ 0, 5, -1 ],
42+
[ 13, 0, 3 ]
43+
]> : tensor<5x3xi8>
44+
45+
%input2 = constant dense<[
46+
[ 127, 0, -128, 0, 0, 3 ],
47+
[ 0, 0, 0, 0, 0, 0 ],
48+
[ 0, 0, 0, 100, 10, 0 ]
49+
]> : tensor<3x6xi8>
50+
51+
%sparse_input2 = sparse_tensor.convert %input2 : tensor<3x6xi8> to tensor<3x6xi8, #DCSR>
52+
53+
// Call the kernel.
54+
%output = constant dense<0> : tensor<5x6xi32>
55+
%0 = call @quantized_matmul(%input1, %sparse_input2, %output)
56+
: (tensor<5x3xi8>,
57+
tensor<3x6xi8, #DCSR>,
58+
tensor<5x6xi32>) -> tensor<5x6xi32>
59+
60+
//
61+
// Verify the output.
62+
//
63+
// CHECK: ( ( -16510, 0, 16640, 12500, 1250, -390 ),
64+
// CHECK-SAME: ( -254, 0, 256, -200, -20, -6 ),
65+
// CHECK-SAME: ( 1143, 0, -1152, -200, -20, 27 ),
66+
// CHECK-SAME: ( -254, 0, 256, -300, -30, -6 ),
67+
// CHECK-SAME: ( 1397, 0, -1408, 100, 10, 33 ) )
68+
//
69+
%m = memref.buffer_cast %0 : memref<5x6xi32>
70+
%v = vector.transfer_read %m[%c0, %c0], %i0
71+
: memref<5x6xi32>, vector<5x6xi32>
72+
vector.print %v : vector<5x6xi32>
73+
74+
return
75+
}
76+
}

0 commit comments

Comments
 (0)