Skip to content

Commit 10951ca

Browse files
authored
[mlir][sparse] use uint64_t type for dim/rank consistently (#69626)
1 parent d681461 commit 10951ca

File tree

1 file changed

+14
-14
lines changed
  • mlir/include/mlir/ExecutionEngine/SparseTensor

1 file changed

+14
-14
lines changed

mlir/include/mlir/ExecutionEngine/SparseTensor/Storage.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
303303
uint64_t lvlRank = getLvlRank();
304304
uint64_t valIdx = 0;
305305
// Linearize the address
306-
for (size_t lvl = 0; lvl < lvlRank; lvl++)
306+
for (uint64_t lvl = 0; lvl < lvlRank; lvl++)
307307
valIdx = valIdx * getLvlSize(lvl) + lvlCoords[lvl];
308308
values[valIdx] = val;
309309
return;
@@ -338,7 +338,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
338338
values[c] = 0;
339339
filled[c] = false;
340340
// Subsequent insertions are quick.
341-
for (uint64_t i = 1; i < count; ++i) {
341+
for (uint64_t i = 1; i < count; i++) {
342342
assert(c < added[i] && "non-lexicographic insertion");
343343
c = added[i];
344344
assert(c <= expsz);
@@ -394,27 +394,27 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
394394

395395
// In-place permutation.
396396
auto applyPerm = [this](std::vector<uint64_t> &perm) {
397-
size_t length = perm.size();
398-
size_t lvlRank = getLvlRank();
397+
uint64_t length = perm.size();
398+
uint64_t lvlRank = getLvlRank();
399399
// Cache for the current level coordinates.
400400
std::vector<P> lvlCrds(lvlRank);
401-
for (size_t i = 0; i < length; i++) {
402-
size_t current = i;
401+
for (uint64_t i = 0; i < length; i++) {
402+
uint64_t current = i;
403403
if (i != perm[current]) {
404-
for (size_t l = 0; l < lvlRank; l++)
404+
for (uint64_t l = 0; l < lvlRank; l++)
405405
lvlCrds[l] = coordinates[l][i];
406406
V val = values[i];
407407
// Deals with a permutation cycle.
408408
while (i != perm[current]) {
409-
size_t next = perm[current];
409+
uint64_t next = perm[current];
410410
// Swaps the level coordinates and value.
411-
for (size_t l = 0; l < lvlRank; l++)
411+
for (uint64_t l = 0; l < lvlRank; l++)
412412
coordinates[l][current] = coordinates[l][next];
413413
values[current] = values[next];
414414
perm[current] = current;
415415
current = next;
416416
}
417-
for (size_t l = 0; l < lvlRank; l++)
417+
for (uint64_t l = 0; l < lvlRank; l++)
418418
coordinates[l][current] = lvlCrds[l];
419419
values[current] = val;
420420
perm[current] = current;
@@ -557,7 +557,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
557557
const uint64_t lastLvl = lvlRank - 1;
558558
assert(diffLvl <= lvlRank);
559559
const uint64_t stop = lvlRank - diffLvl;
560-
for (uint64_t i = 0; i < stop; ++i) {
560+
for (uint64_t i = 0; i < stop; i++) {
561561
const uint64_t l = lastLvl - i;
562562
finalizeSegment(l, lvlCursor[l] + 1);
563563
}
@@ -569,7 +569,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
569569
V val) {
570570
const uint64_t lvlRank = getLvlRank();
571571
assert(diffLvl <= lvlRank);
572-
for (uint64_t l = diffLvl; l < lvlRank; ++l) {
572+
for (uint64_t l = diffLvl; l < lvlRank; l++) {
573573
const uint64_t c = lvlCoords[l];
574574
appendCrd(l, full, c);
575575
full = 0;
@@ -582,7 +582,7 @@ class SparseTensorStorage final : public SparseTensorStorageBase {
582582
/// in the argument differ from those in the current cursor.
583583
uint64_t lexDiff(const uint64_t *lvlCoords) const {
584584
const uint64_t lvlRank = getLvlRank();
585-
for (uint64_t l = 0; l < lvlRank; ++l) {
585+
for (uint64_t l = 0; l < lvlRank; l++) {
586586
const auto crd = lvlCoords[l];
587587
const auto cur = lvlCursor[l];
588588
if (crd > cur || (crd == cur && !isUniqueLvl(l)) ||
@@ -705,7 +705,7 @@ SparseTensorStorage<P, C, V>::SparseTensorStorage(
705705
// really use nnz and dense/sparse distribution.
706706
bool allDense = true;
707707
uint64_t sz = 1;
708-
for (uint64_t l = 0; l < lvlRank; ++l) {
708+
for (uint64_t l = 0; l < lvlRank; l++) {
709709
const DimLevelType dlt = lvlTypes[l]; // Avoid redundant bounds checking.
710710
if (isCompressedDLT(dlt)) {
711711
positions[l].reserve(sz + 1);

0 commit comments

Comments
 (0)