Skip to content

[mlir][sparse] avoid non-perm on sparse tensor convert for new #72459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 16, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1189,27 +1189,38 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
LogicalResult matchAndRewrite(NewOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
const auto dstTp = getSparseTensorType(op.getResult());
const auto encDst = dstTp.getEncoding();
if (!dstTp.hasEncoding() || getCOOStart(encDst) == 0)
auto stt = getSparseTensorType(op.getResult());
auto enc = stt.getEncoding();
if (!stt.hasEncoding() || getCOOStart(enc) == 0)
return failure();

// Implement the NewOp as follows:
// %orderedCoo = sparse_tensor.new %filename
// %t = sparse_tensor.convert %orderedCoo
// with enveloping reinterpreted_map ops for non-permutations.
RankedTensorType dstTp = stt.getRankedTensorType();
RankedTensorType cooTp = getCOOType(dstTp, /*ordered=*/true);
Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
Value convert = rewriter.replaceOpWithNewOp<ConvertOp>(
op, dstTp.getRankedTensorType(), cooTensor);
Value convert = cooTensor;
if (!stt.isPermutation()) { // demap coo, demap dstTp
auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);
dstTp = getSparseTensorType(convert).withEncoding(enc.withoutDimToLvl());
}
convert = rewriter.create<ConvertOp>(loc, dstTp, convert);
if (!stt.isPermutation()) // remap to original enc
convert = rewriter.create<ReinterpretMapOp>(loc, enc, convert);
rewriter.replaceOp(op, convert);

// Release the ordered COO tensor.
// Release the temporary ordered COO tensor.
rewriter.setInsertionPointAfterValue(convert);
rewriter.create<DeallocTensorOp>(loc, cooTensor);

return success();
}
};

/// Sparse rewriting rule for the out operator.
struct OutRewriter : public OpRewritePattern<OutOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(OutOp op,
Expand Down Expand Up @@ -1250,6 +1261,7 @@ struct OutRewriter : public OpRewritePattern<OutOp> {
primaryTypeFunctionSuffix(eltTp)};
Value value = genAllocaScalar(rewriter, loc, eltTp);
ModuleOp module = op->getParentOfType<ModuleOp>();

// For each element in the source tensor, output the element.
rewriter.create<ForeachOp>(
loc, src, std::nullopt,
Expand Down