Skip to content

Commit f40ee6e

Browse files
authored
[mlir][sparse] assemble SoA COO correctly. (llvm#82449)
1 parent 5a45d32 commit f40ee6e

File tree

2 files changed

+59
-39
lines changed

2 files changed

+59
-39
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp

Lines changed: 35 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,9 @@ using namespace sparse_tensor;
2222
// Helper methods.
2323
//===----------------------------------------------------------------------===//
2424

25-
// TODO: reuse StorageLayout::foreachField?
26-
27-
// TODO: we need COO AoS and SoA
28-
2925
// Convert type range to new types range, with sparse tensors externalized.
30-
void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
31-
SmallVectorImpl<Type> *extraTypes = nullptr) {
26+
static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
27+
SmallVectorImpl<Type> *extraTypes = nullptr) {
3228
for (auto type : types) {
3329
// All "dense" data passes through unmodified.
3430
if (!getSparseTensorEncoding(type)) {
@@ -42,29 +38,30 @@ void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
4238
convTypes.push_back(vtp);
4339
if (extraTypes)
4440
extraTypes->push_back(vtp);
45-
// Convert the external representations of the pos/crd arrays.
46-
for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
47-
const auto lt = stt.getLvlType(lvl);
48-
if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
49-
auto ptp = RankedTensorType::get(shape, stt.getPosType());
50-
auto ctp = RankedTensorType::get(shape, stt.getCrdType());
51-
convTypes.push_back(ptp);
52-
convTypes.push_back(ctp);
53-
if (extraTypes) {
54-
extraTypes->push_back(ptp);
55-
extraTypes->push_back(ctp);
56-
}
57-
} else {
58-
assert(isDenseLT(lt)); // TODO: handle other cases
41+
42+
// Convert the external representation of the position/coordinate array.
43+
foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes](
44+
Type t, FieldIndex,
45+
SparseTensorFieldKind kind,
46+
Level, LevelType) {
47+
if (kind == SparseTensorFieldKind::CrdMemRef ||
48+
kind == SparseTensorFieldKind::PosMemRef) {
49+
ShapedType st = t.cast<ShapedType>();
50+
auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
51+
convTypes.push_back(rtp);
52+
if (extraTypes)
53+
extraTypes->push_back(rtp);
5954
}
60-
}
55+
return true;
56+
});
6157
}
6258
}
6359

6460
// Convert input and output values to [dis]assemble ops for sparse tensors.
65-
void convVals(OpBuilder &builder, Location loc, TypeRange types,
66-
ValueRange fromVals, ValueRange extraVals,
67-
SmallVectorImpl<Value> &toVals, unsigned extra, bool isIn) {
61+
static void convVals(OpBuilder &builder, Location loc, TypeRange types,
62+
ValueRange fromVals, ValueRange extraVals,
63+
SmallVectorImpl<Value> &toVals, unsigned extra,
64+
bool isIn) {
6865
unsigned idx = 0;
6966
for (auto type : types) {
7067
// All "dense" data passes through unmodified.
@@ -85,29 +82,28 @@ void convVals(OpBuilder &builder, Location loc, TypeRange types,
8582
if (!isIn) {
8683
inputs.push_back(extraVals[extra++]);
8784
retTypes.push_back(RankedTensorType::get(shape, stt.getElementType()));
88-
cntTypes.push_back(builder.getIndexType());
85+
cntTypes.push_back(builder.getIndexType()); // nnz
8986
}
87+
9088
// Collect the external representations of the pos/crd arrays.
91-
for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
92-
const auto lt = stt.getLvlType(lvl);
93-
if (isCompressedLT(lt) || isLooseCompressedLT(lt)) {
89+
foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
90+
SparseTensorFieldKind kind,
91+
Level, LevelType) {
92+
if (kind == SparseTensorFieldKind::CrdMemRef ||
93+
kind == SparseTensorFieldKind::PosMemRef) {
9494
if (isIn) {
9595
inputs.push_back(fromVals[idx++]);
96-
inputs.push_back(fromVals[idx++]);
9796
} else {
98-
Type pTp = stt.getPosType();
99-
Type cTp = stt.getCrdType();
100-
inputs.push_back(extraVals[extra++]);
97+
ShapedType st = t.cast<ShapedType>();
98+
auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
10199
inputs.push_back(extraVals[extra++]);
102-
retTypes.push_back(RankedTensorType::get(shape, pTp));
103-
retTypes.push_back(RankedTensorType::get(shape, cTp));
104-
cntTypes.push_back(pTp);
105-
cntTypes.push_back(cTp);
100+
retTypes.push_back(rtp);
101+
cntTypes.push_back(rtp.getElementType());
106102
}
107-
} else {
108-
assert(isDenseLT(lt)); // TODO: handle other cases
109103
}
110-
}
104+
return true;
105+
});
106+
111107
if (isIn) {
112108
// Assemble multiple inputs into a single sparse tensor.
113109
auto a = builder.create<sparse_tensor::AssembleOp>(loc, rtp, inputs);

mlir/test/Dialect/SparseTensor/external.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,27 @@ func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<6
100100
func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32, #sparse> {
101101
return %arg0 : tensor<64x64xf32, #sparse>
102102
}
103+
104+
// -----
105+
106+
// CHECK-LABEL: func.func @sparse_inout_coo_soa(
107+
// CHECK-SAME: %[[A:.*0]]: tensor<?xf32>,
108+
// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
109+
// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
110+
// CHECK-SAME: %[[D:.*3]]: tensor<?xindex>,
111+
// CHECK-SAME: %[[E:.*4]]: tensor<?xf32>,
112+
// CHECK-SAME: %[[F:.*5]]: tensor<?xindex>,
113+
// CHECK-SAME: %[[G:.*6]]: tensor<?xindex>,
114+
// CHECK-SAME: %[[H:.*7]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>, tensor<?xindex>) {
115+
// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]], %[[D]]
116+
// CHECK: %[[F:.*]] = call @_internal_sparse_inout_coo_soa(%[[I]])
117+
// CHECK: sparse_tensor.disassemble %[[F]]
118+
// CHECK: return
119+
// CHECK: }
120+
// CHECK: func.func private @_internal_sparse_inout
121+
#sparse = #sparse_tensor.encoding<{
122+
map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa))
123+
}>
124+
func.func @sparse_inout_coo_soa(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32, #sparse> {
125+
return %arg0 : tensor<64x64xf32, #sparse>
126+
}

0 commit comments

Comments
 (0)