Skip to content

Commit 5b15e44

Browse files
Peiming LiuPeimingLiu
authored andcommitted
address comment + fix bug
1 parent ddc0130 commit 5b15e44

File tree

4 files changed

+31
-30
lines changed

4 files changed

+31
-30
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -313,16 +313,16 @@ void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) {
313313
// Compute the subsection size.
314314
Value size = c0;
315315
for (auto [loop, stride] : remDepStack[t][lvl]) {
316-
Value loopHi = loopHighs[loop];
317-
size = ADDI(size, MULI(loopHi, C_IDX(stride)));
316+
Value idxMax = SUBI(loopHighs[loop], C_IDX(1));
317+
size = ADDI(size, ADDI(MULI(idxMax, C_IDX(stride)), C_IDX(1)));
318318
}
319319
it = makeNonEmptySubSectIterator(builder, loc, parent, loopHighs[loop],
320320
std::move(lvlIt), size, curDep.second);
321321
} else {
322-
Value size = loopHighs[loop];
323322
const SparseIterator &subSectIter = *iters[t][lvl].back();
324323
it = makeTraverseSubSectIterator(builder, loc, subSectIter, *parent,
325-
std::move(lvlIt), size, curDep.second);
324+
std::move(lvlIt), loopHighs[loop],
325+
curDep.second);
326326
}
327327
lastIter[t] = it.get();
328328
iters[t][lvl].emplace_back(std::move(it));

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -665,11 +665,10 @@ class SubSectIterator : public SparseIterator {
665665
public:
666666
SubSectIterator(const NonEmptySubSectIterator &subSect,
667667
const SparseIterator &parent,
668-
std::unique_ptr<SparseIterator> &&wrap, Value size)
668+
std::unique_ptr<SparseIterator> &&wrap)
669669
: SparseIterator(IterKind::kSubSect, *wrap,
670670
/*extraCursorCnt=*/wrap->randomAccessible() ? 0 : 1),
671-
subSect(subSect), wrap(std::move(wrap)), parent(parent), size(size),
672-
helper(*this) {
671+
subSect(subSect), wrap(std::move(wrap)), parent(parent), helper(*this) {
673672
assert(subSect.tid == tid && subSect.lvl == lvl);
674673
assert(parent.kind != IterKind::kSubSect || parent.lvl + 1 == lvl);
675674
};
@@ -691,7 +690,9 @@ class SubSectIterator : public SparseIterator {
691690

692691
bool randomAccessible() const override { return wrap->randomAccessible(); };
693692
bool iteratableByFor() const override { return randomAccessible(); };
694-
Value upperBound(OpBuilder &b, Location l) const override { return size; }
693+
Value upperBound(OpBuilder &b, Location l) const override {
694+
return subSect.subSectSz;
695+
}
695696
std::pair<Value, Value> getCurPosition() const override {
696697
return wrap->getCurPosition();
697698
};
@@ -709,7 +710,7 @@ class SubSectIterator : public SparseIterator {
709710
assert(p->lvl + 1 == lvl);
710711
wrap->genInit(b, l, p);
711712
// Linearize the dense subsection index.
712-
nxLvlTupleStart = MULI(size, p->getNxLvlTupleId(b, l));
713+
nxLvlTupleStart = MULI(subSect.subSectSz, p->getNxLvlTupleId(b, l));
713714
} else {
714715
assert(subSect.lvl == lvl && subSect.isSubSectRoot());
715716
wrap->deserialize(subSect.delegate->serialize());
@@ -763,7 +764,6 @@ class SubSectIterator : public SparseIterator {
763764
std::unique_ptr<SparseIterator> wrap;
764765
const SparseIterator &parent;
765766

766-
Value size;
767767
SubSectIterHelper helper;
768768
};
769769

@@ -1354,17 +1354,18 @@ std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
13541354
std::unique_ptr<SparseIterator> sparse_tensor::makeTraverseSubSectIterator(
13551355
OpBuilder &b, Location l, const SparseIterator &subSectIter,
13561356
const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap,
1357-
Value size, unsigned stride) {
1357+
Value loopBound, unsigned stride) {
13581358

13591359
// This must be a subsection iterator or a filtered subsection iterator.
13601360
auto &subSect =
13611361
llvm::cast<NonEmptySubSectIterator>(*tryUnwrapFilter(&subSectIter));
13621362

13631363
auto it = std::make_unique<SubSectIterator>(
1364-
subSect, *tryUnwrapFilter(&parent), std::move(wrap), size);
1364+
subSect, *tryUnwrapFilter(&parent), std::move(wrap));
1365+
13651366
if (stride != 1) {
13661367
return std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
1367-
C_IDX(stride), /*size=*/size);
1368+
C_IDX(stride), /*size=*/loopBound);
13681369
}
13691370
return it;
13701371
}

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,8 @@ std::unique_ptr<SparseIterator> makeNonEmptySubSectIterator(
301301
/// non-empty subsection created by NonEmptySubSectIterator.
302302
std::unique_ptr<SparseIterator> makeTraverseSubSectIterator(
303303
OpBuilder &b, Location l, const SparseIterator &subsectIter,
304-
const SparseIterator &parent, std::unique_ptr<SparseIterator> &&delegate,
305-
Value size, unsigned stride);
304+
const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap,
305+
Value loopBound, unsigned stride);
306306

