Skip to content

[mlir][sparse] support tensor.pad on CSR tensors #90687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ class DenseLevel : public SparseTensorLevel {
ValueRange getLvlBuffers() const override { return {}; }

ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
ValueRange parentPos) const override {
ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
assert(!inPadZone && "Not implemented");
Value p = parentPos.front();
Value posLo = MULI(p, lvlSize);
return {posLo, lvlSize};
Expand All @@ -115,7 +116,8 @@ class BatchLevel : public SparseTensorLevel {
ValueRange getLvlBuffers() const override { return {}; }

ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
ValueRange parentPos) const override {
ValueRange parentPos, Value inPadZone) const override {
assert(!inPadZone && "Not implemented");
assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
// No need to linearize the position for non-annotated tensors.
return {C_IDX(0), lvlSize};
Expand All @@ -129,18 +131,42 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}

ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
ValueRange parentPos) const override {
ValueRange parentPos, Value inPadZone) const override {

assert(parentPos.size() == 1 &&
"compressed level must be the first non-unique level.");
Value p = parentPos.front();

SmallVector<Value> memCrd(batchPrefix);
memCrd.push_back(p);
Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
memCrd.back() = ADDI(p, C_IDX(1));
Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
return {pLo, pHi};
auto loadRange = [&b, l, parentPos, batchPrefix, this]() -> ValuePair {
Value p = parentPos.front();
SmallVector<Value> memCrd(batchPrefix);
memCrd.push_back(p);
Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
memCrd.back() = ADDI(p, C_IDX(1));
Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
return {pLo, pHi};
};

if (inPadZone == nullptr)
return loadRange();

SmallVector<Type, 2> types{b.getIndexType(), b.getIndexType()};
scf::IfOp posRangeIf = b.create<scf::IfOp>(l, types, inPadZone, true);
// True branch, returns a "fake" empty range [0, 0) if parent
// iterator is in pad zone.
b.setInsertionPointToStart(posRangeIf.thenBlock());

SmallVector<Value, 2> emptyRange{C_IDX(0), C_IDX(0)};
b.create<scf::YieldOp>(l, emptyRange);

// False branch, returns the actual range.
b.setInsertionPointToStart(posRangeIf.elseBlock());
auto [pLo, pHi] = loadRange();
SmallVector<Value, 2> loadedRange{pLo, pHi};
b.create<scf::YieldOp>(l, loadedRange);

b.setInsertionPointAfter(posRangeIf);
ValueRange posRange = posRangeIf.getResults();
return {posRange.front(), posRange.back()};
}
};

Expand All @@ -151,9 +177,10 @@ class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}

ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
ValueRange parentPos) const override {
ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 &&
"loose-compressed level must be the first non-unique level.");
assert(!inPadZone && "Not implemented");
SmallVector<Value> memCrd(batchPrefix);
Value p = parentPos.front();
p = MULI(p, C_IDX(2));
Expand All @@ -172,8 +199,9 @@ class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}

ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
ValueRange parentPos) const override {
ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 || parentPos.size() == 2);
assert(!inPadZone && "Not implemented");
Value p = parentPos.front();
Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;

Expand All @@ -191,9 +219,10 @@ class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}

ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
ValueRange parentPos) const override {
ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 && isUnique() &&
"n:m level can not be non-unique.");
assert(!inPadZone && "Not implemented");
// Each n:m blk has exactly n specified elements.
auto n = getN(lt);
Value posLo = MULI(parentPos.front(), C_IDX(n));
Expand Down Expand Up @@ -325,23 +354,7 @@ class TrivialIterator : public ConcreteIterator {
};

void genInitImpl(OpBuilder &b, Location l,
const SparseIterator *parent) override {

if (isBatchIterator() && batchCrds.size() <= stl.lvl)
batchCrds.resize(stl.lvl + 1, nullptr);

Value c0 = C_IDX(0);
ValueRange pPos = c0;
// If the parent iterator is a batch iterator, we also start from 0 (but
// on a different batch).
if (parent && !parent->isBatchIterator())
pPos = parent->getCurPosition();

ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
// Seek to the lowest position.
seek(posLo);
}
const SparseIterator *parent) override;

ValuePair genForCond(OpBuilder &b, Location l) override {
if (randomAccessible())
Expand Down Expand Up @@ -465,15 +478,17 @@ class DedupIterator : public ConcreteIterator {
// A util base-iterator that delegates all methods to the wrapped iterator.
class SimpleWrapIterator : public SparseIterator {
public:
SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind)
: SparseIterator(kind, *wrap), wrap(std::move(wrap)) {}
SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind,
unsigned extraCursorVal = 0)
: SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}

SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
return wrap->getCursorValTypes(b);
}
bool isBatchIterator() const override { return wrap->isBatchIterator(); }
bool randomAccessible() const override { return wrap->randomAccessible(); };
bool iteratableByFor() const override { return wrap->iteratableByFor(); };

SmallVector<Value> serialize() const override { return wrap->serialize(); };
void deserialize(ValueRange vs) override { wrap->deserialize(vs); };
ValueRange getCurPosition() const override { return wrap->getCurPosition(); }
Expand Down Expand Up @@ -586,10 +601,9 @@ class PadIterator : public SimpleWrapIterator {
public:
PadIterator(std::unique_ptr<SparseIterator> &&wrap, Value padLow,
Value padHigh)
: SimpleWrapIterator(std::move(wrap), IterKind::kPad), padLow(padLow),
padHigh(padHigh) {
assert(!randomAccessible() && "Not implemented.");
}
: SimpleWrapIterator(std::move(wrap), IterKind::kPad,
wrap->randomAccessible() ? 1 : 0),
padLow(padLow), padHigh(padHigh) {}

// For LLVM-style RTTI.
static bool classof(const SparseIterator *from) {
Expand All @@ -600,6 +614,26 @@ class PadIterator : public SimpleWrapIterator {
return std::string("pad<") + wrap->getDebugInterfacePrefix() + ">";
}

// Returns a pair of values for *upper*, *lower* bound respectively.
ValuePair genForCond(OpBuilder &b, Location l) override {
if (randomAccessible())
return {getCrd(), upperBound(b, l)};
return wrap->genForCond(b, l);
}

// For padded dense iterator, we append a `inPadZone: bool` in addition to
// values used by the wrapped iterator.
ValueRange getCurPosition() const override { return getCursor(); }

SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
SmallVector<Type> ret = wrap->getCursorValTypes(b);
// Need an extra boolean value `inPadZone` for padded dense iterator.
if (randomAccessible())
ret.push_back(b.getI1Type());

return ret;
}

// The upper bound after padding becomes `size + padLow + padHigh`.
Value upperBound(OpBuilder &b, Location l) const override {
return ADDI(ADDI(wrap->upperBound(b, l), padLow), padHigh);
Expand All @@ -613,6 +647,14 @@ class PadIterator : public SimpleWrapIterator {

void locateImpl(OpBuilder &b, Location l, Value crd) override {
assert(randomAccessible());
wrap->locate(b, l, SUBI(crd, padLow));

// inPadZone = crd < padLow || crd >= size + padLow.
Value inPadLow = CMPI(ult, crd, padLow);
Value inPadHigh = CMPI(uge, crd, ADDI(wrap->upperBound(b, l), padLow));
getMutCursorVals().back() = ORI(inPadLow, inPadHigh);

updateCrd(crd);
}

Value padLow, padHigh;
Expand Down Expand Up @@ -1227,6 +1269,33 @@ ValueRange NonEmptySubSectIterator::inflateSubSectTree(
return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
}

void TrivialIterator::genInitImpl(OpBuilder &b, Location l,
const SparseIterator *parent) {

if (isBatchIterator() && batchCrds.size() <= stl.lvl)
batchCrds.resize(stl.lvl + 1, nullptr);

Value c0 = C_IDX(0);
ValueRange pPos = c0;
Value inPadZone = nullptr;
// If the parent iterator is a batch iterator, we also start from 0 (but
// on a different batch).
if (parent && !parent->isBatchIterator()) {
pPos = parent->getCurPosition();
if (llvm::isa<PadIterator>(parent) && parent->randomAccessible()) {
// A padded dense iterator create "sparse" padded zone, which need to be
// handled specially.
inPadZone = pPos.back();
pPos = pPos.drop_back();
}
}

ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos, inPadZone);
// Seek to the lowest position.
seek(posLo);
}

void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
const SparseIterator *) {
Value c0 = C_IDX(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ class SparseTensorLevel {
///
/// For a sparse level, [posLo, loopHi) specifies the range of index pointer
/// to load coordinate from the coordinate buffer.
virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l,
ValueRange batchPrefix,
ValueRange parentPos) const = 0;
virtual std::pair<Value, Value>
peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
ValueRange parentPos, Value inPadZone = nullptr) const = 0;

Level getLevel() const { return lvl; }
LevelType getLT() const { return lt; }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification -canonicalize | FileCheck %s

#CSR = #sparse_tensor.encoding<{
map = (d0, d1) -> (d0 : dense, d1 : compressed)
}>

#elemwise = {
indexing_maps = [
affine_map<(i,j) -> (i,j)>, // A
affine_map<(i,j) -> (i,j)>, // B
affine_map<(i,j) -> (i,j)> // X (out)
],
iterator_types = ["parallel", "parallel"],
doc = "X(i,j) = A(i,j) OP B(i,j)"
}


