Skip to content

Commit 845561e

Browse files
committed
[mlir][sparse] Factoring magic numbers into a header
Addresses https://bugs.llvm.org/show_bug.cgi?id=52303 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D112962
1 parent f57d0e2 commit 845561e

File tree

4 files changed

+260
-175
lines changed

4 files changed

+260
-175
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
//===- SparseTensorUtils.h - Enums shared with the runtime ------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This header file defines several enums shared between
10+
// Transforms/SparseTensorConversion.cpp and ExecutionEngine/SparseUtils.cpp
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#ifndef MLIR_EXECUTIONENGINE_SPARSETENSORUTILS_H_
15+
#define MLIR_EXECUTIONENGINE_SPARSETENSORUTILS_H_
16+
17+
#include <cinttypes>
18+
19+
extern "C" {
20+
21+
/// Encoding of the elemental type, for "overloading" @newSparseTensor.
22+
enum class OverheadType : uint32_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 };
23+
24+
/// Encoding of the elemental type, for "overloading" @newSparseTensor.
25+
enum class PrimaryType : uint32_t {
26+
kF64 = 1,
27+
kF32 = 2,
28+
kI64 = 3,
29+
kI32 = 4,
30+
kI16 = 5,
31+
kI8 = 6
32+
};
33+
34+
/// The actions performed by @newSparseTensor.
35+
enum class Action : uint32_t {
36+
kEmpty = 0,
37+
kFromFile = 1,
38+
kFromCOO = 2,
39+
kEmptyCOO = 3,
40+
kToCOO = 4,
41+
kToIterator = 5
42+
};
43+
44+
/// This enum mimics `SparseTensorEncodingAttr::DimLevelType` for
45+
/// breaking dependency cycles. `SparseTensorEncodingAttr::DimLevelType`
46+
/// is the source of truth and this enum should be kept consistent with it.
47+
enum class DimLevelType : uint8_t {
48+
kDense = 0,
49+
kCompressed = 1,
50+
kSingleton = 2
51+
};
52+
53+
} // extern "C"
54+
55+
#endif // MLIR_EXECUTIONENGINE_SPARSETENSORUTILS_H_

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Lines changed: 91 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -22,76 +22,18 @@
2222
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
2323
#include "mlir/Dialect/StandardOps/IR/Ops.h"
2424
#include "mlir/Dialect/Tensor/IR/Tensor.h"
25+
#include "mlir/ExecutionEngine/SparseTensorUtils.h"
2526
#include "mlir/Transforms/DialectConversion.h"
2627

2728
using namespace mlir;
2829
using namespace mlir::sparse_tensor;
2930

