Skip to content

Commit 07bf1dd

Browse files
authored
[mlir][sparse] support non-id map for [Dis]assembleOp (#80355)
1 parent 375bd22 commit 07bf1dd

File tree

3 files changed

+84
-3
lines changed

3 files changed

+84
-3
lines changed

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,8 +1016,6 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
10161016
return op->emitError("the sparse-tensor must have static shape");
10171017
if (!stt.hasEncoding())
10181018
return op->emitError("the sparse-tensor must have an encoding attribute");
1019-
if (!stt.isIdentity())
1020-
return op->emitError("the sparse-tensor must have the identity mapping");
10211019

10221020
// Verifies the trailing COO.
10231021
Level cooStartLvl = stt.getCOOStart();

mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,40 @@ struct TensorInsertDemapper
656656
}
657657
};
658658

659+
struct SparseAssembleDemapper : public OpRewritePattern<AssembleOp> {
660+
using OpRewritePattern::OpRewritePattern;
661+
LogicalResult matchAndRewrite(AssembleOp op,
662+
PatternRewriter &rewriter) const override {
663+
if (!hasAnyNonIdentityOperandsOrResults(op))
664+
return failure();
665+
666+
assert(hasAnySparseResult(op));
667+
auto stt = getSparseTensorType(op.getResult());
668+
rewriter.modifyOpInPlace(
669+
op, [&op, &stt]() { op.getResult().setType(stt.getDemappedType()); });
670+
rewriter.setInsertionPointAfter(op);
671+
Value out = genRemap(rewriter, stt.getEncoding(), op.getResult());
672+
rewriter.replaceAllUsesExcept(op, out, out.getDefiningOp());
673+
return success();
674+
}
675+
};
676+
677+
struct SparseDisassembleDemapper
678+
: public DemapInsRewriter<SparseDisassembleDemapper, DisassembleOp> {
679+
using DemapInsRewriter::DemapInsRewriter;
680+
LogicalResult rewriteOp(DisassembleOp op, OpAdaptor adaptor,
681+
PatternRewriter &rewriter) const {
682+
if (!hasAnyNonIdentityOperandsOrResults(op))
683+
return failure();
684+
685+
assert(hasAnySparseOperandOrResult(op));
686+
rewriter.modifyOpInPlace(op, [&op, &adaptor]() {
687+
op.getTensorMutable().assign(adaptor.getTensor());
688+
});
689+
return success();
690+
}
691+
};
692+
659693
struct ForeachOpDemapper
660694
: public DemapInsRewriter<ForeachOpDemapper, ForeachOp> {
661695
using DemapInsRewriter::DemapInsRewriter;
@@ -758,7 +792,8 @@ void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
758792
if (scope == ReinterpretMapScope::kAll ||
759793
scope == ReinterpretMapScope::kExceptGeneric) {
760794
patterns.add<TensorAllocDemapper<bufferization::AllocTensorOp>,
761-
TensorAllocDemapper<tensor::EmptyOp>, TensorInsertDemapper,
795+
TensorAllocDemapper<tensor::EmptyOp>, SparseAssembleDemapper,
796+
SparseDisassembleDemapper, TensorInsertDemapper,
762797
ForeachOpDemapper>(patterns.getContext());
763798
}
764799
}

mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,51 @@ func.func @sparse_foreach_reinterpret_map(%6 : tensor<2x4xf64, #BSR>) -> tensor<
8080
%9 = sparse_tensor.load %8 hasInserts : tensor<2x4xf64, #BSR>
8181
return %9 : tensor<2x4xf64, #BSR>
8282
}
83+
84+
85+
// -----
86+
87+
#BSR = #sparse_tensor.encoding<{
88+
map = ( i, j ) ->
89+
( i floordiv 2 : dense,
90+
j floordiv 2 : compressed,
91+
i mod 2 : dense,
92+
j mod 2 : dense
93+
)
94+
}>
95+
// CHECK-DAG: #[[$remap:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 floordiv 2 : dense, d1 floordiv 2 : compressed, d0 mod 2 : dense, d1 mod 2 : dense) }>
96+
// CHECK-DAG: #[[$demap:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2, d3) -> (d0 : dense, d1 : compressed, d2 : dense, d3 : dense) }>
97+
98+
// CHECK-LABEL: func.func @sparse_assemble_reinterpret_map(
99+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?xf64>,
100+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?xindex>,
101+
// CHECK-SAME: %[[VAL_2:.*]]: tensor<?xindex>) -> tensor<2x4xf64, #[[$remap]]> {
102+
// CHECK: %[[VAL_3:.*]] = sparse_tensor.assemble %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : tensor<?xf64>, tensor<?xindex>, tensor<?xindex> to tensor<1x2x2x2xf64, #[[$demap]]>
103+
// CHECK: %[[VAL_4:.*]] = sparse_tensor.reinterpret_map %[[VAL_3]] : tensor<1x2x2x2xf64, #[[$demap]]> to tensor<2x4xf64, #[[$remap]]>
104+
// CHECK: return %[[VAL_4]] : tensor<2x4xf64, #[[$remap]]>
105+
// CHECK: }
106+
func.func @sparse_assemble_reinterpret_map(%val : tensor<?xf64>, %pos:tensor<?xindex>, %crd:tensor<?xindex>) -> tensor<2x4xf64, #BSR> {
107+
%0 = sparse_tensor.assemble %val, %pos, %crd
108+
: tensor<?xf64>, tensor<?xindex>, tensor<?xindex> to tensor<2x4xf64, #BSR>
109+
return %0 : tensor<2x4xf64, #BSR>
110+
}
111+
112+
// CHECK-LABEL: func.func @sparse_disassemble_reinterpret_map(
113+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x4xf64, #[[$remap]]>,
114+
// CHECK-SAME: %[[VAL_1:.*]]: tensor<?xf64>,
115+
// CHECK-SAME: %[[VAL_2:.*]]: tensor<?xindex>,
116+
// CHECK-SAME: %[[VAL_3:.*]]: tensor<?xindex>) -> (tensor<?xf64>, tensor<?xindex>, tensor<?xindex>) {
117+
// CHECK: %[[VAL_4:.*]] = sparse_tensor.reinterpret_map %[[VAL_0]] : tensor<2x4xf64, #[[$remap]]> to tensor<1x2x2x2xf64, #[[$demap]]>
118+
// CHECK: %[[VAL_5:.*]], %[[VAL_6:.*]]:2, %[[VAL_7:.*]], %[[VAL_8:.*]]:2 = sparse_tensor.disassemble %[[VAL_4]] : tensor<1x2x2x2xf64, #[[$demap]]>
119+
// CHECK: return
120+
// CHECK: }
121+
func.func @sparse_disassemble_reinterpret_map(%sp : tensor<2x4xf64, #BSR>,
122+
%od : tensor<?xf64>,
123+
%op : tensor<?xindex>,
124+
%oi : tensor<?xindex>)
125+
-> (tensor<?xf64>, tensor<?xindex>, tensor<?xindex>) {
126+
%rd, %rp, %ri, %dl, %pl, %il = sparse_tensor.disassemble %sp : tensor<2x4xf64, #BSR>
127+
outs(%od, %op, %oi : tensor<?xf64>, tensor<?xindex>, tensor<?xindex>)
128+
-> tensor<?xf64>, (tensor<?xindex>, tensor<?xindex>), index, (index, index)
129+
return %rd, %rp, %ri : tensor<?xf64>, tensor<?xindex>, tensor<?xindex>
130+
}

0 commit comments

Comments
 (0)