-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][sparse] set up the skeleton for SparseTensorLevel abstraction. #75645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-sparse @llvm/pr-subscribers-mlir Author: Peiming Liu (PeimingLiu) ChangesNote that at the current moment, the newly-introduced Patch is 20.38 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/75645.diff 5 Files Affected:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
index ad8b0d02eca35e..d3ab65e4e1793a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
@@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
Utils/IterationGraphSorter.cpp
Utils/LoopEmitter.cpp
Utils/SparseTensorDescriptor.cpp
+ Utils/SparseTensorLevels.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index 784c793c9bd119..6fa11cc1757aa3 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -126,15 +126,15 @@ static std::pair<Value, Value> fromSliceCrd(OpBuilder &builder, Location loc,
// Generates a bool value for while loop condition that tries to iterate over a
// fully reduced level with affine index expression.
static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc,
- Value crdBuf, Value crdHi, Value posit,
- Value posHi) {
+ const SparseTensorLevel &level,
+ Value crdHi, Value posit, Value posHi) {
Value inBound = CMPI(ult, posit, posHi);
auto ifOp =
builder.create<scf::IfOp>(loc, builder.getI1Type(), inBound, true);
// if (inbound)
// yield coord < crdHi
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
- Value crd = genIndexLoad(builder, loc, crdBuf, posit);
+ Value crd = level.peekCrdAt(builder, loc, posit);
YIELD(CMPI(ult, crd, crdHi));
// else
// yield false
@@ -244,13 +244,12 @@ Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, TensorId tid,
Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc,
TensorId tid, Level lvl, Value pLo,
Value pHi) {
- const auto coordinates = coordinatesBuffers[tid][lvl];
- const auto sameCrd = genIndexLoad(builder, loc, coordinates, pLo);
+ SparseTensorLevel &level = *lvls[tid][lvl];
+ const Value sameCrd = level.peekCrdAt(builder, loc, pLo);
auto whileOp = builder.create<scf::WhileOp>(
loc, builder.getIndexType(), pLo,
/*beforeBuilder=*/
- [pHi, coordinates, sameCrd](OpBuilder &builder, Location loc,
- ValueRange ivs) {
+ [pHi, &level, sameCrd](OpBuilder &builder, Location loc, ValueRange ivs) {
const auto pos = ivs[0];
Value inBound = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ult, pos, pHi);
@@ -261,7 +260,7 @@ Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc,
// Load the next coordinates only when inbound (to avoid OOB
// accesses).
builder.setInsertionPointToStart(ifInBound.thenBlock());
- Value crd = genIndexLoad(builder, loc, coordinates, pos);
+ Value crd = level.peekCrdAt(builder, loc, pos);
Value isSameCrd = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, crd, sameCrd);
YIELD(isSameCrd);
@@ -284,11 +283,8 @@ Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc,
Value LoopEmitter::genSparseCrd(OpBuilder &builder, Location loc, TensorId tid,
Level lvl) {
- // A load on the coordinates array yields the coordinate.
- const Value mem = coordinatesBuffers[tid][lvl];
- /// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
const Value pos = posits[tid][lvl];
- const Value crd = genIndexLoad(builder, loc, mem, pos);
+ const Value crd = lvls[tid][lvl]->peekCrdAt(builder, loc, pos);
return crd;
}
@@ -318,9 +314,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
this->segHi.assign(numTensors, std::vector<Value>());
this->posits.assign(numTensors, std::vector<Value>());
this->coords.assign(numTensors, std::vector<Value>());
- this->positionsBuffers.assign(numTensors, std::vector<Value>());
- this->coordinatesBuffers.assign(numTensors, std::vector<Value>());
this->valBuffer.assign(numTensors, nullptr);
+ this->lvls.resize(numTensors);
this->isSparseSlices.assign(numTensors, false);
this->sliceOffsets.assign(numTensors, std::vector<Value>());
this->sliceStrides.assign(numTensors, std::vector<Value>());
@@ -377,8 +372,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
segHi[tid].assign(lvlRank, Value());
posits[tid].assign(lvlRank, Value());
coords[tid].assign(lvlRank, Value());
- positionsBuffers[tid].assign(lvlRank, Value());
- coordinatesBuffers[tid].assign(lvlRank, Value());
+ lvls[tid].resize(lvlRank);
+
sliceOffsets[tid].assign(lvlRank, Value());
sliceStrides[tid].assign(lvlRank, Value());
@@ -448,22 +443,7 @@ void LoopEmitter::initializeLoopEmit(
// Scan all levels of current tensor.
for (Level l = 0; l < lvlRank; l++) {
- // This should be called only once at beginning.
- assert(!positionsBuffers[t][l] && !coordinatesBuffers[t][l] &&
- !highs[t][l]);
- const auto lvlTp = lvlTypes[t][l];
- // Handle sparse storage schemes.
- if (isCompressedLT(lvlTp) || isLooseCompressedLT(lvlTp)) {
- // Generate sparse primitives to obtain positions and coordinates.
- positionsBuffers[t][l] = genToPositions(builder, loc, tensor, l);
- coordinatesBuffers[t][l] = genToCoordinates(builder, loc, tensor, l);
- } else if (isSingletonLT(lvlTp) || is2OutOf4LT(lvlTp)) {
- // Singleton level, fetch coordinates.
- coordinatesBuffers[t][l] = genToCoordinates(builder, loc, tensor, l);
- } else {
- // Dense level, nothing to fetch.
- assert(isDenseLT(lvlTp));
- }
+ lvls[t][l] = makeSparseTensorLevel(builder, loc, tensor, l);
// Find upper bound in current dimension.
highs[t][l] = lvlSizes[t][l] = lvlSzs[l];
@@ -756,8 +736,7 @@ Value LoopEmitter::genWhileLoopConditions(OpBuilder &builder, Location loc,
crdHi = ADDI(getMostRecentSliceOnLvl(tid, lvl).offset, remSz);
}
assert(crdHi);
- return genSparseReducedAffineCond(builder, loc,
- coordinatesBuffers[tid][lvl], crdHi,
+ return genSparseReducedAffineCond(builder, loc, *lvls[tid][lvl], crdHi,
ivs[0], highs[tid][lvl]);
}
case LoopCondKind::SparseAffineUnRedCond: {
@@ -802,10 +781,9 @@ std::optional<Value> LoopEmitter::genWhileLoopBody(OpBuilder &builder,
sliceTupleFwdCnt[tid][lvl] = SUBI(ivs[0], posits[tid][lvl]);
// Update c = absOffset[lvl][depth] - absOffset[lvl][depth - 1]
Value posit = ivs[0];
- Value crdBuf = coordinatesBuffers[tid][lvl];
// We need to substract the offset to get relative coordinates.
// TODO: Maybe assert relC >=0 during runtime in debug build?
- Value absC = genIndexLoad(builder, loc, crdBuf, posit);
+ Value absC = lvls[tid][lvl]->peekCrdAt(builder, loc, posit);
auto relC = SUBI(absC, getFinalSliceOnLvl(tid, lvl).offset);
posits[tid][lvl] = posit;
coords[tid][lvl] = relC;
@@ -1189,9 +1167,7 @@ Operation *LoopEmitter::enterFilterLoopOverTensorAtLvl(
// The induction variable gives the position.
const Value pos = forOp.getInductionVar();
posits[tid][lvl] = pos;
- // Generating a load on the coordinates array yields the crd.
- const Value mem = coordinatesBuffers[tid][lvl];
- const Value crd = genIndexLoad(builder, loc, mem, pos);
+ const Value crd = lvls[tid][lvl]->peekCrdAt(builder, loc, pos);
coords[tid][lvl] = crd;
// Generate an if-condition to filter out coordinates that are not
@@ -1255,7 +1231,10 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
/// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
assert(lvl == 0 || posits[tid][lvl - 1]);
if (isCompressedLT(lvlTp) || isLooseCompressedLT(lvlTp)) {
- const Value mem = positionsBuffers[tid][lvl];
+ const Value mem =
+ isCompressedLT(lvlTp)
+ ? static_cast<CompressedLevel &>(*lvls[tid][lvl]).posBuffer
+ : static_cast<LooseCompressedLevel &>(*lvls[tid][lvl]).posBuffer;
Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1];
if (isLooseCompressedLT(lvlTp))
@@ -1623,8 +1602,7 @@ std::pair<Operation *, ValueRange> LoopEmitter::genSliceLvlTraverseLoop(
/*beforeBuilder=*/
[this, posHi, sliceHi, tid, lvl](OpBuilder &builder, Location loc,
ValueRange args) {
- Value cond = genSparseReducedAffineCond(builder, loc,
- coordinatesBuffers[tid][lvl],
+ Value cond = genSparseReducedAffineCond(builder, loc, *lvls[tid][lvl],
sliceHi, args[0], posHi);
// continue if not yet break nor out of bound.
builder.create<scf::ConditionOp>(loc, cond, args);
@@ -1848,12 +1826,12 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
Value pHi, pLo;
if (lvl == 0) {
pLo = c0;
- pHi = genIndexLoad(builder, loc, positionsBuffers[tid][0], c1);
+ Value pBuf = static_cast<CompressedLevel &>(*lvls[tid][0]).posBuffer;
+ pHi = genIndexLoad(builder, loc, pBuf, c1);
} else {
- pLo = genIndexLoad(builder, loc, positionsBuffers[tid][lvl],
- posits[tid][lvl - 1]);
- pHi = genIndexLoad(builder, loc, positionsBuffers[tid][lvl],
- ADDI(posits[tid][lvl - 1], c1));
+ Value pBuf = static_cast<CompressedLevel &>(*lvls[tid][lvl]).posBuffer;
+ pLo = genIndexLoad(builder, loc, pBuf, posits[tid][lvl - 1]);
+ pHi = genIndexLoad(builder, loc, pBuf, ADDI(posits[tid][lvl - 1], c1));
}
// Fills out pIdxBuffer[tid][lvl][0] with [pLo, pHi]
updateSlicePos(builder, loc, sPtrBuf, pLo, c0, SlicePosKind::kLo);
@@ -1868,7 +1846,7 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
// nonempty. though we assume that even on empty sparse tensors, a non-empty
// ptr/idx buffer is allocated for each level so it would not cause OOB to
// avoid generating a ifOp here.
- Value minCrd = genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], pLo);
+ Value minCrd = lvls[tid][lvl]->peekCrdAt(builder, loc, pLo);
// FIXME: We need the relative offset related to the base slice.
Value absOffset = offsetFromMinCoord(builder, loc, minCrd, nxSz, isNonEmpty);
@@ -1955,9 +1933,9 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
Value &curTupleCnt = reduc[2];
Value pHi = ADDI(iv, c1);
- Value sPLo = genIndexLoad(builder, loc, positionsBuffers[tid][lvl], iv);
- Value sPHi =
- genIndexLoad(builder, loc, positionsBuffers[tid][lvl], pHi);
+ Value pBuf = static_cast<CompressedLevel &>(*lvls[tid][lvl]).posBuffer;
+ Value sPLo = genIndexLoad(builder, loc, pBuf, iv);
+ Value sPHi = genIndexLoad(builder, loc, pBuf, pHi);
// isNonEmpty = isNonEmpty || lvlNonEmpty, i.e., as long as there is
// one non-empty lvl, the slice is non-empty.
@@ -1975,8 +1953,7 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
// }
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(ifNonEmpty.thenBlock());
- Value curC =
- genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], sPLo);
+ Value curC = lvls[tid][lvl]->peekCrdAt(builder, loc, sPLo);
Value isSmaller = CMPI(ult, curC, minCrd);
Value newMin = SELECT(isSmaller, curC, minCrd);
YIELD(newMin);
@@ -2176,8 +2153,7 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
/* if pLo < pHi */ {
builder.setInsertionPointToStart(&advPLo.getThenRegion().front());
// coord = load[pLo]
- Value coord =
- genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], pLo);
+ Value coord = lvls[tid][lvl]->peekCrdAt(builder, loc, pLo);
Value pred = CMPI(eq, coord, info.minCrd);
auto ifEqual = builder.create<scf::IfOp>(loc, idxTp, pred, true);
/* if coord == minCrd */ {
@@ -2209,7 +2185,7 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
auto newMin =
builder.create<scf::IfOp>(loc, idxTp, lvlNonEmpty, true);
builder.setInsertionPointToStart(&newMin.getThenRegion().front());
- YIELD(genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], pLo));
+ YIELD(lvls[tid][lvl]->peekCrdAt(builder, loc, pLo));
builder.setInsertionPointToStart(&newMin.getElseRegion().front());
YIELD(curMinCrd);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index 78bb53e4483f60..272d2bf0e89c2e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -11,6 +11,8 @@
#include <vector>
+#include "SparseTensorLevels.h"
+
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
@@ -241,12 +243,6 @@ class LoopEmitter {
const std::vector<std::vector<Value>> &getPosits() const { return posits; };
const std::vector<std::vector<Value>> &getCoords() const { return coords; };
const std::vector<std::vector<Value>> &getHighs() const { return highs; };
- const std::vector<std::vector<Value>> &getPositionBuffers() const {
- return positionsBuffers;
- };
- const std::vector<std::vector<Value>> &getCoordinateBuffers() const {
- return coordinatesBuffers;
- };
const std::vector<Value> &getValBuffer() const { return valBuffer; };
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName() {
@@ -648,8 +644,9 @@ class LoopEmitter {
std::vector<std::vector<Value>> segHi;
std::vector<std::vector<Value>> highs;
std::vector<std::vector<Value>> lvlSizes;
- std::vector<std::vector<Value>> positionsBuffers; // to_positions
- std::vector<std::vector<Value>> coordinatesBuffers; // to_coordinates
+ // std::vector<std::vector<Value>> positionsBuffers; // to_positions
+ // std::vector<std::vector<Value>> coordinatesBuffers; // to_coordinates
+ std::vector<std::vector<std::unique_ptr<SparseTensorLevel>>> lvls;
std::vector<Value> valBuffer; // to_value
//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.cpp
new file mode 100644
index 00000000000000..a9dae17e9de055
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.cpp
@@ -0,0 +1,46 @@
+#include "SparseTensorLevels.h"
+#include "CodegenUtils.h"
+
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+
+using namespace mlir;
+using namespace mlir::sparse_tensor;
+
+std::unique_ptr<SparseTensorLevel>
+sparse_tensor::makeSparseTensorLevel(OpBuilder &builder, Location loc, Value t,
+ Level l) {
+ auto stt = getSparseTensorType(t);
+
+ LevelType lt = stt.getLvlType(l);
+ Value lvlSz = stt.hasEncoding()
+ ? builder.create<LvlOp>(loc, t, l).getResult()
+ : builder.create<tensor::DimOp>(loc, t, l).getResult();
+
+ switch (*getLevelFormat(lt)) {
+ case LevelFormat::Dense:
+ return std::make_unique<DenseLevel>(lvlSz);
+ case LevelFormat::Compressed: {
+ Value posBuf = genToPositions(builder, loc, t, l);
+ Value crdBuf = genToCoordinates(builder, loc, t, l);
+ return std::make_unique<CompressedLevel>(lt, lvlSz, posBuf, crdBuf);
+ }
+ case LevelFormat::LooseCompressed: {
+ Value posBuf = genToPositions(builder, loc, t, l);
+ Value crdBuf = genToCoordinates(builder, loc, t, l);
+ return std::make_unique<LooseCompressedLevel>(lt, lvlSz, posBuf, crdBuf);
+ }
+ case LevelFormat::Singleton: {
+ Value crdBuf = genToCoordinates(builder, loc, t, l);
+ return std::make_unique<SingletonLevel>(lt, lvlSz, crdBuf);
+ }
+ case LevelFormat::TwoOutOfFour: {
+ Value crdBuf = genToCoordinates(builder, loc, t, l);
+ return std::make_unique<TwoOutFourLevel>(lt, lvlSz, crdBuf);
+ }
+ }
+ llvm_unreachable("unrecognizable level format");
+}
+
+Value SparseLevel::peekCrdAt(OpBuilder &b, Location l, Value pos) const {
+ return genIndexLoad(b, l, crdBuffer, pos);
+}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.h
new file mode 100644
index 00000000000000..c6574295ca7fae
--- /dev/null
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.h
@@ -0,0 +1,109 @@
+//===- TensorLevels.h -------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_TENSORLEVEL_H_
+#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_TENSORLEVEL_H_
+
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+
+namespace mlir {
+namespace sparse_tensor {
+
+class SparseTensorLevel {
+ SparseTensorLevel(SparseTensorLevel &&) = delete;
+ SparseTensorLevel(const SparseTensorLevel &) = delete;
+
+public:
+ SparseTensorLevel() : SparseTensorLevel(LevelType::Undef, nullptr){};
+ virtual ~SparseTensorLevel() = default;
+
+ virtual Value peekCrdAt(OpBuilder &b, Location l, Value p) const = 0;
+
+ LevelType getLT() const { return lt; }
+ Value getPos() const { return pos; }
+ Value getCrd() const { return crd; }
+ Value getLoopHi() const { return loopHi; }
+ Value getLoopLo() const { return loopLo; }
+
+protected:
+ SparseTensorLevel(LevelType lt, Value lvlSize)
+ : lt(lt), lvlSize(lvlSize), pos(nullptr), crd(nullptr), loopHi(nullptr),
+ loopLo(nullptr){};
+
+ const LevelType lt;
+ const Value lvlSize;
+
+public: // TODO: make these values private upon feature complete.
+ Value pos;
+ Value crd;
+ Value loopHi;
+ Value loopLo;
+};
+
+/// Helper function to create a TensorLevel object from given `tensor`.
+std::unique_ptr<SparseTensorLevel>
+makeSparseTensorLevel(OpBuilder &builder, Location loc, Value t, Level l);
+
+class DenseLevel : public SparseTensorLevel {
+public:
+ DenseLevel(Value lvlSize) : SparseTensorLevel(LevelType::Dense, lvlSize) {
+ // Dense level, loop upper bound equals to the level size.
+ loopHi = lvlSize;
+ }
+
+ Value peekCrdAt(OpBuilder &, Location, Value pos) const override {
+ return pos;
+ }
+};
+
+class SparseLevel : public SparseTensorLevel {
+public:
+ SparseLevel(LevelType lt, Value lvlSize, Value crdBuffer)
+ : SparseTensorLevel(lt, lvlSize), crdBuffer(crdBuffer) {}
+
+ Value peekCrdAt(OpBuilder &b, Location l, Value pos) const override;
+
+public: // TODO: make these values private upon feature complete.
+ const Value crdBuffer;
+};
+
+class CompressedLevel : public SparseLevel {
+public:
+ CompressedLevel(LevelType lt, Value lvlSize, Value posBuffer, Value crdBuffer)
+ : SparseLevel(lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+
+public: // TODO: make these values private upon feature complete.
+ const Value posBuffer;
+};
+
+class LooseCompressedLevel : public SparseLevel {
+public:
+ LooseCompressedLevel(LevelType lt, Value lvlSize, Value posBuffer,
+ Value crdBuffer)
+ : SparseLevel(lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
+
+public: // TODO: make these values private upon feature complete.
+ const Value posBuffer;
+};
+
+class SingletonLevel : public SparseLevel {
+public:
+ SingletonLevel(Lev...
[truncated]
|
6abcccf
to
e31518e
Compare
e31518e
to
c6f97d5
Compare
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.h
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.h
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevels.h
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Note that at the current moment, the newly-introduced
SparseTensorLevel
classes are far from complete, we plan to migrate code generation related to accessing sparse tensor levels to these classes in the near future to simplifyLoopEmitter
.