22
22
#include " mlir/Dialect/SparseTensor/Transforms/Passes.h"
23
23
#include " mlir/Dialect/StandardOps/IR/Ops.h"
24
24
#include " mlir/Dialect/Tensor/IR/Tensor.h"
25
+ #include " mlir/ExecutionEngine/SparseTensorUtils.h"
25
26
#include " mlir/Transforms/DialectConversion.h"
26
27
27
28
using namespace mlir ;
28
29
using namespace mlir ::sparse_tensor;
29
30
30
31
namespace {
31
32
32
- // / New tensor storage action. Keep these values consistent with
33
- // / the sparse runtime support library.
34
- enum Action : uint32_t {
35
- kEmpty = 0 ,
36
- kFromFile = 1 ,
37
- kFromCOO = 2 ,
38
- kEmptyCOO = 3 ,
39
- kToCOO = 4 ,
40
- kToIter = 5
41
- };
42
-
43
33
// ===----------------------------------------------------------------------===//
44
34
// Helper methods.
45
35
// ===----------------------------------------------------------------------===//
46
36
47
- // / Returns internal type encoding for primary storage. Keep these
48
- // / values consistent with the sparse runtime support library.
49
- static uint32_t getPrimaryTypeEncoding (Type tp) {
50
- if (tp.isF64 ())
51
- return 1 ;
52
- if (tp.isF32 ())
53
- return 2 ;
54
- if (tp.isInteger (64 ))
55
- return 3 ;
56
- if (tp.isInteger (32 ))
57
- return 4 ;
58
- if (tp.isInteger (16 ))
59
- return 5 ;
60
- if (tp.isInteger (8 ))
61
- return 6 ;
62
- return 0 ;
63
- }
64
-
65
- // / Returns internal type encoding for overhead storage. Keep these
66
- // / values consistent with the sparse runtime support library.
67
- static uint32_t getOverheadTypeEncoding (unsigned width) {
68
- switch (width) {
69
- default :
70
- return 1 ;
71
- case 32 :
72
- return 2 ;
73
- case 16 :
74
- return 3 ;
75
- case 8 :
76
- return 4 ;
77
- }
78
- }
79
-
80
- // / Returns internal dimension level type encoding. Keep these
81
- // / values consistent with the sparse runtime support library.
82
- static uint32_t
83
- getDimLevelTypeEncoding (SparseTensorEncodingAttr::DimLevelType dlt) {
84
- switch (dlt) {
85
- case SparseTensorEncodingAttr::DimLevelType::Dense:
86
- return 0 ;
87
- case SparseTensorEncodingAttr::DimLevelType::Compressed:
88
- return 1 ;
89
- case SparseTensorEncodingAttr::DimLevelType::Singleton:
90
- return 2 ;
91
- }
92
- llvm_unreachable (" Unknown SparseTensorEncodingAttr::DimLevelType" );
93
- }
94
-
95
37
// / Generates a constant zero of the given type.
96
38
inline static Value constantZero (ConversionPatternRewriter &rewriter,
97
39
Location loc, Type t) {
@@ -116,6 +58,75 @@ inline static Value constantI8(ConversionPatternRewriter &rewriter,
116
58
return rewriter.create <arith::ConstantIntOp>(loc, i, 8 );
117
59
}
118
60
61
+ // / Generates a constant of the given `Action`.
62
+ static Value constantAction (ConversionPatternRewriter &rewriter, Location loc,
63
+ Action action) {
64
+ return constantI32 (rewriter, loc, static_cast <uint32_t >(action));
65
+ }
66
+
67
+ // / Generates a constant of the internal type encoding for overhead storage.
68
+ static Value constantOverheadTypeEncoding (ConversionPatternRewriter &rewriter,
69
+ Location loc, unsigned width) {
70
+ OverheadType sec;
71
+ switch (width) {
72
+ default :
73
+ sec = OverheadType::kU64 ;
74
+ break ;
75
+ case 32 :
76
+ sec = OverheadType::kU32 ;
77
+ break ;
78
+ case 16 :
79
+ sec = OverheadType::kU16 ;
80
+ break ;
81
+ case 8 :
82
+ sec = OverheadType::kU8 ;
83
+ break ;
84
+ }
85
+ return constantI32 (rewriter, loc, static_cast <uint32_t >(sec));
86
+ }
87
+
88
+ // / Generates a constant of the internal type encoding for primary storage.
89
+ static Value constantPrimaryTypeEncoding (ConversionPatternRewriter &rewriter,
90
+ Location loc, Type tp) {
91
+ PrimaryType primary;
92
+ if (tp.isF64 ())
93
+ primary = PrimaryType::kF64 ;
94
+ else if (tp.isF32 ())
95
+ primary = PrimaryType::kF32 ;
96
+ else if (tp.isInteger (64 ))
97
+ primary = PrimaryType::kI64 ;
98
+ else if (tp.isInteger (32 ))
99
+ primary = PrimaryType::kI32 ;
100
+ else if (tp.isInteger (16 ))
101
+ primary = PrimaryType::kI16 ;
102
+ else if (tp.isInteger (8 ))
103
+ primary = PrimaryType::kI8 ;
104
+ else
105
+ llvm_unreachable (" Unknown element type" );
106
+ return constantI32 (rewriter, loc, static_cast <uint32_t >(primary));
107
+ }
108
+
109
+ // / Generates a constant of the internal dimension level type encoding.
110
+ static Value
111
+ constantDimLevelTypeEncoding (ConversionPatternRewriter &rewriter, Location loc,
112
+ SparseTensorEncodingAttr::DimLevelType dlt) {
113
+ DimLevelType dlt2;
114
+ switch (dlt) {
115
+ case SparseTensorEncodingAttr::DimLevelType::Dense:
116
+ dlt2 = DimLevelType::kDense ;
117
+ break ;
118
+ case SparseTensorEncodingAttr::DimLevelType::Compressed:
119
+ dlt2 = DimLevelType::kCompressed ;
120
+ break ;
121
+ case SparseTensorEncodingAttr::DimLevelType::Singleton:
122
+ dlt2 = DimLevelType::kSingleton ;
123
+ break ;
124
+ default :
125
+ llvm_unreachable (" Unknown SparseTensorEncodingAttr::DimLevelType" );
126
+ }
127
+ return constantI8 (rewriter, loc, static_cast <uint8_t >(dlt2));
128
+ }
129
+
119
130
// / Returns a function reference (first hit also inserts into module). Sets
120
131
// / the "_emit_c_interface" on the function declaration when requested,
121
132
// / so that LLVM lowering generates a wrapper function that takes care
@@ -238,15 +249,15 @@ static Value genBuffer(ConversionPatternRewriter &rewriter, Location loc,
238
249
// / computation.
239
250
static void newParams (ConversionPatternRewriter &rewriter,
240
251
SmallVector<Value, 8 > ¶ms, Operation *op,
241
- SparseTensorEncodingAttr &enc, uint32_t action,
252
+ SparseTensorEncodingAttr &enc, Action action,
242
253
ValueRange szs, Value ptr = Value()) {
243
254
Location loc = op->getLoc ();
244
255
ArrayRef<SparseTensorEncodingAttr::DimLevelType> dlt = enc.getDimLevelType ();
245
256
unsigned sz = dlt.size ();
246
257
// Sparsity annotations.
247
258
SmallVector<Value, 4 > attrs;
248
259
for (unsigned i = 0 ; i < sz; i++)
249
- attrs.push_back (constantI8 (rewriter, loc, getDimLevelTypeEncoding ( dlt[i]) ));
260
+ attrs.push_back (constantDimLevelTypeEncoding (rewriter, loc, dlt[i]));
250
261
params.push_back (genBuffer (rewriter, loc, attrs));
251
262
// Dimension sizes array of the enveloping tensor. Useful for either
252
263
// verification of external data, or for construction of internal data.
@@ -268,18 +279,17 @@ static void newParams(ConversionPatternRewriter &rewriter,
268
279
params.push_back (genBuffer (rewriter, loc, rev));
269
280
// Secondary and primary types encoding.
270
281
ShapedType resType = op->getResult (0 ).getType ().cast <ShapedType>();
271
- uint32_t secPtr = getOverheadTypeEncoding (enc.getPointerBitWidth ());
272
- uint32_t secInd = getOverheadTypeEncoding (enc.getIndexBitWidth ());
273
- uint32_t primary = getPrimaryTypeEncoding (resType.getElementType ());
274
- assert (primary);
275
- params.push_back (constantI32 (rewriter, loc, secPtr));
276
- params.push_back (constantI32 (rewriter, loc, secInd));
277
- params.push_back (constantI32 (rewriter, loc, primary));
282
+ params.push_back (
283
+ constantOverheadTypeEncoding (rewriter, loc, enc.getPointerBitWidth ()));
284
+ params.push_back (
285
+ constantOverheadTypeEncoding (rewriter, loc, enc.getIndexBitWidth ()));
286
+ params.push_back (
287
+ constantPrimaryTypeEncoding (rewriter, loc, resType.getElementType ()));
278
288
// User action and pointer.
279
289
Type pTp = LLVM::LLVMPointerType::get (rewriter.getI8Type ());
280
290
if (!ptr)
281
291
ptr = rewriter.create <LLVM::NullOp>(loc, pTp);
282
- params.push_back (constantI32 (rewriter, loc, action));
292
+ params.push_back (constantAction (rewriter, loc, action));
283
293
params.push_back (ptr);
284
294
}
285
295
@@ -530,7 +540,7 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
530
540
SmallVector<Value, 8 > params;
531
541
sizesFromType (rewriter, sizes, op.getLoc (), resType.cast <ShapedType>());
532
542
Value ptr = adaptor.getOperands ()[0 ];
533
- newParams (rewriter, params, op, enc, kFromFile , sizes, ptr);
543
+ newParams (rewriter, params, op, enc, Action:: kFromFile , sizes, ptr);
534
544
rewriter.replaceOp (op, genNewCall (rewriter, op, params));
535
545
return success ();
536
546
}
@@ -549,7 +559,7 @@ class SparseTensorInitConverter : public OpConversionPattern<InitOp> {
549
559
// Generate the call to construct empty tensor. The sizes are
550
560
// explicitly defined by the arguments to the init operator.
551
561
SmallVector<Value, 8 > params;
552
- newParams (rewriter, params, op, enc, kEmpty , adaptor.getOperands ());
562
+ newParams (rewriter, params, op, enc, Action:: kEmpty , adaptor.getOperands ());
553
563
rewriter.replaceOp (op, genNewCall (rewriter, op, params));
554
564
return success ();
555
565
}
@@ -588,13 +598,13 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
588
598
auto enc = SparseTensorEncodingAttr::get (
589
599
op->getContext (), encDst.getDimLevelType (), encDst.getDimOrdering (),
590
600
encSrc.getPointerBitWidth (), encSrc.getIndexBitWidth ());
591
- newParams (rewriter, params, op, enc, kToCOO , sizes, src);
601
+ newParams (rewriter, params, op, enc, Action:: kToCOO , sizes, src);
592
602
Value coo = genNewCall (rewriter, op, params);
593
- params[3 ] = constantI32 (
594
- rewriter, loc, getOverheadTypeEncoding ( encDst.getPointerBitWidth () ));
595
- params[4 ] = constantI32 (
596
- rewriter, loc, getOverheadTypeEncoding ( encDst.getIndexBitWidth () ));
597
- params[6 ] = constantI32 (rewriter, loc, kFromCOO );
603
+ params[3 ] = constantOverheadTypeEncoding (rewriter, loc,
604
+ encDst.getPointerBitWidth ());
605
+ params[4 ] = constantOverheadTypeEncoding (rewriter, loc,
606
+ encDst.getIndexBitWidth ());
607
+ params[6 ] = constantAction (rewriter, loc, Action:: kFromCOO );
598
608
params[7 ] = coo;
599
609
rewriter.replaceOp (op, genNewCall (rewriter, op, params));
600
610
return success ();
@@ -613,7 +623,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
613
623
Type elemTp = dstTensorTp.getElementType ();
614
624
// Fabricate a no-permutation encoding for newParams().
615
625
// The pointer/index types must be those of `src`.
616
- // The dimLevelTypes aren't actually used by kToIter .
626
+ // The dimLevelTypes aren't actually used by Action::kToIterator .
617
627
encDst = SparseTensorEncodingAttr::get (
618
628
op->getContext (),
619
629
SmallVector<SparseTensorEncodingAttr::DimLevelType>(
@@ -622,7 +632,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
622
632
SmallVector<Value, 4 > sizes;
623
633
SmallVector<Value, 8 > params;
624
634
sizesFromPtr (rewriter, sizes, op, encSrc, srcTensorTp, src);
625
- newParams (rewriter, params, op, encDst, kToIter , sizes, src);
635
+ newParams (rewriter, params, op, encDst, Action:: kToIterator , sizes, src);
626
636
Value iter = genNewCall (rewriter, op, params);
627
637
Value ind = genAlloca (rewriter, loc, rank, rewriter.getIndexType ());
628
638
Value elemPtr = genAllocaScalar (rewriter, loc, elemTp);
@@ -677,7 +687,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
677
687
SmallVector<Value, 4 > sizes;
678
688
SmallVector<Value, 8 > params;
679
689
sizesFromSrc (rewriter, sizes, loc, src);
680
- newParams (rewriter, params, op, encDst, kEmptyCOO , sizes);
690
+ newParams (rewriter, params, op, encDst, Action:: kEmptyCOO , sizes);
681
691
Value ptr = genNewCall (rewriter, op, params);
682
692
Value ind = genAlloca (rewriter, loc, rank, rewriter.getIndexType ());
683
693
Value perm = params[2 ];
@@ -718,7 +728,7 @@ class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
718
728
return {};
719
729
});
720
730
// Final call to construct sparse tensor storage.
721
- params[6 ] = constantI32 (rewriter, loc, kFromCOO );
731
+ params[6 ] = constantAction (rewriter, loc, Action:: kFromCOO );
722
732
params[7 ] = ptr;
723
733
rewriter.replaceOp (op, genNewCall (rewriter, op, params));
724
734
return success ();
0 commit comments