Skip to content

[mlir][sparse] provide an AoS "view" into sparse runtime support lib #87116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ class SparseTensorStorageBase {
MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DECL_GETCOORDINATES)
#undef DECL_GETCOORDINATES

/// Gets coordinates-overhead storage buffer for the given level.
#define DECL_GETCOORDINATESBUFFER(INAME, C) \
virtual void getCoordinatesBuffer(std::vector<C> **, uint64_t);
MLIR_SPARSETENSOR_FOREVERY_FIXED_O(DECL_GETCOORDINATESBUFFER)
#undef DECL_GETCOORDINATESBUFFER

/// Gets primary storage.
#define DECL_GETVALUES(VNAME, V) virtual void getValues(std::vector<V> **);
MLIR_SPARSETENSOR_FOREVERY_V(DECL_GETVALUES)
Expand Down Expand Up @@ -251,6 +257,31 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
assert(lvl < getLvlRank());
*out = &coordinates[lvl];
}
void getCoordinatesBuffer(std::vector<C> **out, uint64_t lvl) final {
assert(out && "Received nullptr for out parameter");
assert(lvl < getLvlRank());
// Note that the sparse tensor support library always stores COO in SoA
// format, even when AoS is requested. This is never an issue, since all
// actual code/library generation requests "views" into the coordinate
// storage for the individual levels, which is trivially provided for
// both AoS and SoA (as well as all the other storage formats). The only
// exception is when the buffer version of coordinate storage is requested
// (currently only for printing). In that case, we do the following
// potentially expensive transformation to provide that view. If this
// operation becomes more common beyond debugging, we should consider
// implementing proper AoS in the support library as well.
uint64_t lvlRank = getLvlRank();
uint64_t nnz = values.size();
crdBuffer.clear();
crdBuffer.reserve(nnz * (lvlRank - lvl));
for (uint64_t i = 0; i < nnz; i++) {
for (uint64_t l = lvl; l < lvlRank; l++) {
assert(i < coordinates[l].size());
crdBuffer.push_back(coordinates[l][i]);
}
}
*out = &crdBuffer;
}
void getValues(std::vector<V> **out) final {
assert(out && "Received nullptr for out parameter");
*out = &values;
Expand Down Expand Up @@ -529,10 +560,14 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
return -1u;
}

// Sparse tensor storage components.
std::vector<std::vector<P>> positions;
std::vector<std::vector<C>> coordinates;
std::vector<V> values;

// Auxiliary data structures.
std::vector<uint64_t> lvlCursor;
std::vector<C> crdBuffer; // just for AoS view
};

//===----------------------------------------------------------------------===//
Expand Down
8 changes: 8 additions & 0 deletions mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ MLIR_SPARSETENSOR_FOREVERY_O(DECL_SPARSEPOSITIONS)
MLIR_SPARSETENSOR_FOREVERY_O(DECL_SPARSECOORDINATES)
#undef DECL_SPARSECOORDINATES

/// Tensor-storage method to obtain direct access to the coordinates array
/// buffer for the given level (provides an AoS view into the library).
#define DECL_SPARSECOORDINATES(CNAME, C) \
MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_sparseCoordinatesBuffer##CNAME( \
StridedMemRefType<C, 1> *out, void *tensor, index_type lvl);
MLIR_SPARSETENSOR_FOREVERY_O(DECL_SPARSECOORDINATES)
#undef DECL_SPARSECOORDINATES

/// Tensor-storage method to insert elements in lexicographical
/// level-coordinate order.
#define DECL_LEXINSERT(VNAME, V) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ static Value genPositionsCall(OpBuilder &builder, Location loc,
.getResult(0);
}

/// Generates a call to obtain the coordindates array.
/// Generates a call to obtain the coordinates array.
static Value genCoordinatesCall(OpBuilder &builder, Location loc,
SparseTensorType stt, Value ptr, Level l) {
Type crdTp = stt.getCrdType();
Expand All @@ -287,6 +287,20 @@ static Value genCoordinatesCall(OpBuilder &builder, Location loc,
.getResult(0);
}

/// Generates a call to obtain the coordinates array (AoS view).
static Value genCoordinatesBufferCall(OpBuilder &builder, Location loc,
SparseTensorType stt, Value ptr,
Level l) {
Type crdTp = stt.getCrdType();
auto resTp = MemRefType::get({ShapedType::kDynamic}, crdTp);
Value lvl = constantIndex(builder, loc, l);
SmallString<25> name{"sparseCoordinatesBuffer",
overheadTypeFunctionSuffix(crdTp)};
return createFuncCall(builder, loc, name, resTp, {ptr, lvl},
EmitCInterface::On)
.getResult(0);
}

//===----------------------------------------------------------------------===//
// Conversion rules.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -518,13 +532,35 @@ class SparseTensorToCoordinatesConverter
LogicalResult
matchAndRewrite(ToCoordinatesOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const Location loc = op.getLoc();
auto stt = getSparseTensorType(op.getTensor());
auto crds = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
op.getLevel());
// Cast the MemRef type to the type expected by the users, though these
// two types should be compatible at runtime.
if (op.getType() != crds.getType())
crds = rewriter.create<memref::CastOp>(loc, op.getType(), crds);
rewriter.replaceOp(op, crds);
return success();
}
};

