Skip to content

[mlir][sparse] rename files and unifies APIs #88162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
Utils/IterationGraphSorter.cpp
Utils/LoopEmitter.cpp
Utils/SparseTensorDescriptor.cpp
Utils/SparseTensorLevel.cpp
Utils/SparseTensorIterator.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

#include <vector>

#include "SparseTensorLevel.h"
#include "SparseTensorIterator.h"

#include "mlir/Dialect/SparseTensor/IR/Enums.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
//===- SparseTensorLevel.cpp - Tensor management class -------------------===//
//===- SparseTensorIterator.cpp -------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "SparseTensorLevel.h"
#include "SparseTensorIterator.h"
#include "CodegenUtils.h"

#include "mlir/Dialect/MemRef/IR/MemRef.h"
Expand Down Expand Up @@ -46,21 +46,41 @@ using ValueTuple = std::tuple<Value, Value, Value>;

namespace {

template <bool hasPosBuffer>
class SparseLevel : public SparseTensorLevel {
// It is either an array of size 2 or size 1 depending on whether the sparse
// level requires a position array.
using BufferT = std::conditional_t<hasPosBuffer, std::array<Value, 2>,
std::array<Value, 1>>;

public:
SparseLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value crdBuffer)
: SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
BufferT buffers)
: SparseTensorLevel(tid, lvl, lt, lvlSize), buffers(buffers) {}

ValueRange getLvlBuffers() const override { return buffers; }

Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value iv) const override {
SmallVector<Value> memCrd(batchPrefix);
memCrd.push_back(iv);
return genIndexLoad(b, l, crdBuffer, memCrd);
return genIndexLoad(b, l, getCrdBuf(), memCrd);
}

protected:
const Value crdBuffer;
template <typename T = void, typename = std::enable_if_t<hasPosBuffer, T>>
Value getPosBuf() const {
return buffers[0];
}

Value getCrdBuf() const {
if constexpr (hasPosBuffer)
return buffers[1];
else
return buffers[0];
}

const BufferT buffers;
};

class DenseLevel : public SparseTensorLevel {
Expand All @@ -72,6 +92,8 @@ class DenseLevel : public SparseTensorLevel {
llvm_unreachable("locate random-accessible level instead");
}

ValueRange getLvlBuffers() const override { return {}; }

ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
Value max) const override {
Value posLo = MULI(p, lvlSize);
Expand All @@ -88,6 +110,8 @@ class BatchLevel : public SparseTensorLevel {
llvm_unreachable("locate random-accessible level instead");
}

ValueRange getLvlBuffers() const override { return {}; }

ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
Value max) const override {
assert(max == nullptr && "Dense level can not be non-unique.");
Expand All @@ -96,11 +120,11 @@ class BatchLevel : public SparseTensorLevel {
}
};

class CompressedLevel : public SparseLevel {
class CompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
public:
CompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value posBuffer, Value crdBuffer)
: SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}

ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value max) const override {
Expand All @@ -109,21 +133,18 @@ class CompressedLevel : public SparseLevel {

SmallVector<Value> memCrd(batchPrefix);
memCrd.push_back(p);
Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
memCrd.back() = ADDI(p, C_IDX(1));
Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
return {pLo, pHi};
}

private:
const Value posBuffer;
};

class LooseCompressedLevel : public SparseLevel {
class LooseCompressedLevel : public SparseLevel</*hasPosBuf=*/true> {
public:
LooseCompressedLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value posBuffer, Value crdBuffer)
: SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
: SparseLevel(tid, lvl, lt, lvlSize, {posBuffer, crdBuffer}) {}

ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value max) const override {
Expand All @@ -133,21 +154,18 @@ class LooseCompressedLevel : public SparseLevel {

p = MULI(p, C_IDX(2));
memCrd.push_back(p);
Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
Value pLo = genIndexLoad(b, l, getPosBuf(), memCrd);
memCrd.back() = ADDI(p, C_IDX(1));
Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
Value pHi = genIndexLoad(b, l, getPosBuf(), memCrd);
return {pLo, pHi};
}

private:
const Value posBuffer;
};

class SingletonLevel : public SparseLevel {
class SingletonLevel : public SparseLevel</*hasPosBuf=*/false> {
public:
SingletonLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value crdBuffer)
: SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}

ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value segHi) const override {
Expand All @@ -159,11 +177,11 @@ class SingletonLevel : public SparseLevel {
}
};

class NOutOfMLevel : public SparseLevel {
class NOutOfMLevel : public SparseLevel</*hasPosBuf=*/false> {
public:
NOutOfMLevel(unsigned tid, Level lvl, LevelType lt, Value lvlSize,
Value crdBuffer)
: SparseLevel(tid, lvl, lt, lvlSize, crdBuffer) {}
: SparseLevel(tid, lvl, lt, lvlSize, {crdBuffer}) {}

ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
Value p, Value max) const override {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
//===- SparseTensorLevel.h --------------------------------------*- C++ -*-===//
//===- SparseTensorIterator.h ---------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORLEVEL_H_
#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORLEVEL_H_
#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_
#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_

#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
Expand Down Expand Up @@ -55,6 +55,7 @@ class SparseTensorLevel {
Level getLevel() const { return lvl; }
LevelType getLT() const { return lt; }
Value getSize() const { return lvlSize; }
virtual ValueRange getLvlBuffers() const = 0;

//
// Level properties
Expand Down Expand Up @@ -321,4 +322,4 @@ std::unique_ptr<SparseIterator> makeTraverseSubSectIterator(
} // namespace sparse_tensor
} // namespace mlir

#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORLEVEL_H_
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_SPARSETENSORITERATOR_H_