307307
} // namespace sparse_tensor
308308
} // namespace mlir

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_dilated_conv_2d_nhwc_hwcf.mlir

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ func.func @entry() {
9090
%zero = arith.constant 0.00000e+00 : f32
9191

9292
%filter2D_nhwc = call @alloc_4d_filled_f32(%c3, %c3, %c3, %c1, %val) :(index, index, index, index, f32) -> (tensor<?x?x?x?xf32>)
93-
%in2D_tmp = call @alloc_4d_filled_f32(%c3, %c7, %c7, %c3, %zero) : (index, index, index, index, f32) -> (tensor<?x?x?x?xf32>)
94-
%in2D_nhwc = tensor.insert %f10 into %in2D_tmp[%c0, %c1, %c1, %c0] : tensor<?x?x?x?xf32>
93+
%in2D_tmp = call @alloc_4d_filled_f32(%c3, %c7, %c7, %c3, %f10) : (index, index, index, index, f32) -> (tensor<?x?x?x?xf32>)
94+
%in2D_nhwc = tensor.insert %zero into %in2D_tmp[%c0, %c1, %c1, %c0] : tensor<?x?x?x?xf32>
9595
%out2D_nhwc = call @alloc_4d_filled_f32(%c3, %c3, %c3, %c1, %zero) : (index, index, index, index, f32) -> (tensor<?x?x?x?xf32>)
9696

9797
%in2D_nhwc_CCCC = sparse_tensor.convert %in2D_nhwc
@@ -108,35 +108,35 @@ func.func @entry() {
108108
%dual_CDCC_ret = call @conv_2d_nhwc_hwcf_dual_CDCC(%in2D_nhwc_CDCC, %filter2D_nhwc_CDCC, %out2D_nhwc)
109109
: (tensor<?x?x?x?xf32, #CDCC>, tensor<?x?x?x?xf32, #CDCC>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>)
110110

111-
// CHECK: ( ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 20 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
112-
// CHECK-SAME: ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
113-
// CHECK-SAME: ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ) )
111+
// CHECK-NEXT: ( ( ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 520 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ) ),
112+
// CHECK-SAME: ( ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ) ),
113+
// CHECK-SAME: ( ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ) ) )
114114
%dense_v = vector.transfer_read %dense_ret[%c0, %c0, %c0, %c0], %zero
115115
: tensor<?x?x?x?xf32>, vector<3x3x3x1xf32>
116116
vector.print %dense_v : vector<3x3x3x1xf32>
117117

118-
// CHECK-NEXT: ( ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 20 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
119-
// CHECK-SAME: ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
120-
// CHECK-SAME: ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ) )
118+
// CHECK-NEXT: ( ( ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 520 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ) ),
119+
// CHECK-SAME: ( ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ) ),
120+
// CHECK-SAME: ( ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ) ) )
121121
%v_dual = vector.transfer_read %dual_CDCC_ret[%c0, %c0, %c0, %c0], %zero
122122
: tensor<?x?x?x?xf32>, vector<3x3x3x1xf32>
123123
vector.print %v_dual : vector<3x3x3x1xf32>
124124

125-
// CHECK-NEXT: ( ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 20 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
126-
// CHECK-SAME: ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
127-
// CHECK-SAME: ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ) )
125+
// CHECK-NEXT: ( ( ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 520 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ) ),
126+
// CHECK-SAME: ( ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ) ),
127+
// CHECK-SAME: ( ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ) ) )
128128
%v1 = vector.transfer_read %CCCC_ret[%c0, %c0, %c0, %c0], %zero
129129
: tensor<?x?x?x?xf32>, vector<3x3x3x1xf32>
130130
vector.print %v1 : vector<3x3x3x1xf32>
131131

132-
// CHECK-NEXT: ( ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 20 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
133-
// CHECK-SAME: ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
134-
// CHECK-SAME: ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ) )
132+
// CHECK-NEXT: ( ( ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 520 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ) ),
133+
// CHECK-SAME: ( ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ) ),
134+
// CHECK-SAME: ( ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ), ( ( 540 ), ( 540 ), ( 540 ) ) ) )
135135
%v2 = vector.transfer_read %CDCC_ret[%c0, %c0, %c0, %c0], %zero
136136
: tensor<?x?x?x?xf32>, vector<3x3x3x1xf32>
137137
vector.print %v1 : vector<3x3x3x1xf32>
138138

139-
// Free the resources
139+
// Free the resources.
140140
bufferization.dealloc_tensor %in2D_nhwc : tensor<?x?x?x?xf32>
141141
bufferization.dealloc_tensor %filter2D_nhwc : tensor<?x?x?x?xf32>
142142
bufferization.dealloc_tensor %out2D_nhwc : tensor<?x?x?x?xf32>

0 commit comments

Comments
 (0)