/// Sparse conversion rule for coordinate accesses (AoS style).
class SparseToCoordinatesBufferConverter
: public OpConversionPattern<ToCoordinatesBufferOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ToCoordinatesBufferOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const Location loc = op.getLoc();
auto stt = getSparseTensorType(op.getTensor());
auto crds = genCoordinatesCall(rewriter, op.getLoc(), stt,
adaptor.getTensor(), op.getLevel());
auto crds = genCoordinatesBufferCall(
rewriter, loc, stt, adaptor.getTensor(), stt.getAoSCOOStart());
// Cast the MemRef type to the type expected by the users, though these
// two types should be compatible at runtime.
if (op.getType() != crds.getType())
crds = rewriter.create<memref::CastOp>(op.getLoc(), op.getType(), crds);
crds = rewriter.create<memref::CastOp>(loc, op.getType(), crds);
rewriter.replaceOp(op, crds);
return success();
}
Expand Down Expand Up @@ -878,10 +914,10 @@ void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
SparseTensorAllocConverter, SparseTensorEmptyConverter,
SparseTensorDeallocConverter, SparseTensorReorderCOOConverter,
SparseTensorToPositionsConverter, SparseTensorToCoordinatesConverter,
SparseTensorToValuesConverter, SparseNumberOfEntriesConverter,
SparseTensorLoadConverter, SparseTensorInsertConverter,
SparseTensorExpandConverter, SparseTensorCompressConverter,
SparseTensorAssembleConverter, SparseTensorDisassembleConverter,
SparseHasRuntimeLibraryConverter>(typeConverter,
patterns.getContext());
SparseToCoordinatesBufferConverter, SparseTensorToValuesConverter,
SparseNumberOfEntriesConverter, SparseTensorLoadConverter,
SparseTensorInsertConverter, SparseTensorExpandConverter,
SparseTensorCompressConverter, SparseTensorAssembleConverter,
SparseTensorDisassembleConverter, SparseHasRuntimeLibraryConverter>(
typeConverter, patterns.getContext());
}
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,9 @@ struct PrintRewriter : public OpRewritePattern<PrintOp> {
loc, lvl, vector::PrintPunctuation::NoPunctuation);
rewriter.create<vector::PrintOp>(loc, rewriter.getStringAttr("] : "));
Value crd = nullptr;
// TODO: eliminates ToCoordinateBufferOp!
// For COO AoS storage, we want to print a single, linear view of
// the full coordinate storage at this level. For any other storage,
// we show the coordinate storage for every indivual level.
if (stt.getAoSCOOStart() == l)
crd = rewriter.create<ToCoordinatesBufferOp>(loc, tensor);
else
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/ExecutionEngine/SparseTensor/Storage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETPOSITIONS)
MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETCOORDINATES)
#undef IMPL_GETCOORDINATES

#define IMPL_GETCOORDINATESBUFFER(CNAME, C) \
void SparseTensorStorageBase::getCoordinatesBuffer(std::vector<C> **, \
uint64_t) { \
FATAL_PIV("getCoordinatesBuffer" #CNAME); \
}
MLIR_SPARSETENSOR_FOREVERY_FIXED_O(IMPL_GETCOORDINATESBUFFER)
#undef IMPL_GETCOORDINATESBUFFER

#define IMPL_GETVALUES(VNAME, V) \
void SparseTensorStorageBase::getValues(std::vector<V> **) { \
FATAL_PIV("getValues" #VNAME); \
Expand Down
7 changes: 7 additions & 0 deletions mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_SPARSEVALUES)
assert(v); \
aliasIntoMemref(v->size(), v->data(), *ref); \
}

#define IMPL_SPARSEPOSITIONS(PNAME, P) \
IMPL_GETOVERHEAD(sparsePositions##PNAME, P, getPositions)
MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSEPOSITIONS)
Expand All @@ -320,6 +321,12 @@ MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSEPOSITIONS)
IMPL_GETOVERHEAD(sparseCoordinates##CNAME, C, getCoordinates)
MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATES)
#undef IMPL_SPARSECOORDINATES

