-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][sparse] support tensor.pad on CSR tensors #90687
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-mlir Author: Peiming Liu (PeimingLiu) ChangesFull diff: https://github.com/llvm/llvm-project/pull/90687.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index caf55072ce32e6..112b9f6c252786 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -95,8 +95,9 @@ class DenseLevel : public SparseTensorLevel {
ValueRange getLvlBuffers() const override { return {}; }
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
+ assert(!inPadZone && "Not implememnted");
Value p = parentPos.front();
Value posLo = MULI(p, lvlSize);
return {posLo, lvlSize};
@@ -115,7 +116,8 @@ class BatchLevel : public SparseTensorLevel {
ValueRange getLvlBuffers() const override { return {}; }
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
+ assert(!inPadZone && "Not implememnted");
assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
// No need to linearize the position for non-annotated tensors.
return {C_IDX(0), lvlSize};
@@ -129,18 +131,41 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 &&
"compressed level must be the first non-unique level.");
- Value p = parentPos.front();
- SmallVector<Value> memCrd(batchPrefix);
- memCrd.push_back(p);
- Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
- memCrd.back() = ADDI(p, C_IDX(1));
- Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
- return {pLo, pHi};
+ auto loadRange = [&b, l, parentPos, batchPrefix, this]() -> ValuePair {
+ Value p = parentPos.front();
+ SmallVector<Value> memCrd(batchPrefix);
+ memCrd.push_back(p);
+ Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
+ memCrd.back() = ADDI(p, C_IDX(1));
+ Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
+ return {pLo, pHi};
+ };
+
+ if (inPadZone == nullptr)
+ return loadRange();
+
+ SmallVector<Type, 2> types{b.getIndexType(), b.getIndexType()};
+ scf::IfOp posRangeIf = b.create<scf::IfOp>(l, types, inPadZone, true);
+ // True branch.
+ b.setInsertionPointToStart(posRangeIf.thenBlock());
+ // Returns a "fake" empty range [0, 0) if parent iterator is in pad zone.
+ SmallVector<Value, 2> emptyRange{C_IDX(0), C_IDX(0)};
+ b.create<scf::YieldOp>(l, emptyRange);
+
+ // False branch.
+ b.setInsertionPointToStart(posRangeIf.elseBlock());
+ auto [pLo, pHi] = loadRange();
+ SmallVector<Value, 2> loadedRange{pLo, pHi};
+ b.create<scf::YieldOp>(l, loadedRange);
+
+ b.setInsertionPointAfter(posRangeIf);
+ ValueRange posRange = posRangeIf.getResults();
+ return {posRange.front(), posRange.back()};
}
};
@@ -151,9 +176,10 @@ class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 &&
"loose-compressed level must be the first non-unique level.");
+ assert(!inPadZone && "Not implememnted");
SmallVector<Value> memCrd(batchPrefix);
Value p = parentPos.front();
p = MULI(p, C_IDX(2));
@@ -172,8 +198,9 @@ class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 || parentPos.size() == 2);
+ assert(!inPadZone && "Not implememnted");
Value p = parentPos.front();
Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
@@ -191,9 +218,10 @@ class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 && isUnique() &&
"n:m level can not be non-unique.");
+ assert(!inPadZone && "Not implememnted");
// Each n:m blk has exactly n specified elements.
auto n = getN(lt);
Value posLo = MULI(parentPos.front(), C_IDX(n));
@@ -325,23 +353,7 @@ class TrivialIterator : public ConcreteIterator {
};
void genInitImpl(OpBuilder &b, Location l,
- const SparseIterator *parent) override {
-
- if (isBatchIterator() && batchCrds.size() <= stl.lvl)
- batchCrds.resize(stl.lvl + 1, nullptr);
-
- Value c0 = C_IDX(0);
- ValueRange pPos = c0;
- // If the parent iterator is a batch iterator, we also start from 0 (but
- // on a different batch).
- if (parent && !parent->isBatchIterator())
- pPos = parent->getCurPosition();
-
- ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
- std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
- // Seek to the lowest position.
- seek(posLo);
- }
+ const SparseIterator *parent) override;
ValuePair genForCond(OpBuilder &b, Location l) override {
if (randomAccessible())
@@ -465,8 +477,9 @@ class DedupIterator : public ConcreteIterator {
// A util base-iterator that delegates all methods to the wrapped iterator.
class SimpleWrapIterator : public SparseIterator {
public:
- SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind)
- : SparseIterator(kind, *wrap), wrap(std::move(wrap)) {}
+ SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind,
+ unsigned extraCursorVal = 0)
+ : SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}
SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
return wrap->getCursorValTypes(b);
@@ -586,9 +599,10 @@ class PadIterator : public SimpleWrapIterator {
public:
PadIterator(std::unique_ptr<SparseIterator> &&wrap, Value padLow,
Value padHigh)
- : SimpleWrapIterator(std::move(wrap), IterKind::kPad), padLow(padLow),
- padHigh(padHigh) {
- assert(!randomAccessible() && "Not implemented.");
+ : SimpleWrapIterator(std::move(wrap), IterKind::kPad,
+ wrap->randomAccessible() ? 1 : 0),
+ padLow(padLow), padHigh(padHigh) {
+ // assert(!randomAccessible());
}
// For LLVM-style RTTI.
@@ -600,6 +614,19 @@ class PadIterator : public SimpleWrapIterator {
return std::string("pad<") + wrap->getDebugInterfacePrefix() + ">";
}
+ // For padded dense iterator, we append a `inPadZone: bool` in addition to
+ // values used by the wrapped iterator.
+ ValueRange getCurPosition() const override { return getCursor(); }
+
+ SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
+ SmallVector<Type> ret = wrap->getCursorValTypes(b);
+ // Need a extra boolean value `inPadZone` for padded dense iterator.
+ if (randomAccessible())
+ ret.push_back(b.getI1Type());
+
+ return ret;
+ }
+
// The upper bound after padding becomes `size + padLow + padHigh`.
Value upperBound(OpBuilder &b, Location l) const override {
return ADDI(ADDI(wrap->upperBound(b, l), padLow), padHigh);
@@ -613,6 +640,14 @@ class PadIterator : public SimpleWrapIterator {
void locateImpl(OpBuilder &b, Location l, Value crd) override {
assert(randomAccessible());
+ wrap->locate(b, l, SUBI(crd, padLow));
+
+ // inPadZone = crd < padLow || crd >= size + padLow.
+ Value inPadLow = CMPI(ult, crd, padLow);
+ Value inPadHigh = CMPI(uge, crd, ADDI(wrap->upperBound(b, l), padLow));
+ getMutCursorVals().back() = ORI(inPadLow, inPadHigh);
+
+ updateCrd(crd);
}
Value padLow, padHigh;
@@ -1227,6 +1262,33 @@ ValueRange NonEmptySubSectIterator::inflateSubSectTree(
return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
}
+void TrivialIterator::genInitImpl(OpBuilder &b, Location l,
+ const SparseIterator *parent) {
+
+ if (isBatchIterator() && batchCrds.size() <= stl.lvl)
+ batchCrds.resize(stl.lvl + 1, nullptr);
+
+ Value c0 = C_IDX(0);
+ ValueRange pPos = c0;
+ Value inPadZone = nullptr;
+ // If the parent iterator is a batch iterator, we also start from 0 (but
+ // on a different batch).
+ if (parent && !parent->isBatchIterator()) {
+ pPos = parent->getCurPosition();
+ if (llvm::isa<PadIterator>(parent) && parent->randomAccessible()) {
+ // A padded dense iterator create "sparse" padded zone, which need to be
+ // handled specially.
+ inPadZone = pPos.back();
+ pPos = pPos.drop_back();
+ }
+ }
+
+ ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
+ std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos, inPadZone);
+ // Seek to the lowest position.
+ seek(posLo);
+}
+
void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
const SparseIterator *) {
Value c0 = C_IDX(0);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 2e7eeb2a05f998..120a806536f190 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -46,9 +46,9 @@ class SparseTensorLevel {
///
/// For a sparse level, [posLo, loopHi) specifies the range of index pointer
/// to load coordinate from the coordinate buffer.
- virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l,
- ValueRange batchPrefix,
- ValueRange parentPos) const = 0;
+ virtual std::pair<Value, Value>
+ peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+ ValueRange parentPos, Value inPadZone = nullptr) const = 0;
Level getLevel() const { return lvl; }
LevelType getLT() const { return lt; }
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
index 92fbbf54558237..50dd989416e2a0 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
@@ -30,16 +30,8 @@
// Do the same run, but now with direct IR generation and VLA vectorization.
// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
-#CCCC = #sparse_tensor.encoding<{
- map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : compressed)
-}>
-
-#CDCD = #sparse_tensor.encoding<{
- map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : dense, d2 : compressed, d3 : dense)
-}>
-
-#DCCD = #sparse_tensor.encoding<{
- map = (d0, d1, d2, d3) -> (d0 : dense, d1 : compressed, d2 : compressed, d3 : dense)
+#CDCC_NHWC = #sparse_tensor.encoding<{
+ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : dense, d2 : compressed, d3 : compressed)
}>
// Creates and returns 4-D buffer of size (%s1, %s2, %s3, %s4) filled with the value %f
@@ -66,7 +58,7 @@ func.func @conv_2d_nhwc_hwcf(%arg0: tensor<3x8x8x3xf32>, %arg1: tensor<5x5x3x1xf
return %ret : tensor<3x8x8x1xf32>
}
-func.func @conv_2d_nhwc_hwcf_CCCC(%arg0: tensor<3x8x8x3xf32, #CCCC>, %arg1: tensor<5x5x3x1xf32>) -> tensor<3x8x8x1xf32> {
+func.func @conv_2d_nhwc_hwcf_CDCC_NHWC(%arg0: tensor<3x8x8x3xf32, #CDCC_NHWC>, %arg1: tensor<5x5x3x1xf32>) -> tensor<3x8x8x1xf32> {
%cst_0 = arith.constant 0.00000e+00 : f32
%buf = tensor.empty() : tensor<3x8x8x1xf32>
%s = linalg.fill ins(%cst_0 : f32) outs(%buf : tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32>
@@ -74,11 +66,11 @@ func.func @conv_2d_nhwc_hwcf_CCCC(%arg0: tensor<3x8x8x3xf32, #CCCC>, %arg1: tens
%padded = tensor.pad %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0] {
^bb0(%arg75: index, %arg76: index, %arg77: index, %arg78: index):
tensor.yield %cst_0 : f32
- } : tensor<3x8x8x3xf32, #CCCC> to tensor<3x12x12x3xf32, #CCCC>
+ } : tensor<3x8x8x3xf32, #CDCC_NHWC> to tensor<3x12x12x3xf32, #CDCC_NHWC>
%ret = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
- ins (%padded, %arg1: tensor<3x12x12x3xf32, #CCCC>, tensor<5x5x3x1xf32>)
+ ins (%padded, %arg1: tensor<3x12x12x3xf32, #CDCC_NHWC>, tensor<5x5x3x1xf32>)
outs (%s: tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32>
return %ret : tensor<3x8x8x1xf32>
}
@@ -105,8 +97,8 @@ func.func @main() {
%dense_ret = call @conv_2d_nhwc_hwcf(%static_input, %static_filter, %static_output) : (tensor<3x8x8x3xf32>, tensor<5x5x3x1xf32>, tensor<3x8x8x1xf32>) -> (tensor<3x8x8x1xf32>)
- %in2D_nhwc_CCCC = sparse_tensor.convert %static_input : tensor<3x8x8x3xf32> to tensor<3x8x8x3xf32, #CCCC>
- %CCCC_ret = call @conv_2d_nhwc_hwcf_CCCC(%in2D_nhwc_CCCC, %static_filter) : (tensor<3x8x8x3xf32, #CCCC>, tensor<5x5x3x1xf32>) -> (tensor<3x8x8x1xf32>)
+ %in2D_nhwc_CDCC_NHWC = sparse_tensor.convert %static_input : tensor<3x8x8x3xf32> to tensor<3x8x8x3xf32, #CDCC_NHWC>
+ %CDCC_NHWC_ret = call @conv_2d_nhwc_hwcf_CDCC_NHWC(%in2D_nhwc_CDCC_NHWC, %static_filter) : (tensor<3x8x8x3xf32, #CDCC_NHWC>, tensor<5x5x3x1xf32>) -> (tensor<3x8x8x1xf32>)
// CHECK: ( ( ( ( 108 ), ( 160 ), ( 196 ), ( 196 ), ( 196 ), ( 196 ), ( 144 ), ( 108 ) ),
@@ -161,17 +153,17 @@ func.func @main() {
// CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
// CHECK-SAME: ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
// CHECK-SAME: ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ) ) )
- %CCCC_v = vector.transfer_read %CCCC_ret[%c0, %c0, %c0, %c0], %zero
+ %CDCC_NHWC_v = vector.transfer_read %CDCC_NHWC_ret[%c0, %c0, %c0, %c0], %zero
: tensor<3x8x8x1xf32>, vector<3x8x8x1xf32>
- vector.print %CCCC_v : vector<3x8x8x1xf32>
+ vector.print %CDCC_NHWC_v : vector<3x8x8x1xf32>
bufferization.dealloc_tensor %static_filter : tensor<5x5x3x1xf32>
bufferization.dealloc_tensor %static_input : tensor<3x8x8x3xf32>
bufferization.dealloc_tensor %static_output : tensor<3x8x8x1xf32>
- bufferization.dealloc_tensor %CCCC_ret : tensor<3x8x8x1xf32>
+ bufferization.dealloc_tensor %CDCC_NHWC_ret : tensor<3x8x8x1xf32>
- bufferization.dealloc_tensor %in2D_nhwc_CCCC : tensor<3x8x8x3xf32, #CCCC>
+ bufferization.dealloc_tensor %in2D_nhwc_CDCC_NHWC : tensor<3x8x8x3xf32, #CDCC_NHWC>
return
}
|
@llvm/pr-subscribers-mlir-sparse Author: Peiming Liu (PeimingLiu) ChangesFull diff: https://github.com/llvm/llvm-project/pull/90687.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
index caf55072ce32e6..112b9f6c252786 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
@@ -95,8 +95,9 @@ class DenseLevel : public SparseTensorLevel {
ValueRange getLvlBuffers() const override { return {}; }
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
+ assert(!inPadZone && "Not implememnted");
Value p = parentPos.front();
Value posLo = MULI(p, lvlSize);
return {posLo, lvlSize};
@@ -115,7 +116,8 @@ class BatchLevel : public SparseTensorLevel {
ValueRange getLvlBuffers() const override { return {}; }
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
+ assert(!inPadZone && "Not implememnted");
assert(parentPos.size() == 1 && "Dense level can not be non-unique.");
// No need to linearize the position for non-annotated tensors.
return {C_IDX(0), lvlSize};
@@ -129,18 +131,41 @@ class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 &&
"compressed level must be the first non-unique level.");
- Value p = parentPos.front();
- SmallVector<Value> memCrd(batchPrefix);
- memCrd.push_back(p);
- Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
- memCrd.back() = ADDI(p, C_IDX(1));
- Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
- return {pLo, pHi};
+ auto loadRange = [&b, l, parentPos, batchPrefix, this]() -> ValuePair {
+ Value p = parentPos.front();
+ SmallVector<Value> memCrd(batchPrefix);
+ memCrd.push_back(p);
+ Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
+ memCrd.back() = ADDI(p, C_IDX(1));
+ Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
+ return {pLo, pHi};
+ };
+
+ if (inPadZone == nullptr)
+ return loadRange();
+
+ SmallVector<Type, 2> types{b.getIndexType(), b.getIndexType()};
+ scf::IfOp posRangeIf = b.create<scf::IfOp>(l, types, inPadZone, true);
+ // True branch.
+ b.setInsertionPointToStart(posRangeIf.thenBlock());
+ // Returns a "fake" empty range [0, 0) if parent iterator is in pad zone.
+ SmallVector<Value, 2> emptyRange{C_IDX(0), C_IDX(0)};
+ b.create<scf::YieldOp>(l, emptyRange);
+
+ // False branch.
+ b.setInsertionPointToStart(posRangeIf.elseBlock());
+ auto [pLo, pHi] = loadRange();
+ SmallVector<Value, 2> loadedRange{pLo, pHi};
+ b.create<scf::YieldOp>(l, loadedRange);
+
+ b.setInsertionPointAfter(posRangeIf);
+ ValueRange posRange = posRangeIf.getResults();
+ return {posRange.front(), posRange.back()};
}
};
@@ -151,9 +176,10 @@ class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 &&
"loose-compressed level must be the first non-unique level.");
+ assert(!inPadZone && "Not implememnted");
SmallVector<Value> memCrd(batchPrefix);
Value p = parentPos.front();
p = MULI(p, C_IDX(2));
@@ -172,8 +198,9 @@ class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 || parentPos.size() == 2);
+ assert(!inPadZone && "Not implememnted");
Value p = parentPos.front();
Value segHi = parentPos.size() == 2 ? parentPos.back() : nullptr;
@@ -191,9 +218,10 @@ class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}
ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
- ValueRange parentPos) const override {
+ ValueRange parentPos, Value inPadZone) const override {
assert(parentPos.size() == 1 && isUnique() &&
"n:m level can not be non-unique.");
+ assert(!inPadZone && "Not implememnted");
// Each n:m blk has exactly n specified elements.
auto n = getN(lt);
Value posLo = MULI(parentPos.front(), C_IDX(n));
@@ -325,23 +353,7 @@ class TrivialIterator : public ConcreteIterator {
};
void genInitImpl(OpBuilder &b, Location l,
- const SparseIterator *parent) override {
-
- if (isBatchIterator() && batchCrds.size() <= stl.lvl)
- batchCrds.resize(stl.lvl + 1, nullptr);
-
- Value c0 = C_IDX(0);
- ValueRange pPos = c0;
- // If the parent iterator is a batch iterator, we also start from 0 (but
- // on a different batch).
- if (parent && !parent->isBatchIterator())
- pPos = parent->getCurPosition();
-
- ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
- std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos);
- // Seek to the lowest position.
- seek(posLo);
- }
+ const SparseIterator *parent) override;
ValuePair genForCond(OpBuilder &b, Location l) override {
if (randomAccessible())
@@ -465,8 +477,9 @@ class DedupIterator : public ConcreteIterator {
// A util base-iterator that delegates all methods to the wrapped iterator.
class SimpleWrapIterator : public SparseIterator {
public:
- SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind)
- : SparseIterator(kind, *wrap), wrap(std::move(wrap)) {}
+ SimpleWrapIterator(std::unique_ptr<SparseIterator> &&wrap, IterKind kind,
+ unsigned extraCursorVal = 0)
+ : SparseIterator(kind, *wrap, extraCursorVal), wrap(std::move(wrap)) {}
SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
return wrap->getCursorValTypes(b);
@@ -586,9 +599,10 @@ class PadIterator : public SimpleWrapIterator {
public:
PadIterator(std::unique_ptr<SparseIterator> &&wrap, Value padLow,
Value padHigh)
- : SimpleWrapIterator(std::move(wrap), IterKind::kPad), padLow(padLow),
- padHigh(padHigh) {
- assert(!randomAccessible() && "Not implemented.");
+ : SimpleWrapIterator(std::move(wrap), IterKind::kPad,
+ wrap->randomAccessible() ? 1 : 0),
+ padLow(padLow), padHigh(padHigh) {
+ // assert(!randomAccessible());
}
// For LLVM-style RTTI.
@@ -600,6 +614,19 @@ class PadIterator : public SimpleWrapIterator {
return std::string("pad<") + wrap->getDebugInterfacePrefix() + ">";
}
+ // For padded dense iterator, we append a `inPadZone: bool` in addition to
+ // values used by the wrapped iterator.
+ ValueRange getCurPosition() const override { return getCursor(); }
+
+ SmallVector<Type> getCursorValTypes(OpBuilder &b) const override {
+ SmallVector<Type> ret = wrap->getCursorValTypes(b);
+ // Need a extra boolean value `inPadZone` for padded dense iterator.
+ if (randomAccessible())
+ ret.push_back(b.getI1Type());
+
+ return ret;
+ }
+
// The upper bound after padding becomes `size + padLow + padHigh`.
Value upperBound(OpBuilder &b, Location l) const override {
return ADDI(ADDI(wrap->upperBound(b, l), padLow), padHigh);
@@ -613,6 +640,14 @@ class PadIterator : public SimpleWrapIterator {
void locateImpl(OpBuilder &b, Location l, Value crd) override {
assert(randomAccessible());
+ wrap->locate(b, l, SUBI(crd, padLow));
+
+ // inPadZone = crd < padLow || crd >= size + padLow.
+ Value inPadLow = CMPI(ult, crd, padLow);
+ Value inPadHigh = CMPI(uge, crd, ADDI(wrap->upperBound(b, l), padLow));
+ getMutCursorVals().back() = ORI(inPadLow, inPadHigh);
+
+ updateCrd(crd);
}
Value padLow, padHigh;
@@ -1227,6 +1262,33 @@ ValueRange NonEmptySubSectIterator::inflateSubSectTree(
return p->inflateSubSectTree(b, l, reduc, visitDenseSubSect);
}
+void TrivialIterator::genInitImpl(OpBuilder &b, Location l,
+ const SparseIterator *parent) {
+
+ if (isBatchIterator() && batchCrds.size() <= stl.lvl)
+ batchCrds.resize(stl.lvl + 1, nullptr);
+
+ Value c0 = C_IDX(0);
+ ValueRange pPos = c0;
+ Value inPadZone = nullptr;
+ // If the parent iterator is a batch iterator, we also start from 0 (but
+ // on a different batch).
+ if (parent && !parent->isBatchIterator()) {
+ pPos = parent->getCurPosition();
+ if (llvm::isa<PadIterator>(parent) && parent->randomAccessible()) {
+ // A padded dense iterator create "sparse" padded zone, which need to be
+ // handled specially.
+ inPadZone = pPos.back();
+ pPos = pPos.drop_back();
+ }
+ }
+
+ ValueRange batchPrefix = parent ? parent->getBatchCrds() : ValueRange{};
+ std::tie(posLo, posHi) = stl.peekRangeAt(b, l, batchPrefix, pPos, inPadZone);
+ // Seek to the lowest position.
+ seek(posLo);
+}
+
void NonEmptySubSectIterator::genInitImpl(OpBuilder &b, Location l,
const SparseIterator *) {
Value c0 = C_IDX(0);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
index 2e7eeb2a05f998..120a806536f190 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h
@@ -46,9 +46,9 @@ class SparseTensorLevel {
///
/// For a sparse level, [posLo, loopHi) specifies the range of index pointer
/// to load coordinate from the coordinate buffer.
- virtual std::pair<Value, Value> peekRangeAt(OpBuilder &b, Location l,
- ValueRange batchPrefix,
- ValueRange parentPos) const = 0;
+ virtual std::pair<Value, Value>
+ peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+ ValueRange parentPos, Value inPadZone = nullptr) const = 0;
Level getLevel() const { return lvl; }
LevelType getLT() const { return lt; }
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
index 92fbbf54558237..50dd989416e2a0 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
@@ -30,16 +30,8 @@
// Do the same run, but now with direct IR generation and VLA vectorization.
// RUN: %if mlir_arm_sve_tests %{ %{compile_sve} | %{run_sve} | FileCheck %s %}
-#CCCC = #sparse_tensor.encoding<{
- map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : compressed, d2 : compressed, d3 : compressed)
-}>
-
-#CDCD = #sparse_tensor.encoding<{
- map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : dense, d2 : compressed, d3 : dense)
-}>
-
-#DCCD = #sparse_tensor.encoding<{
- map = (d0, d1, d2, d3) -> (d0 : dense, d1 : compressed, d2 : compressed, d3 : dense)
+#CDCC_NHWC = #sparse_tensor.encoding<{
+ map = (d0, d1, d2, d3) -> (d0 : compressed, d1 : dense, d2 : compressed, d3 : compressed)
}>
// Creates and returns 4-D buffer of size (%s1, %s2, %s3, %s4) filled with the value %f
@@ -66,7 +58,7 @@ func.func @conv_2d_nhwc_hwcf(%arg0: tensor<3x8x8x3xf32>, %arg1: tensor<5x5x3x1xf
return %ret : tensor<3x8x8x1xf32>
}
-func.func @conv_2d_nhwc_hwcf_CCCC(%arg0: tensor<3x8x8x3xf32, #CCCC>, %arg1: tensor<5x5x3x1xf32>) -> tensor<3x8x8x1xf32> {
+func.func @conv_2d_nhwc_hwcf_CDCC_NHWC(%arg0: tensor<3x8x8x3xf32, #CDCC_NHWC>, %arg1: tensor<5x5x3x1xf32>) -> tensor<3x8x8x1xf32> {
%cst_0 = arith.constant 0.00000e+00 : f32
%buf = tensor.empty() : tensor<3x8x8x1xf32>
%s = linalg.fill ins(%cst_0 : f32) outs(%buf : tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32>
@@ -74,11 +66,11 @@ func.func @conv_2d_nhwc_hwcf_CCCC(%arg0: tensor<3x8x8x3xf32, #CCCC>, %arg1: tens
%padded = tensor.pad %arg0 low[0, 2, 2, 0] high[0, 2, 2, 0] {
^bb0(%arg75: index, %arg76: index, %arg77: index, %arg78: index):
tensor.yield %cst_0 : f32
- } : tensor<3x8x8x3xf32, #CCCC> to tensor<3x12x12x3xf32, #CCCC>
+ } : tensor<3x8x8x3xf32, #CDCC_NHWC> to tensor<3x12x12x3xf32, #CDCC_NHWC>
%ret = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>,
strides = dense<1> : tensor<2xi64>}
- ins (%padded, %arg1: tensor<3x12x12x3xf32, #CCCC>, tensor<5x5x3x1xf32>)
+ ins (%padded, %arg1: tensor<3x12x12x3xf32, #CDCC_NHWC>, tensor<5x5x3x1xf32>)
outs (%s: tensor<3x8x8x1xf32>) -> tensor<3x8x8x1xf32>
return %ret : tensor<3x8x8x1xf32>
}
@@ -105,8 +97,8 @@ func.func @main() {
%dense_ret = call @conv_2d_nhwc_hwcf(%static_input, %static_filter, %static_output) : (tensor<3x8x8x3xf32>, tensor<5x5x3x1xf32>, tensor<3x8x8x1xf32>) -> (tensor<3x8x8x1xf32>)
- %in2D_nhwc_CCCC = sparse_tensor.convert %static_input : tensor<3x8x8x3xf32> to tensor<3x8x8x3xf32, #CCCC>
- %CCCC_ret = call @conv_2d_nhwc_hwcf_CCCC(%in2D_nhwc_CCCC, %static_filter) : (tensor<3x8x8x3xf32, #CCCC>, tensor<5x5x3x1xf32>) -> (tensor<3x8x8x1xf32>)
+ %in2D_nhwc_CDCC_NHWC = sparse_tensor.convert %static_input : tensor<3x8x8x3xf32> to tensor<3x8x8x3xf32, #CDCC_NHWC>
+ %CDCC_NHWC_ret = call @conv_2d_nhwc_hwcf_CDCC_NHWC(%in2D_nhwc_CDCC_NHWC, %static_filter) : (tensor<3x8x8x3xf32, #CDCC_NHWC>, tensor<5x5x3x1xf32>) -> (tensor<3x8x8x1xf32>)
// CHECK: ( ( ( ( 108 ), ( 160 ), ( 196 ), ( 196 ), ( 196 ), ( 196 ), ( 144 ), ( 108 ) ),
@@ -161,17 +153,17 @@ func.func @main() {
// CHECK-SAME: ( ( 180 ), ( 240 ), ( 300 ), ( 300 ), ( 300 ), ( 300 ), ( 240 ), ( 180 ) ),
// CHECK-SAME: ( ( 144 ), ( 192 ), ( 240 ), ( 240 ), ( 240 ), ( 240 ), ( 192 ), ( 144 ) ),
// CHECK-SAME: ( ( 108 ), ( 144 ), ( 180 ), ( 180 ), ( 180 ), ( 180 ), ( 144 ), ( 108 ) ) ) )
- %CCCC_v = vector.transfer_read %CCCC_ret[%c0, %c0, %c0, %c0], %zero
+ %CDCC_NHWC_v = vector.transfer_read %CDCC_NHWC_ret[%c0, %c0, %c0, %c0], %zero
: tensor<3x8x8x1xf32>, vector<3x8x8x1xf32>
- vector.print %CCCC_v : vector<3x8x8x1xf32>
+ vector.print %CDCC_NHWC_v : vector<3x8x8x1xf32>
bufferization.dealloc_tensor %static_filter : tensor<5x5x3x1xf32>
bufferization.dealloc_tensor %static_input : tensor<3x8x8x3xf32>
bufferization.dealloc_tensor %static_output : tensor<3x8x8x1xf32>
- bufferization.dealloc_tensor %CCCC_ret : tensor<3x8x8x1xf32>
+ bufferization.dealloc_tensor %CDCC_NHWC_ret : tensor<3x8x8x1xf32>
- bufferization.dealloc_tensor %in2D_nhwc_CCCC : tensor<3x8x8x3xf32, #CCCC>
+ bufferization.dealloc_tensor %in2D_nhwc_CDCC_NHWC : tensor<3x8x8x3xf32, #CDCC_NHWC>
return
}
|
aartbik
reviewed
May 1, 2024
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
Outdated
Show resolved
Hide resolved
mlir/test/Integration/Dialect/SparseTensor/CPU/padded_sparse_conv_2d.mlir
Show resolved
Hide resolved
69f555f
to
c83ada7
Compare
aartbik
approved these changes
May 1, 2024
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp
Outdated
Show resolved
Hide resolved
yinying-lisa-li
approved these changes
May 1, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
No description provided.