Skip to content

[NFC][mlir][sparse] remove redundant parameter. #75551

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a,
if (format == CuSparseFormat::kCOO) {
// Library uses SoA COO, direct IR uses AoS COO.
if (enableRT)
return genToCoordinates(builder, loc, a, 0, /*cooStart=*/0);
return genToCoordinates(builder, loc, a, 0);
return genToCoordinatesBuffer(builder, loc, a);
}
// Formats CSR/CSC and BSR use positions at 1.
Expand All @@ -490,7 +490,7 @@ static Value genSecondCrds(OpBuilder &builder, Location loc, Value a,
if (isCOO && !enableRT)
return Value(); // nothing needed
// Formats CSR/CSC and BSR use coordinates at 1.
return genToCoordinates(builder, loc, a, 1, /*cooStart=*/isCOO ? 0 : 2);
return genToCoordinates(builder, loc, a, 1);
}

/// Generates the sparse matrix handle.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -564,10 +564,11 @@ Value sparse_tensor::genToPositions(OpBuilder &builder, Location loc,
}

Value sparse_tensor::genToCoordinates(OpBuilder &builder, Location loc,
Value tensor, Level lvl, Level cooStart) {
Value tensor, Level lvl) {
const auto srcTp = getSparseTensorType(tensor);
const Type crdTp = srcTp.getCrdType();
const Type memTp = get1DMemRefType(crdTp, /*withLayout=*/lvl >= cooStart);
const Type memTp =
get1DMemRefType(crdTp, /*withLayout=*/lvl >= srcTp.getCOOStart());
return builder.create<ToCoordinatesOp>(loc, memTp, tensor,
builder.getIndexAttr(lvl));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ Value genToPositions(OpBuilder &builder, Location loc, Value tensor, Level lvl);
/// stride and offset. Otherwise, the result type is a memref without
/// any specified layout.
Value genToCoordinates(OpBuilder &builder, Location loc, Value tensor,
Level lvl, Level cooStart);
Level lvl);

/// Infers the result type and generates `ToCoordinatesBufferOp`.
Value genToCoordinatesBuffer(OpBuilder &builder, Location loc, Value tensor);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,6 @@ void LoopEmitter::initializeLoopEmit(
auto stt = getSparseTensorType(tensor);
const Level lvlRank = stt.getLvlRank();
const auto shape = rtp.getShape();
const Level cooStart = stt.getCOOStart();

SmallVector<Value> lvlSzs;
for (Level l = 0; l < stt.getLvlRank(); l++) {
Expand All @@ -457,12 +456,10 @@ void LoopEmitter::initializeLoopEmit(
if (isCompressedLT(lvlTp) || isLooseCompressedLT(lvlTp)) {
// Generate sparse primitives to obtain positions and coordinates.
positionsBuffers[t][l] = genToPositions(builder, loc, tensor, l);
coordinatesBuffers[t][l] =
genToCoordinates(builder, loc, tensor, l, cooStart);
coordinatesBuffers[t][l] = genToCoordinates(builder, loc, tensor, l);
} else if (isSingletonLT(lvlTp) || is2OutOf4LT(lvlTp)) {
// Singleton level, fetch coordinates.
coordinatesBuffers[t][l] =
genToCoordinates(builder, loc, tensor, l, cooStart);
coordinatesBuffers[t][l] = genToCoordinates(builder, loc, tensor, l);
} else {
// Dense level, nothing to fetch.
assert(isDenseLT(lvlTp));
Expand Down