@@ -23,20 +23,19 @@ namespace mlir {
23
23
namespace sparse_tensor {
24
24
25
25
// / An element of a sparse tensor in coordinate-scheme representation
26
- // / (i.e., a pair of coordinates and value). For example, a rank-1
26
+ // / (i.e., a pair of coordinates and value). For example, a rank-1
27
27
// / vector element would look like
28
28
// / ({i}, a[i])
29
29
// / and a rank-5 tensor element would look like
30
30
// / ({i,j,k,l,m}, a[i,j,k,l,m])
31
31
// /
32
- // / The coordinates are represented as a (non-owning) pointer into
33
- // / a shared pool of coordinates, rather than being stored directly in
34
- // / this object. This significantly improves performance because it:
35
- // / (1) reduces the per-element memory footprint, and (2) centralizes
36
- // / the memory management for coordinates. The only downside is that
37
- // / the coordinates themselves cannot be retrieved without knowing the
38
- // / rank of the tensor to which this element belongs (and that rank is
39
- // / not stored in this object).
32
+ // / The coordinates are represented as a (non-owning) pointer into a
33
+ // / shared pool of coordinates, rather than being stored directly in this
34
+ // / object. This significantly improves performance because it reduces the
35
+ // / per-element memory footprint and centralizes the memory management for
36
+ // / coordinates. The only downside is that the coordinates themselves cannot
37
+ // / be retrieved without knowing the rank of the tensor to which this element
38
+ // / belongs (and that rank is not stored in this object).
40
39
template <typename V>
41
40
struct Element final {
42
41
Element (const uint64_t *coords, V val) : coords(coords), value(val){};
@@ -48,10 +47,6 @@ struct Element final {
48
47
template <typename V>
49
48
struct ElementLT final {
50
49
ElementLT (uint64_t rank) : rank(rank) {}
51
-
52
- // / Compares two elements a la `operator<`.
53
- // /
54
- // / Precondition: the elements must both be valid for `rank`.
55
50
bool operator ()(const Element<V> &e1 , const Element<V> &e2 ) const {
56
51
for (uint64_t d = 0 ; d < rank; ++d) {
57
52
if (e1 .coords [d] == e2 .coords [d])
@@ -60,13 +55,10 @@ struct ElementLT final {
60
55
}
61
56
return false ;
62
57
}
63
-
64
58
const uint64_t rank;
65
59
};
66
60
67
- // / The type of callback functions which receive an element. We avoid
68
- // / packaging the coordinates and value together as an `Element` object
69
- // / because this helps keep code somewhat cleaner.
61
+ // / The type of callback functions which receive an element.
70
62
template <typename V>
71
63
using ElementConsumer =
72
64
const std::function<void (const std::vector<uint64_t > &, V)> &;
@@ -89,27 +81,14 @@ class SparseTensorCOO final {
89
81
using size_type = typename vector_type::size_type;
90
82
91
83
// / Constructs a new coordinate-scheme sparse tensor with the given
92
- // / sizes and initial storage capacity.
93
- // /
94
- // / Asserts:
95
- // / * `dimSizes` has nonzero size.
96
- // / * the elements of `dimSizes` are nonzero.
84
+ // / sizes and an optional initial storage capacity.
97
85
explicit SparseTensorCOO (const std::vector<uint64_t > &dimSizes,
98
86
uint64_t capacity = 0 )
99
87
: SparseTensorCOO(dimSizes.size(), dimSizes.data(), capacity) {}
100
88
101
- // TODO: make a class for capturing known-valid sizes (a la PermutationRef),
102
- // so that `SparseTensorStorage::toCOO` can avoid redoing these assertions.
103
- // Also so that we can enforce the asserts *before* copying into `dimSizes`.
104
- //
105
89
// / Constructs a new coordinate-scheme sparse tensor with the given
106
- // / sizes and initial storage capacity.
107
- // /
108
- // / Precondition: `dimSizes` must be valid for `dimRank`.
109
- // /
110
- // / Asserts:
111
- // / * `dimRank` is nonzero.
112
- // / * the elements of `dimSizes` are nonzero.
90
+ // / sizes and an optional initial storage capacity. The size of the
91
+ // / dimSizes array is determined by dimRank.
113
92
explicit SparseTensorCOO (uint64_t dimRank, const uint64_t *dimSizes,
114
93
uint64_t capacity = 0 )
115
94
: dimSizes(dimSizes, dimSizes + dimRank), isSorted(true ) {
@@ -134,16 +113,7 @@ class SparseTensorCOO final {
134
113
// / Returns the `operator<` closure object for the COO's element type.
135
114
ElementLT<V> getElementLT () const { return ElementLT<V>(getRank ()); }
136
115
137
- // / Adds an element to the tensor. This method does not check whether
138
- // / `dimCoords` is already associated with a value, it adds it regardless.
139
- // / Resolving such conflicts is left up to clients of the iterator
140
- // / interface.
141
- // /
142
- // / This method invalidates all iterators.
143
- // /
144
- // / Asserts:
145
- // / * the `dimCoords` is valid for `getRank`.
146
- // / * the components of `dimCoords` are valid for `getDimSizes`.
116
+ // / Adds an element to the tensor. This method invalidates all iterators.
147
117
void add (const std::vector<uint64_t > &dimCoords, V val) {
148
118
const uint64_t *base = coordinates.data ();
149
119
const uint64_t size = coordinates.size ();
@@ -154,7 +124,7 @@ class SparseTensorCOO final {
154
124
" Coordinate is too large for the dimension" );
155
125
coordinates.push_back (dimCoords[d]);
156
126
}
157
- // This base only changes if `coordinates` was reallocated. In which
127
+ // This base only changes if `coordinates` was reallocated. In which
158
128
// case, we need to correct all previous pointers into the vector.
159
129
// Note that this only happens if we did not set the initial capacity
160
130
// right, and then only for every internal vector reallocation (which
@@ -175,11 +145,9 @@ class SparseTensorCOO final {
175
145
const_iterator begin () const { return elements.cbegin (); }
176
146
const_iterator end () const { return elements.cend (); }
177
147
178
- // / Sorts elements lexicographically by coordinates. If a coordinate
148
+ // / Sorts elements lexicographically by coordinates. If a coordinate
179
149
// / is mapped to multiple values, then the relative order of those
180
- // / values is unspecified.
181
- // /
182
- // / This method invalidates all iterators.
150
+ // / values is unspecified. This method invalidates all iterators.
183
151
void sort () {
184
152
if (isSorted)
185
153
return ;
0 commit comments