Skip to content

Commit ca0e5b7

Browse files
author
Peiming Liu
committed
address comments
1 parent 520c703 commit ca0e5b7

File tree

2 files changed

+98
-12
lines changed

2 files changed

+98
-12
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class DenseLevel : public SparseTensorLevel {
9797
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
9898
ValueRange parentPos, Value inPadZone) const override {
9999
assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
100-
assert(!inPadZone && "Not implememnted");
100+
assert(!inPadZone && "Not implemented");
101101
Value p = parentPos.front();
102102
Value posLo = MULI(p, lvlSize);
103103
return {posLo, lvlSize};
@@ -117,7 +117,7 @@ class BatchLevel : public SparseTensorLevel {
117117

118118
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
119119
ValueRange parentPos, Value inPadZone) const override {
120-
assert(!inPadZone && "Not implememnted");
120+
assert(!inPadZone && "Not implemented");
121121
assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
122122
// No need to linearize the position for non-annotated tensors.
123123
return {C_IDX(0), lvlSize};
@@ -151,13 +151,14 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
151151

152152
SmallVector<Type, 2> types{b.getIndexType(), b.getIndexType()};
153153
scf::IfOp posRangeIf = b.create<scf::IfOp>(l, types, inPadZone, true);
154-
// True branch.
154+
// True branch, returns a "fake" empty range [0, 0) if parent
155+
// iterator is in pad zone.
155156
b.setInsertionPointToStart(posRangeIf.thenBlock());
156-
// Returns a "fake" empty range [0, 0) if parent iterator is in pad zone.
157+
157158
SmallVector<Value, 2> emptyRange{C_IDX(0), C_IDX(0)};
158159
b.create<scf::YieldOp>(l, emptyRange);
159160

160-
// False branch.
161+
// False branch, returns the actual range.
161162
b.setInsertionPointToStart(posRangeIf.elseBlock());
162163
auto [pLo, pHi] = loadRange();
163164
SmallVector<Value, 2> loadedRange{pLo, pHi};
@@ -179,7 +180,7 @@ class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
179180
ValueRange parentPos, Value inPadZone) const override {
180181
assert(parentPos.size() == 1 &&
181182
"loose-compressed level must be the first non-unique level.");
182-
assert(!inPadZone && "Not implememnted");
183+
assert(!inPadZone && "Not implemented");
183184
SmallVector<Value> memCrd(batchPrefix);
184185
Value p = parentPos.front();
185186
p = MULI(p, C_IDX(2));
@@ -200,7 +201,7 @@ class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
200201
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
201202
ValueRange parentPos, Value inPadZone) const override {
202203
assert(parentPos.size() == 1 || parentPos.size() == 2);
203-
assert(!inPadZone && "Not implememnted");
204+
assert(!inPadZone && "Not implemented");
204205
Value p = parentPos.front();
205206
Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
206207

@@ -221,7 +222,7 @@ class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
221222
ValueRange parentPos, Value inPadZone) const override {
222223
assert(parentPos.size() == 1 && isUnique() &&
223224
"n:m level can not be non-unique.");
224-
assert(!inPadZone && "Not implememnted");
225+
assert(!inPadZone && "Not implemented");
225226
// Each n:m blk has exactly n specified elements.
226227
auto n = getN(lt);
227228
Value posLo = MULI(parentPos.front(), C_IDX(n));
@@ -487,6 +488,7 @@ class SimpleWrapIterator : public SparseIterator {
487488
bool isBatchIterator() const override { return wrap->isBatchIterator(); }
488489
bool randomAccessible() const override { return wrap->randomAccessible(); };
489490
bool iteratableByFor() const override { return wrap->iteratableByFor(); };
491+
490492
SmallVector<Value> serialize() const override { return wrap->serialize(); };
491493
void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
492494
ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
@@ -601,9 +603,7 @@ class PadIterator : public SimpleWrapIterator {
601603
Value padHigh)
602604
: SimpleWrapIterator(std::move(wrap), IterKind::kPad,
603605
wrap->randomAccessible() ? 1 : 0),
604-
padLow(padLow), padHigh(padHigh) {
605-
// assert(!randomAccessible());
606-
}
606+
padLow(padLow), padHigh(padHigh) {}
607607

608608
// For LLVM-style RTTI.
609609
static bool classof(const SparseIterator *from) {
@@ -614,13 +614,20 @@ class PadIterator : public SimpleWrapIterator {
614614
return std::string("pad<") + wrap->getDebugInterfacePrefix() + ">";
615615
}
616616

617+
// Returns a pair of values for *upper*, *lower* bound respectively.
618+
ValuePair genForCond(OpBuilder &b, Location l) override {
619+
if (randomAccessible())
620+
return {getCrd(), upperBound(b, l)};
621+
return wrap->genForCond(b, l);
622+
}
623+
617624
// For padded dense iterator, we append a `inPadZone: bool` in addition to
618625
// values used by the wrapped iterator.
619626
ValueRange getCurPosition() const override { return getCursor(); }
620627

621628
SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
622629
SmallVector<Type> ret = wrap->getCursorValTypes(b);
623-
// Need a extra boolean value `inPadZone` for padded dense iterator.
630+
// Need an extra boolean value `inPadZone` for padded dense iterator.
624631
if (randomAccessible())
625632
ret.push_back(b.getI1Type());
626633

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -canonicalize | FileCheck %s
2+
3+
#CSR = #sparse_tensor.encoding<{
4+
map = (d0, d1) -> (d0 : dense, d1 : compressed)
5+
}>
6+
7+
#elemwise = {
8+
indexing_maps = [
9+
affine_map<(i,j) -> (i,j)>, // A
10+
affine_map<(i,j) -> (i,j)>, // B
11+
affine_map<(i,j) -> (i,j)> // X (out)
12+
],
13+
iterator_types = ["parallel", "parallel"],
14+
doc = "X(i,j) = A(i,j) OP B(i,j)"
15+
}
16+
17+
18+
// CHECK-LABEL: func.func @padded_mul(
19+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x4xf32, #sparse>,
20+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8xf32>) -> tensor<8x8xf32> {
21+
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant -1 : index
22+
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 6 : index
23+
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 8 : index
24+
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
25+
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
26+
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 2 : index
27+
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
28+
// CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<8x8xf32>
29+
// CHECK: %[[VAL_10:.*]] = linalg.fill ins(%[[VAL_8]] : f32) outs(%[[VAL_9]] : tensor<8x8xf32>) -> tensor<8x8xf32>
30+
// CHECK: %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4xf32, #sparse> to memref<?xindex>
31+
// CHECK: %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<4x4xf32, #sparse> to memref<?xindex>
32+
// CHECK: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4xf32, #sparse> to memref<?xf32>
33+
// CHECK: %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_10]] : memref<8x8xf32>
34+
// CHECK: linalg.fill ins(%[[VAL_8]] : f32) outs(%[[VAL_14]] : memref<8x8xf32>)
35+
// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_5]] {
36+
// CHECK: %[[VAL_16:.*]] = arith.subi %[[VAL_15]], %[[VAL_7]] : index
37+
// CHECK: %[[VAL_17:.*]] = arith.cmpi ult, %[[VAL_15]], %[[VAL_7]] : index
38+
// CHECK: %[[VAL_18:.*]] = arith.cmpi uge, %[[VAL_15]], %[[VAL_3]] : index
39+
// CHECK: %[[VAL_19:.*]] = arith.ori %[[VAL_17]], %[[VAL_18]] : i1
40+
// CHECK: %[[VAL_20:.*]]:2 = scf.if %[[VAL_19]] -> (index, index) {
41+
// CHECK: scf.yield %[[VAL_6]], %[[VAL_6]] : index, index
42+
// CHECK: } else {
43+
// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_16]]] : memref<?xindex>
44+
// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_15]], %[[VAL_2]] : index
45+
// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_22]]] : memref<?xindex>
46+
// CHECK: scf.yield %[[VAL_21]], %[[VAL_23]] : index, index
47+
// CHECK: }
48+
// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_20]]#0 to %[[VAL_20]]#1 step %[[VAL_5]] {
49+
// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]]] : memref<?xindex>
50+
// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_7]] : index
51+
// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]]] : memref<?xf32>
52+
// CHECK: %[[VAL_29:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : tensor<8x8xf32>
53+
// CHECK: %[[VAL_30:.*]] = arith.mulf %[[VAL_28]], %[[VAL_29]] : f32
54+
// CHECK: memref.store %[[VAL_30]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : memref<8x8xf32>
55+
// CHECK: } {"Emitted from" = "linalg.generic"}
56+
// CHECK: } {"Emitted from" = "linalg.generic"}
57+
// CHECK: %[[VAL_31:.*]] = bufferization.to_tensor %[[VAL_14]] : memref<8x8xf32>
58+
// CHECK: return %[[VAL_31]] : tensor<8x8xf32>
59+
// CHECK: }
60+
func.func @padded_mul(%arg0: tensor<4x4xf32, #CSR>, %arg1: tensor<8x8xf32>) -> tensor<8x8xf32> {
61+
%cst_0 = arith.constant 0.00000e+00 : f32
62+
%buf = tensor.empty() : tensor<8x8xf32>
63+
%s = linalg.fill ins(%cst_0 : f32) outs(%buf : tensor<8x8xf32>) -> tensor<8x8xf32>
64+
65+
%padded = tensor.pad %arg0 low[2, 2] high[2, 2] {
66+
^bb0(%arg75: index, %arg76: index):
67+
tensor.yield %cst_0 : f32
68+
} : tensor<4x4xf32, #CSR> to tensor<8x8xf32, #CSR>
69+
70+
%0 = linalg.generic #elemwise
71+
ins(%padded, %arg1: tensor<8x8xf32, #CSR>, tensor<8x8xf32>)
72+
outs(%s: tensor<8x8xf32>) {
73+
^bb(%a: f32, %b: f32, %x: f32):
74+
%0 = arith.mulf %a, %b : f32
75+
linalg.yield %0 : f32
76+
} -> tensor<8x8xf32>
77+
78+
return %0 : tensor<8x8xf32>
79+
}

0 commit comments

Comments
 (0)