// CHECK-LABEL: func.func @padded_mul(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x4xf32, #sparse>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8xf32>) -> tensor<8x8xf32> {
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant -1 : index
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 6 : index
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 8 : index
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<8x8xf32>
// CHECK: %[[VAL_10:.*]] = linalg.fill ins(%[[VAL_8]] : f32) outs(%[[VAL_9]] : tensor<8x8xf32>) -> tensor<8x8xf32>
// CHECK: %[[VAL_11:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<4x4xf32, #sparse> to memref<?xindex>
// CHECK: %[[VAL_12:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<4x4xf32, #sparse> to memref<?xindex>
// CHECK: %[[VAL_13:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<4x4xf32, #sparse> to memref<?xf32>
// CHECK: %[[VAL_14:.*]] = bufferization.to_memref %[[VAL_10]] : memref<8x8xf32>
// CHECK: linalg.fill ins(%[[VAL_8]] : f32) outs(%[[VAL_14]] : memref<8x8xf32>)
// CHECK: scf.for %[[VAL_15:.*]] = %[[VAL_6]] to %[[VAL_4]] step %[[VAL_5]] {
// CHECK: %[[VAL_16:.*]] = arith.subi %[[VAL_15]], %[[VAL_7]] : index
// CHECK: %[[VAL_17:.*]] = arith.cmpi ult, %[[VAL_15]], %[[VAL_7]] : index
// CHECK: %[[VAL_18:.*]] = arith.cmpi uge, %[[VAL_15]], %[[VAL_3]] : index
// CHECK: %[[VAL_19:.*]] = arith.ori %[[VAL_17]], %[[VAL_18]] : i1
// CHECK: %[[VAL_20:.*]]:2 = scf.if %[[VAL_19]] -> (index, index) {
// CHECK: scf.yield %[[VAL_6]], %[[VAL_6]] : index, index
// CHECK: } else {
// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_16]]] : memref<?xindex>
// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_15]], %[[VAL_2]] : index
// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_22]]] : memref<?xindex>
// CHECK: scf.yield %[[VAL_21]], %[[VAL_23]] : index, index
// CHECK: }
// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_20]]#0 to %[[VAL_20]]#1 step %[[VAL_5]] {
// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_24]]] : memref<?xindex>
// CHECK: %[[VAL_27:.*]] = arith.addi %[[VAL_26]], %[[VAL_7]] : index
// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_24]]] : memref<?xf32>
// CHECK: %[[VAL_29:.*]] = tensor.extract %[[VAL_1]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : tensor<8x8xf32>
// CHECK: %[[VAL_30:.*]] = arith.mulf %[[VAL_28]], %[[VAL_29]] : f32
// CHECK: memref.store %[[VAL_30]], %[[VAL_14]]{{\[}}%[[VAL_15]], %[[VAL_27]]] : memref<8x8xf32>
// CHECK: } {"Emitted from" = "linalg.generic"}
// CHECK: } {"Emitted from" = "linalg.generic"}
// CHECK: %[[VAL_31:.*]] = bufferization.to_tensor %[[VAL_14]] : memref<8x8xf32>
// CHECK: return %[[VAL_31]] : tensor<8x8xf32>
// CHECK: }
func.func @padded_mul(%arg0: tensor<4x4xf32, #CSR>, %arg1: tensor<8x8xf32>) -> tensor<8x8xf32> {
%cst_0 = arith.constant 0.00000e+00 : f32
%buf = tensor.empty() : tensor<8x8xf32>
%s = linalg.fill ins(%cst_0 : f32) outs(%buf : tensor<8x8xf32>) -> tensor<8x8xf32>

%padded = tensor.pad %arg0 low[2, 2] high[2, 2] {
^bb0(%arg75: index, %arg76: index):
tensor.yield %cst_0 : f32
} : tensor<4x4xf32, #CSR> to tensor<8x8xf32, #CSR>

%0 = linalg.generic #elemwise
ins(%padded, %arg1: tensor<8x8xf32, #CSR>, tensor<8x8xf32>)
outs(%s: tensor<8x8xf32>) {
^bb(%a: f32, %b: f32, %x: f32):
%0 = arith.mulf %a, %b : f32
linalg.yield %0 : f32
} -> tensor<8x8xf32>

return %0 : tensor<8x8xf32>
}
Loading