3031
namespace {
3132

32-
/// New tensor storage action. Keep these values consistent with
33-
/// the sparse runtime support library.
34-
enum Action : uint32_t {
35-
kEmpty = 0,
36-
kFromFile = 1,
37-
kFromCOO = 2,
38-
kEmptyCOO = 3,
39-
kToCOO = 4,
40-
kToIter = 5
41-
};
42-
4333
//===----------------------------------------------------------------------===//
4434
// Helper methods.
4535
//===----------------------------------------------------------------------===//
4636

47-
/// Returns internal type encoding for primary storage. Keep these
48-
/// values consistent with the sparse runtime support library.
49-
static uint32_t getPrimaryTypeEncoding(Type tp) {
50-
if (tp.isF64())
51-
return 1;
52-
if (tp.isF32())
53-
return 2;
54-
if (tp.isInteger(64))
55-
return 3;
56-
if (tp.isInteger(32))
57-
return 4;
58-
if (tp.isInteger(16))
59-
return 5;
60-
if (tp.isInteger(8))
61-
return 6;
62-
return 0;
63-
}
64-
65-
/// Returns internal type encoding for overhead storage. Keep these
66-
/// values consistent with the sparse runtime support library.
67-
static uint32_t getOverheadTypeEncoding(unsigned width) {
68-
switch (width) {
69-
default:
70-
return 1;
71-
case 32:
72-
return 2;
73-
case 16:
74-
return 3;
75-
case 8:
76-
return 4;
77-
}
78-
}
79-
80-
/// Returns internal dimension level type encoding. Keep these
81-
/// values consistent with the sparse runtime support library.
82-
static uint32_t
83-
getDimLevelTypeEncoding(SparseTensorEncodingAttr::DimLevelType dlt) {
84-
switch (dlt) {
85-
case SparseTensorEncodingAttr::DimLevelType::Dense:
86-
return 0;
87-
case SparseTensorEncodingAttr::DimLevelType::Compressed:
88-
return 1;
89-
case SparseTensorEncodingAttr::DimLevelType::Singleton:
90-
return 2;
91-
}
92-
llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
93-
}
94-
9537
/// Generates a constant zero of the given type.
9638
inline static Value constantZero(ConversionPatternRewriter &rewriter,
9739
Location loc, Type t) {
@@ -116,6 +58,75 @@ inline static Value constantI8(ConversionPatternRewriter &rewriter,
11658
return rewriter.create<arith::ConstantIntOp>(loc, i, 8);
11759
}
11860

61+
/// Generates a constant of the given `Action`.
62+
static Value constantAction(ConversionPatternRewriter &rewriter, Location loc,
63+
Action action) {
64+
return constantI32(rewriter, loc, static_cast<uint32_t>(action));
65+
}
66+
67+
/// Generates a constant of the internal type encoding for overhead storage.
68+
static Value constantOverheadTypeEncoding(ConversionPatternRewriter &rewriter,
69+
Location loc, unsigned width) {
70+
OverheadType sec;
71+
switch (width) {
72+
default:
73+
sec = OverheadType::kU64;
74+
break;
75+
case 32:
76+
sec = OverheadType::kU32;
77+
break;
78+
case 16:
79+
sec = OverheadType::kU16;
80+
break;
81+
case 8:
82+
sec = OverheadType::kU8;
83+
break;
84+
}
85+
return constantI32(rewriter, loc, static_cast<uint32_t>(sec));
86+
}
87+
88+
/// Generates a constant of the internal type encoding for primary storage.
89+
static Value constantPrimaryTypeEncoding(ConversionPatternRewriter &rewriter,
90+
Location loc, Type tp) {
91+
PrimaryType primary;
92+
if (tp.isF64())
93+
primary = PrimaryType::kF64;
94+
else if (tp.isF32())
95+
primary = PrimaryType::kF32;
96+
else if (tp.isInteger(64))
97+
primary = PrimaryType::kI64;
98+
else if (tp.isInteger(32))
99+
primary = PrimaryType::kI32;
100+
else if (tp.isInteger(16))
101+
primary = PrimaryType::kI16;
102+
else if (tp.isInteger(8))
103+
primary = PrimaryType::kI8;
104+
else
105+
llvm_unreachable("Unknown element type");
106+
return constantI32(rewriter, loc, static_cast<uint32_t>(primary));
107+
}
108+
109+
/// Generates a constant of the internal dimension level type encoding.
110+
static Value
111+
constantDimLevelTypeEncoding(ConversionPatternRewriter &rewriter, Location loc,
112+
SparseTensorEncodingAttr::DimLevelType dlt) {
113+
DimLevelType dlt2;
114+
switch (dlt) {
115+
case SparseTensorEncodingAttr::DimLevelType::Dense:
116+
dlt2 = DimLevelType::kDense;
117+
break;
118+
case SparseTensorEncodingAttr::DimLevelType::Compressed:
119+
dlt2 = DimLevelType::kCompressed;
120+
break;
121+
case SparseTensorEncodingAttr::DimLevelType::Singleton:
122+
dlt2 = DimLevelType::kSingleton;
123+
break;
124+
default:
125+
llvm_unreachable("Unknown SparseTensorEncodingAttr::DimLevelType");
126+
}
127+
return constantI8(rewriter, loc, static_cast<uint8_t>(dlt2));
128+
}
129+
119130
/// Returns a function reference (first hit also inserts into module). Sets
120131
/// the "_emit_c_interface" on the function declaration when requested,
121132
/// so that LLVM lowering generates a wrapper function that takes care
@@ -238,15 +249,15 @@ static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc,
238249
/// computation.
239250
static void newParams(ConversionPatternRewriter &rewriter,
240251
SmallVector<Value, 8> &params, Operation *op,
241-
SparseTensorEncodingAttr &enc, uint32_t action,
252+
SparseTensorEncodingAttr &enc, Action action,
242253
ValueRange szs, Value ptr = Value()) {
243254
Location loc = op->getLoc();
244255
ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType();
245256
unsigned sz = dlt.size();
246257
// Sparsity annotations.
247258
SmallVector<Value, 4> attrs;
248259
for (unsigned i = 0; i < sz; i++)
249-
attrs.push_back(constantI8(rewriter, loc, getDimLevelTypeEncoding(dlt[i])));
260+
attrs.push_back(constantDimLevelTypeEncoding(rewriter, loc, dlt[i]));
250261
params.push_back(genBuffer(rewriter, loc, attrs));
251262
// Dimension sizes array of the enveloping tensor. Useful for either
252263
// verification of external data, or for construction of internal data.
@@ -268,18 +279,17 @@ static void newParams(ConversionPatternRewriter &rewriter,
268279
params.push_back(genBuffer(rewriter, loc, rev));
269280
// Secondary and primary types encoding.
270281
ShapedType resType = op->getResult(0).getType().cast<ShapedType>();
271-
uint32_t secPtr = getOverheadTypeEncoding(enc.getPointerBitWidth());
272-
uint32_t secInd = getOverheadTypeEncoding(enc.getIndexBitWidth());
273-
uint32_t primary = getPrimaryTypeEncoding(resType.getElementType());
274-
assert(primary);
275-
params.push_back(constantI32(rewriter, loc, secPtr));
276-
params.push_back(constantI32(rewriter, loc, secInd));
277-
params.push_back(constantI32(rewriter, loc, primary));
282+
params.push_back(
283+
constantOverheadTypeEncoding(rewriter, loc, enc.getPointerBitWidth()));
284+
params.push_back(
285+
constantOverheadTypeEncoding(rewriter, loc, enc.getIndexBitWidth()));
286+
params.push_back(
287+
constantPrimaryTypeEncoding(rewriter, loc, resType.getElementType()));
278288
// User action and pointer.
279289
Type pTp = LLVM::LLVMPointerType::get(rewriter.getI8Type());
280290
if (!ptr)
281291
ptr = rewriter.create<LLVM::NullOp>(loc, pTp);
282-
params.push_back(constantI32(rewriter, loc, action));
292+
params.push_back(constantAction(rewriter, loc, action));
283293
params.push_back(ptr);
284294
}
285295

@@ -530,7 +540,7 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
530540
SmallVector<Value, 8> params;
531541
sizesFromType(rewriter, sizes, op.getLoc(), resType.cast<ShapedType>());
532542
Value ptr = adaptor.getOperands()[0];
533-
newParams(rewriter, params, op, enc, kFromFile, sizes, ptr);
543+
newParams(rewriter, params, op, enc, Action::kFromFile, sizes, ptr);
534544
rewriter.replaceOp(op, genNewCall(rewriter, op, params));
535545
return success();
536546
}
@@ -549,7 +559,7 @@ class SparseTensorInitConverter : public OpConversionPattern<InitOp> {
549559
// Generate the call to construct empty tensor. The sizes are
550560
// explicitly defined by the arguments to the init operator.
551561
SmallVector<Value, 8> params;
552-
newParams(rewriter, params, op, enc, kEmpty, adaptor.getOperands());
562+
newParams(rewriter, params, op, enc, Action::kEmpty, adaptor.getOperands());
553563
rewriter.replaceOp(op, genNewCall(rewriter, op, params));
554564
return success();
555565
}
@@ -588,13 +598,13 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
588598
auto enc = SparseTensorEncodingAttr::get(
589599
op->getContext(), encDst.getDimLevelType(), encDst.getDimOrdering(),
590600
encSrc.getPointerBitWidth(), encSrc.getIndexBitWidth());
591-
newParams(rewriter, params, op, enc, kToCOO, sizes, src);
601+
newParams(rewriter, params, op, enc, Action::kToCOO, sizes, src);
592602
Value coo = genNewCall(rewriter, op, params);
593-
params[3] = constantI32(
594-
rewriter, loc, getOverheadTypeEncoding(encDst.getPointerBitWidth()));
595-
params[4] = constantI32(
596-
rewriter, loc, getOverheadTypeEncoding(encDst.getIndexBitWidth()));
597-
params[6] = constantI32(rewriter, loc, kFromCOO);
603+
params[3] = constantOverheadTypeEncoding(rewriter, loc,
604+
encDst.getPointerBitWidth());
605+
params[4] = constantOverheadTypeEncoding(rewriter, loc,
606+
encDst.getIndexBitWidth());
607+
params[6] = constantAction(rewriter, loc, Action::kFromCOO);
598608
params[7] = coo;
599609
rewriter.replaceOp(op, genNewCall(rewriter, op, params));
600610
return success();
@@ -613,7 +623,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
613623
Type elemTp = dstTensorTp.getElementType();
614624
// Fabricate a no-permutation encoding for newParams().
615625
// The pointer/index types must be those of `src`.
616-
// The dimLevelTypes aren't actually used by kToIter.
626+
// The dimLevelTypes aren't actually used by Action::kToIterator.
617627
encDst = SparseTensorEncodingAttr::get(
618628
op->getContext(),
619629
SmallVector<SparseTensorEncodingAttr::DimLevelType>(
@@ -622,7 +632,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
622632
SmallVector<Value, 4> sizes;
623633
SmallVector<Value, 8> params;
624634
sizesFromPtr(rewriter, sizes, op, encSrc, srcTensorTp, src);
625-
newParams(rewriter, params, op, encDst, kToIter, sizes, src);
635+
newParams(rewriter, params, op, encDst, Action::kToIterator, sizes, src);
626636
Value iter = genNewCall(rewriter, op, params);
627637
Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
628638
Value elemPtr = genAllocaScalar(rewriter, loc, elemTp);
@@ -677,7 +687,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
677687
SmallVector<Value, 4> sizes;
678688
SmallVector<Value, 8> params;
679689
sizesFromSrc(rewriter, sizes, loc, src);
680-
newParams(rewriter, params, op, encDst, kEmptyCOO, sizes);
690+
newParams(rewriter, params, op, encDst, Action::kEmptyCOO, sizes);
681691
Value ptr = genNewCall(rewriter, op, params);
682692
Value ind = genAlloca(rewriter, loc, rank, rewriter.getIndexType());
683693
Value perm = params[2];
@@ -718,7 +728,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
718728
return {};
719729
});
720730
// Final call to construct sparse tensor storage.
721-
params[6] = constantI32(rewriter, loc, kFromCOO);
731+
params[6] = constantAction(rewriter, loc, Action::kFromCOO);
722732
params[7] = ptr;
723733
rewriter.replaceOp(op, genNewCall(rewriter, op, params));
724734
return success();

0 commit comments

Comments
 (0)