@@ -22,13 +22,9 @@ using namespace sparse_tensor;
22
22
// Helper methods.
23
23
// ===----------------------------------------------------------------------===//
24
24
25
- // TODO: reuse StorageLayout::foreachField?
26
-
27
- // TODO: we need COO AoS and SoA
28
-
29
25
// Convert type range to new types range, with sparse tensors externalized.
30
- void convTypes (TypeRange types, SmallVectorImpl<Type> &convTypes,
31
- SmallVectorImpl<Type> *extraTypes = nullptr ) {
26
+ static void convTypes (TypeRange types, SmallVectorImpl<Type> &convTypes,
27
+ SmallVectorImpl<Type> *extraTypes = nullptr ) {
32
28
for (auto type : types) {
33
29
// All "dense" data passes through unmodified.
34
30
if (!getSparseTensorEncoding (type)) {
@@ -42,29 +38,30 @@ void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
42
38
convTypes.push_back (vtp);
43
39
if (extraTypes)
44
40
extraTypes->push_back (vtp);
45
- // Convert the external representations of the pos/crd arrays.
46
- for (Level lvl = 0 , lvlRank = stt.getLvlRank (); lvl < lvlRank; lvl++) {
47
- const auto lt = stt.getLvlType (lvl);
48
- if (isCompressedLT (lt) || isLooseCompressedLT (lt)) {
49
- auto ptp = RankedTensorType::get (shape, stt.getPosType ());
50
- auto ctp = RankedTensorType::get (shape, stt.getCrdType ());
51
- convTypes.push_back (ptp);
52
- convTypes.push_back (ctp);
53
- if (extraTypes) {
54
- extraTypes->push_back (ptp);
55
- extraTypes->push_back (ctp);
56
- }
57
- } else {
58
- assert (isDenseLT (lt)); // TODO: handle other cases
41
+
42
+ // Convert the external representation of the position/coordinate array.
43
+ foreachFieldAndTypeInSparseTensor (stt, [&convTypes, extraTypes](
44
+ Type t, FieldIndex,
45
+ SparseTensorFieldKind kind,
46
+ Level, LevelType) {
47
+ if (kind == SparseTensorFieldKind::CrdMemRef ||
48
+ kind == SparseTensorFieldKind::PosMemRef) {
49
+ ShapedType st = t.cast <ShapedType>();
50
+ auto rtp = RankedTensorType::get (st.getShape (), st.getElementType ());
51
+ convTypes.push_back (rtp);
52
+ if (extraTypes)
53
+ extraTypes->push_back (rtp);
59
54
}
60
- }
55
+ return true ;
56
+ });
61
57
}
62
58
}
63
59
64
60
// Convert input and output values to [dis]assemble ops for sparse tensors.
65
- void convVals (OpBuilder &builder, Location loc, TypeRange types,
66
- ValueRange fromVals, ValueRange extraVals,
67
- SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn) {
61
+ static void convVals (OpBuilder &builder, Location loc, TypeRange types,
62
+ ValueRange fromVals, ValueRange extraVals,
63
+ SmallVectorImpl<Value> &toVals, unsigned extra,
64
+ bool isIn) {
68
65
unsigned idx = 0 ;
69
66
for (auto type : types) {
70
67
// All "dense" data passes through unmodified.
@@ -85,29 +82,28 @@ void convVals(OpBuilder &builder, Location loc, TypeRange types,
85
82
if (!isIn) {
86
83
inputs.push_back (extraVals[extra++]);
87
84
retTypes.push_back (RankedTensorType::get (shape, stt.getElementType ()));
88
- cntTypes.push_back (builder.getIndexType ());
85
+ cntTypes.push_back (builder.getIndexType ()); // nnz
89
86
}
87
+
90
88
// Collect the external representations of the pos/crd arrays.
91
- for (Level lvl = 0 , lvlRank = stt.getLvlRank (); lvl < lvlRank; lvl++) {
92
- const auto lt = stt.getLvlType (lvl);
93
- if (isCompressedLT (lt) || isLooseCompressedLT (lt)) {
89
+ foreachFieldAndTypeInSparseTensor (stt, [&, isIn](Type t, FieldIndex,
90
+ SparseTensorFieldKind kind,
91
+ Level, LevelType) {
92
+ if (kind == SparseTensorFieldKind::CrdMemRef ||
93
+ kind == SparseTensorFieldKind::PosMemRef) {
94
94
if (isIn) {
95
95
inputs.push_back (fromVals[idx++]);
96
- inputs.push_back (fromVals[idx++]);
97
96
} else {
98
- Type pTp = stt.getPosType ();
99
- Type cTp = stt.getCrdType ();
100
- inputs.push_back (extraVals[extra++]);
97
+ ShapedType st = t.cast <ShapedType>();
98
+ auto rtp = RankedTensorType::get (st.getShape (), st.getElementType ());
101
99
inputs.push_back (extraVals[extra++]);
102
- retTypes.push_back (RankedTensorType::get (shape, pTp));
103
- retTypes.push_back (RankedTensorType::get (shape, cTp));
104
- cntTypes.push_back (pTp);
105
- cntTypes.push_back (cTp);
100
+ retTypes.push_back (rtp);
101
+ cntTypes.push_back (rtp.getElementType ());
106
102
}
107
- } else {
108
- assert (isDenseLT (lt)); // TODO: handle other cases
109
103
}
110
- }
104
+ return true ;
105
+ });
106
+
111
107
if (isIn) {
112
108
// Assemble multiple inputs into a single sparse tensor.
113
109
auto a = builder.create <sparse_tensor::AssembleOp>(loc, rtp, inputs);
0 commit comments