Skip to content

Commit ed59a44

Browse files
author
Peiming Liu
committed
address comments.
1 parent 149b0dd commit ed59a44

File tree

6 files changed

+32
-26
lines changed

6 files changed

+32
-26
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1769,18 +1769,18 @@ def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
17691769
});
17701770
}
17711771

1772-
// The block arguments starts with referenced coordinates, follows by
1773-
// user-provided iteration arguments and ends with iterators.
1772+
// The block arguments starts with user-provided iteration arguments,
1773+
// follows by referenced coordinates and ends with iterators.
17741774
Block::BlockArgListType getCrds(unsigned regionIdx) {
17751775
return getRegion(regionIdx).getArguments()
1776-
.take_front(getCrdUsedLvls().count());
1776+
.slice(getNumRegionIterArgs(), getCrdUsedLvls().count());
17771777
}
1778-
unsigned getNumRegionIterArgs(unsigned regionIdx) {
1778+
unsigned getNumRegionIterArgs() {
17791779
return getInitArgs().size();
17801780
}
17811781
Block::BlockArgListType getRegionIterArgs(unsigned regionIdx) {
17821782
return getRegion(regionIdx).getArguments()
1783-
.slice(getCrdUsedLvls().count(), getNumRegionIterArgs(regionIdx));
1783+
.take_front(getNumRegionIterArgs());
17841784
}
17851785
Block::BlockArgListType getRegionIterators(unsigned regionIdx) {
17861786
return getRegion(regionIdx).getArguments()

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2293,16 +2293,18 @@ parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state,
22932293
if (parser.parseOperandList(spaces, OpAsmParser::Delimiter::Paren))
22942294
return failure();
22952295

2296-
if (failed(parseUsedCoordList(parser, state, blockArgs)))
2296+
SmallVector<OpAsmParser::Argument> coords;
2297+
if (failed(parseUsedCoordList(parser, state, coords)))
22972298
return failure();
2298-
size_t numCrds = blockArgs.size();
2299+
size_t numCrds = coords.size();
22992300

23002301
// Parse "iter_args(%arg = %init, ...)"
23012302
SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
23022303
bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
23032304
if (hasIterArgs)
23042305
if (parser.parseAssignmentList(blockArgs, initArgs))
23052306
return failure();
2307+
blockArgs.append(coords);
23062308

23072309
SmallVector<Type> iterSpaceTps;
23082310
// parse ": (sparse_tensor.iter_space, ...) -> ret"
@@ -2326,8 +2328,8 @@ parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state,
23262328
state.operands.append(spacesVals);
23272329

23282330
if (hasIterArgs) {
2329-
// Strip off leading args that used for coordinates.
2330-
MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
2331+
// Strip off trailing args that used for coordinates.
2332+
MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
23312333
if (args.size() != initArgs.size() || args.size() != state.types.size()) {
23322334
return parser.emitError(
23332335
parser.getNameLoc(),
@@ -2609,8 +2611,10 @@ void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
26092611
cast<IterSpaceType>(iterSpaces.front().getType()).getSpaceDim();
26102612
// All ones.
26112613
I64BitSet set((1 << rank) - 1);
2612-
// Fake cases bits. We need to preallocate all the regions as Region can not
2613-
// be dynamically added later after the operation is created.
2614+
// Generates all-zero case bits (they only serve as placeholders), which are
2615+
// supposed to be overriden later. We need to preallocate all the regions as
2616+
// mlir::Region cannot be dynamically added later after the operation is
2617+
// created.
26142618
SmallVector<int64_t> caseBits(numCases, 0);
26152619
ArrayAttr cases = builder.getI64ArrayAttr(caseBits);
26162620
return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces,
@@ -2701,7 +2705,7 @@ ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) {
27012705

27022706
LogicalResult CoIterateOp::verifyRegions() {
27032707
for (unsigned r = 0, e = getNumRegions(); r < e; r++) {
2704-
if (getNumRegionIterArgs(r) != getNumResults())
2708+
if (getNumRegionIterArgs() != getNumResults())
27052709
return emitOpError(
27062710
"mismatch in number of basic block args and defined values");
27072711

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -904,7 +904,7 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
904904
// basic block where scf::Yield should be inserted.
905905
}
906906

907-
/// Generate a case region in the coiterate operation.
907+
/// Generates a case region in the coiterate operation.
908908
static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder,
909909
unsigned caseIdx, LatPointId allCase,
910910
LatPointId curCase,
@@ -920,8 +920,8 @@ static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder,
920920
if (curCaseBits.test(set))
921921
caseBit.set(idx);
922922

923-
env.emitter().enterCurCoIterationCase(builder, env.op().getLoc(), caseBit,
924-
caseIdx, reduc);
923+
env.emitter().enterCurrentCoIterationCase(builder, env.op().getLoc(), caseBit,
924+
caseIdx, reduc);
925925
}
926926

927927
/// Generates a single if-statement within a while-loop.

mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -615,10 +615,11 @@ bool LoopEmitter::shouldIteratedByForLoop(ArrayRef<SparseIterator *> spIters) {
615615
return true;
616616
}
617617

618-
Region *LoopEmitter::enterCurCoIterationCase(OpBuilder &builder, Location loc,
619-
I64BitSet caseBit,
620-
unsigned caseIdx,
621-
MutableArrayRef<Value> reduc) {
618+
Region *LoopEmitter::enterCurrentCoIterationCase(OpBuilder &builder,
619+
Location loc,
620+
I64BitSet caseBit,
621+
unsigned caseIdx,
622+
MutableArrayRef<Value> reduc) {
622623
auto coIterOp = cast<CoIterateOp>(loopStack.back().loop);
623624
SmallVector<Attribute> cases(coIterOp.getCases().getAsRange<Attribute>());
624625
cases[caseIdx] = builder.getI64IntegerAttr(caseBit);
@@ -628,11 +629,12 @@ Region *LoopEmitter::enterCurCoIterationCase(OpBuilder &builder, Location loc,
628629
assert(caseRegion.getBlocks().empty() &&
629630
"re-initialize the same coiteration case region.");
630631

631-
// Each block starts with a list of used coordinates of index type.
632+
// Each block starts with by a list of user-provided iteration arguments.
633+
TypeRange iterArgsTps = coIterOp.getInitArgs().getTypes();
634+
// Followed by a list of used coordinates of index type.
632635
SmallVector<Type> blockArgTps(coIterOp.getCrdUsedLvls().count(),
633636
builder.getIndexType());
634-
// Follows by a list of user-provided iteration arguments.
635-
TypeRange iterArgsTps = coIterOp.getInitArgs().getTypes();
637+
636638
blockArgTps.append(iterArgsTps.begin(), iterArgsTps.end());
637639
// Ends with a set of iterators that defines the actually iteration space.
638640
for (auto i : caseBit.bits()) {

mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,9 @@ class LoopEmitter {
148148
unsigned numCases, MutableArrayRef<Value> reduc = {},
149149
bool isParallel = false, bool needsUniv = false);
150150

151-
Region *enterCurCoIterationCase(OpBuilder &builder, Location loc,
152-
I64BitSet caseBit, unsigned caseIdx,
153-
MutableArrayRef<Value> reduc);
151+
Region *enterCurrentCoIterationCase(OpBuilder &builder, Location loc,
152+
I64BitSet caseBit, unsigned caseIdx,
153+
MutableArrayRef<Value> reduc);
154154

155155
/// Generates code to exit the current loop (e.g., generates yields, forwards
156156
/// loop induction variables, etc).

mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse | FileCheck %s --check-prefix="ITER"
22

3-
// TODO: temporially disabled since there is no lowering rules from `coiterate` to `scf`.
3+
// TODO: temporarilly disabled since there is no lowering rules from `coiterate` to `scf`.
44
// R_U_N: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse --sparse-space-collapse --lower-sparse-iteration-to-scf --loop-invariant-code-motion | FileCheck %s
55

66

0 commit comments

Comments
 (0)