Skip to content

Commit 2cc4b3d

Browse files
authored
[mlir][sparse] code cleanup using the assumption that dim2lvl maps ar… (#72894)
…e simplified.
1 parent 445f6f1 commit 2cc4b3d

File tree

2 files changed

+16
-20
lines changed

2 files changed

+16
-20
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -745,8 +745,8 @@ class SparseTensorAllocConverter
745745
const auto resType = getSparseTensorType(op);
746746
if (!resType.hasEncoding())
747747
return failure();
748-
Location loc = op.getLoc();
749748

749+
Location loc = op.getLoc();
750750
// Deal with copy.
751751
if (op.getCopy()) {
752752
auto desc = getDescriptorFromTensorTuple(adaptor.getCopy());
@@ -768,16 +768,14 @@ class SparseTensorAllocConverter
768768
return success();
769769
}
770770

771-
// Construct the dim/lvl sizes and the (unused) dim2lvl/lvl2dim buffers.
772-
SmallVector<Value> dimSizesValues;
771+
if (!resType.isIdentity()) {
772+
return rewriter.notifyMatchFailure(
773+
op, "try run --sparse-reinterpret-map before codegen");
774+
}
775+
// Level size equals to dimension size since lvl2dim map is an identity map.
773776
SmallVector<Value> lvlSizesValues;
774-
Value dimSizesBuffer;
775-
Value dim2lvlBuffer;
776-
Value lvl2dimBuffer;
777777
createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
778-
dimSizesValues);
779-
genMapBuffers(rewriter, loc, resType, dimSizesValues, dimSizesBuffer,
780-
lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
778+
/*dimSizesValues=*/lvlSizesValues);
781779

782780
// Construct allocation for each field.
783781
Value sizeHint = op.getSizeHint();
@@ -809,19 +807,17 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
809807
const auto resType = getSparseTensorType(op);
810808
if (!resType.hasEncoding())
811809
return failure();
812-
Location loc = op.getLoc();
813810

814-
// Construct the dim/lvl sizes and the (unused) dim2lvl/lvl2dim buffers.
815-
SmallVector<Value> dimSizesValues;
811+
if (!resType.isIdentity()) {
812+
return rewriter.notifyMatchFailure(
813+
op, "try run --sparse-reinterpret-map before codegen");
814+
}
815+
816+
Location loc = op.getLoc();
817+
// Level size equals to dimension size since lvl2dim map is an identity map.
816818
SmallVector<Value> lvlSizesValues;
817-
Value dimSizesBuffer;
818-
Value dim2lvlBuffer;
819-
Value lvl2dimBuffer;
820819
createDimSizes(rewriter, loc, resType, adaptor.getDynamicSizes(),
821-
dimSizesValues);
822-
genMapBuffers(rewriter, loc, resType, dimSizesValues, dimSizesBuffer,
823-
lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
824-
820+
/*dimSizesValues=*/lvlSizesValues);
825821
// Construct allocation for each field.
826822
Value sizeHint; // none
827823
SmallVector<Value> fields;

mlir/test/Dialect/SparseTensor/codegen.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach --lower-sparse-foreach-to-scf --sparse-tensor-codegen --canonicalize -cse | FileCheck %s
1+
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach --lower-sparse-foreach-to-scf --sparse-reinterpret-map --sparse-tensor-codegen --canonicalize -cse | FileCheck %s
22

33
#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
44

0 commit comments

Comments
 (0)