#define IMPL_SPARSECOORDINATESBUFFER(CNAME, C) \
IMPL_GETOVERHEAD(sparseCoordinatesBuffer##CNAME, C, getCoordinatesBuffer)
MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSECOORDINATESBUFFER)
#undef IMPL_SPARSECOORDINATESBUFFER

#undef IMPL_GETOVERHEAD

#define IMPL_LEXINSERT(VNAME, V) \
Expand Down
57 changes: 45 additions & 12 deletions mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_print.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@
)
}>

#COOAoS = #sparse_tensor.encoding<{
map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton)
}>

#COOSoA = #sparse_tensor.encoding<{
map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa))
}>

module {

//
Expand Down Expand Up @@ -161,6 +169,8 @@ module {
%h = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSCC>
%i = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSR0>
%j = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #BSC0>
%AoS = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #COOAoS>
%SoA = sparse_tensor.convert %x : tensor<4x8xi32> to tensor<4x8xi32, #COOSoA>

// CHECK-NEXT: ---- Sparse Tensor ----
// CHECK-NEXT: nse = 5
Expand Down Expand Up @@ -274,19 +284,42 @@ module {
// CHECK-NEXT: ----
sparse_tensor.print %j : tensor<4x8xi32, #BSC0>

// CHECK-NEXT: ---- Sparse Tensor ----
// CHECK-NEXT: nse = 5
// CHECK-NEXT: dim = ( 4, 8 )
// CHECK-NEXT: lvl = ( 4, 8 )
// CHECK-NEXT: pos[0] : ( 0, 5,
// CHECK-NEXT: crd[0] : ( 0, 0, 0, 2, 3, 2, 3, 3, 3, 5,
// CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
// CHECK-NEXT: ----
sparse_tensor.print %AoS : tensor<4x8xi32, #COOAoS>

// CHECK-NEXT: ---- Sparse Tensor ----
// CHECK-NEXT: nse = 5
// CHECK-NEXT: dim = ( 4, 8 )
// CHECK-NEXT: lvl = ( 4, 8 )
// CHECK-NEXT: pos[0] : ( 0, 5,
// CHECK-NEXT: crd[0] : ( 0, 0, 3, 3, 3,
// CHECK-NEXT: crd[1] : ( 0, 2, 2, 3, 5,
// CHECK-NEXT: values : ( 1, 2, 3, 4, 5,
// CHECK-NEXT: ----
sparse_tensor.print %SoA : tensor<4x8xi32, #COOSoA>

// Release the resources.
bufferization.dealloc_tensor %XO : tensor<4x8xi32, #AllDense>
bufferization.dealloc_tensor %XT : tensor<4x8xi32, #AllDenseT>
bufferization.dealloc_tensor %a : tensor<4x8xi32, #CSR>
bufferization.dealloc_tensor %b : tensor<4x8xi32, #DCSR>
bufferization.dealloc_tensor %c : tensor<4x8xi32, #CSC>
bufferization.dealloc_tensor %d : tensor<4x8xi32, #DCSC>
bufferization.dealloc_tensor %e : tensor<4x8xi32, #BSR>
bufferization.dealloc_tensor %f : tensor<4x8xi32, #BSRC>
bufferization.dealloc_tensor %g : tensor<4x8xi32, #BSC>
bufferization.dealloc_tensor %h : tensor<4x8xi32, #BSCC>
bufferization.dealloc_tensor %i : tensor<4x8xi32, #BSR0>
bufferization.dealloc_tensor %j : tensor<4x8xi32, #BSC0>
bufferization.dealloc_tensor %XO : tensor<4x8xi32, #AllDense>
bufferization.dealloc_tensor %XT : tensor<4x8xi32, #AllDenseT>
bufferization.dealloc_tensor %a : tensor<4x8xi32, #CSR>
bufferization.dealloc_tensor %b : tensor<4x8xi32, #DCSR>
bufferization.dealloc_tensor %c : tensor<4x8xi32, #CSC>
bufferization.dealloc_tensor %d : tensor<4x8xi32, #DCSC>
bufferization.dealloc_tensor %e : tensor<4x8xi32, #BSR>
bufferization.dealloc_tensor %f : tensor<4x8xi32, #BSRC>
bufferization.dealloc_tensor %g : tensor<4x8xi32, #BSC>
bufferization.dealloc_tensor %h : tensor<4x8xi32, #BSCC>
bufferization.dealloc_tensor %i : tensor<4x8xi32, #BSR0>
bufferization.dealloc_tensor %j : tensor<4x8xi32, #BSC0>
bufferization.dealloc_tensor %AoS : tensor<4x8xi32, #COOAoS>
bufferization.dealloc_tensor %SoA : tensor<4x8xi32, #COOSoA>

return
}
Expand Down