Skip to content

[mlir][sparse] use a consistent order between [dis]assembleOp and sto… #84079

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [Pure]>,
}

def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]>,
Arguments<(ins TensorOf<[AnyType]>:$values,
Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels)>,
Arguments<(ins Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels,
TensorOf<[AnyType]>:$values)>,
Results<(outs AnySparseTensor: $result)> {
let summary = "Returns a sparse tensor assembled from the given values and levels";

Expand Down Expand Up @@ -96,20 +96,20 @@ def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]>,
}];

let assemblyFormat =
"$values `,` $levels attr-dict"
"`:` type($values) `,` type($levels) `to` type($result)";
"` ` `(` $levels `)` `,` $values attr-dict"
" `:` `(` type($levels) `)` `,` type($values) `to` type($result)";

let hasVerifier = 1;
}

def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVariadicResultSize]>,
Arguments<(ins AnySparseTensor:$tensor,
TensorOf<[AnyType]>:$out_values,
Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>,
Results<(outs TensorOf<[AnyType]>:$ret_values,
Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
AnyIndexingScalarLike:$val_len,
Variadic<AnyIndexingScalarLike>:$lvl_lens)> {
Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels,
TensorOf<[AnyType]>:$out_values)>,
Results<(outs Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
TensorOf<[AnyType]>:$ret_values,
Variadic<AnyIndexingScalarLike>:$lvl_lens,
AnyIndexingScalarLike:$val_len)> {
let summary = "Returns the (values, coordinates) pair disassembled from the input tensor";

let description = [{
Expand All @@ -134,8 +134,9 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
// |0.0, 0.0, 0.0, 0.0|
%v, %p, %c, %v_len, %p_len, %c_len =
sparse_tensor.disassemble %sp : tensor<3x4xf64, #COO>
outs(%od, %op, %oi : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>)
-> tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index)
out_lvls(%op, %oi) : tensor<2xindex>, tensor<3x2xindex>,
out_vals(%od) : tensor<3xf64> ->
tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index)
// %v = arith.constant dense<[ 1.1, 2.2, 3.3 ]> : tensor<3xf64>
// %p = arith.constant dense<[ 0, 3 ]> : tensor<2xindex>
// %c = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex>
Expand All @@ -147,8 +148,10 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria

let assemblyFormat =
"$tensor `:` type($tensor) "
"`outs` `(` $out_values `,` $out_levels `:` type($out_values) `,` type($out_levels) `)` attr-dict"
"`->` type($ret_values) `,` `(` type($ret_levels) `)` `,` type($val_len) `,` `(` type($lvl_lens) `)`";
"`out_lvls` `(` $out_levels `:` type($out_levels) `)` "
"`out_vals` `(` $out_values `:` type($out_values) `)` attr-dict"
"`->` `(` type($ret_levels) `)` `,` type($ret_values) `,` "
"`(` type($lvl_lens) `)` `,` type($val_len)";

let hasVerifier = 1;
}
Expand Down
33 changes: 10 additions & 23 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,16 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
convTypes.push_back(type);
continue;
}
// Convert the external representation of the values array.

// Convert the external representation of the position/coordinate array
const SparseTensorType stt(cast<RankedTensorType>(type));
auto shape = stt.getBatchLvlShape();
shape.push_back(ShapedType::kDynamic);
auto vtp = RankedTensorType::get(shape, stt.getElementType());
convTypes.push_back(vtp);
if (extraTypes)
extraTypes->push_back(vtp);

// Convert the external representation of the position/coordinate array.
foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes](
Type t, FieldIndex,
SparseTensorFieldKind kind,
Level, LevelType) {
if (kind == SparseTensorFieldKind::CrdMemRef ||
kind == SparseTensorFieldKind::PosMemRef) {
kind == SparseTensorFieldKind::PosMemRef ||
kind == SparseTensorFieldKind::ValMemRef) {
ShapedType st = t.cast<ShapedType>();
auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
convTypes.push_back(rtp);
Expand All @@ -70,37 +64,30 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
toVals.push_back(fromVals[idx++]);
continue;
}
// Convert the external representation of the values array.
// Handle sparse data.
auto rtp = cast<RankedTensorType>(type);
const SparseTensorType stt(rtp);
auto shape = stt.getBatchLvlShape();
shape.push_back(ShapedType::kDynamic);
SmallVector<Value> inputs;
SmallVector<Type> retTypes;
SmallVector<Type> cntTypes;
// Collect the external representation of the values array for
// input or the outgoing sparse tensor for output.
inputs.push_back(fromVals[idx++]);
if (!isIn) {
inputs.push_back(extraVals[extra++]);
retTypes.push_back(RankedTensorType::get(shape, stt.getElementType()));
cntTypes.push_back(builder.getIndexType()); // nnz
}
if (!isIn)
inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble

// Collect the external representations of the pos/crd arrays.
foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
SparseTensorFieldKind kind,
Level, LevelType) {
if (kind == SparseTensorFieldKind::CrdMemRef ||
kind == SparseTensorFieldKind::PosMemRef) {
kind == SparseTensorFieldKind::PosMemRef ||
kind == SparseTensorFieldKind::ValMemRef) {
if (isIn) {
inputs.push_back(fromVals[idx++]);
} else {
ShapedType st = t.cast<ShapedType>();
auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
inputs.push_back(extraVals[extra++]);
retTypes.push_back(rtp);
cntTypes.push_back(rtp.getElementType());
cntTypes.push_back(builder.getIndexType());
}
}
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -928,8 +928,8 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter,
Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH);
Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH);
Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH);
rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), vt,
ValueRange{rt, ct});
rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), ValueRange{rt, ct},
vt);
return success();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1409,14 +1409,10 @@ struct SparseDisassembleOpConverter
sz = desc.getValMemSize(rewriter, loc);
src = desc.getValMemRef();
dst = genToMemref(rewriter, loc, op.getOutValues());
// Values is the last field in descriptor, but it is the first
// operand in unpack operation.
// TODO: maybe change unpack/pack operation instead to be
// consistent.
retMem.insert(retMem.begin(), dst);

retMem.push_back(dst);
Type valLenTp = op.getValLen().getType();
retLen.insert(retLen.begin(),
genScalarToTensor(rewriter, loc, sz, valLenTp));
retLen.push_back(genScalarToTensor(rewriter, loc, sz, valLenTp));
} else {
assert(fKind == SparseTensorFieldKind::PosMemRef ||
fKind == SparseTensorFieldKind::CrdMemRef);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -738,13 +738,7 @@ class SparseTensorDisassembleConverter
auto stt = getSparseTensorType(op.getTensor());
SmallVector<Value> retVal;
SmallVector<Value> retLen;
// Get the values buffer first.
auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
auto valLenTp = op.getValLen().getType();
auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
retVal.push_back(vals);
retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
// Then get the positions and coordinates buffers.
// Get the positions and coordinates buffers.
const Level lvlRank = stt.getLvlRank();
Level trailCOOLen = 0;
for (Level l = 0; l < lvlRank; l++) {
Expand All @@ -761,15 +755,15 @@ class SparseTensorDisassembleConverter
auto poss =
genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
retVal.push_back(poss);
retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
}
if (stt.isWithCrd(l)) {
auto crds =
genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
auto crdLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()];
retVal.push_back(crds);
retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
}
Expand All @@ -784,14 +778,13 @@ class SparseTensorDisassembleConverter
auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
cooStartLvl);
auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
retVal.push_back(poss);
retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
// Coordinates, copied over with:
// for (i = 0; i < crdLen; i++)
// buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
auto buf =
genToMemref(rewriter, loc, op.getOutLevels()[retLen.size() - 1]);
auto buf = genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]);
auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
cooStartLvl);
auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
Expand All @@ -814,17 +807,25 @@ class SparseTensorDisassembleConverter
args[1] = one;
rewriter.create<memref::StoreOp>(loc, c1, buf, args);
rewriter.setInsertionPointAfter(forOp);
auto bufLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()];
retVal.push_back(buf);
retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp));
}
// Get the values buffer last.
auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
auto valLenTp = op.getValLen().getType();
auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
retVal.push_back(vals);
retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));

// Converts MemRefs back to Tensors.
assert(retVal.size() + retLen.size() == op.getNumResults());
for (unsigned i = 0, sz = retVal.size(); i < sz; i++) {
auto tensor = rewriter.create<bufferization::ToTensorOp>(loc, retVal[i]);
retVal[i] =
rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor);
}

// Appends the actual memory length used in each buffer returned.
retVal.append(retLen.begin(), retLen.end());
rewriter.replaceOp(op, retVal);
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
// CHECK: %[[VAL_a2:.*]] = bufferization.to_tensor %[[VAL_83]] : memref<?xf32>
// CHECK: %[[VAL_a3:.*]] = bufferization.to_tensor %[[VAL_81]] : memref<?xindex>
// CHECK: %[[VAL_a4:.*]] = bufferization.to_tensor %[[VAL_82]] : memref<?xindex>
// CHECK: %[[VAL_a5:.*]] = sparse_tensor.assemble %[[VAL_a2]], %[[VAL_a3]], %[[VAL_a4]] : tensor<?xf32>, tensor<?xindex>, tensor<?xindex> to tensor<8x8xf32, #{{.*}}>
// CHECK: %[[VAL_a5:.*]] = sparse_tensor.assemble (%[[VAL_a3]], %[[VAL_a4]]), %[[VAL_a2]] : (tensor<?xindex>, tensor<?xindex>), tensor<?xf32> to tensor<8x8xf32, #{{.*}}>
// CHECK: return %[[VAL_a5]] : tensor<8x8xf32, #{{.*}}>
// CHECK: }
func.func @matmulCSR(%A: tensor<8x8xf32, #CSR>,
Expand Down
66 changes: 33 additions & 33 deletions mlir/test/Dialect/SparseTensor/external.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> {
// -----

// CHECK-LABEL: func.func @sparse_in(
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
// CHECK-SAME: %[[B:.*0]]: tensor<?xindex>,
// CHECK-SAME: %[[C:.*1]]: tensor<?xindex>,
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>) -> tensor<64x64xf32> {
// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
// CHECK: %[[F:.*]] = call @_internal_sparse_in(%[[I]])
// CHECK: return %[[F]] : tensor<64x64xf32>
// CHECK: }
Expand All @@ -30,11 +30,11 @@ func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
// -----

// CHECK-LABEL: func.func @sparse_in2(
// CHECK-SAME: %[[X:.*]]: tensor<100xf32>,
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
// CHECK-SAME: %[[X:.*0]]: tensor<100xf32>,
// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
// CHECK-SAME: %[[A:.*3]]: tensor<?xf32>) -> tensor<64x64xf32> {
// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
// CHECK: %[[F:.*]] = call @_internal_sparse_in2(%[[X]], %[[I]])
// CHECK: return %[[F]] : tensor<64x64xf32>
// CHECK: }
Expand All @@ -48,10 +48,10 @@ func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>)
// -----

// CHECK-LABEL: func.func @sparse_out(
// CHECK-SAME: %[[X:.*]]: tensor<64x64xf32>,
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
// CHECK-SAME: %[[X:.*0]]: tensor<64x64xf32>,
// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
// CHECK-SAME: %[[A:.*3]]: tensor<?xf32>)
// CHECK: %[[F:.*]] = call @_internal_sparse_out(%[[X]])
// CHECK: sparse_tensor.disassemble %[[F]]
// CHECK: return
Expand All @@ -66,10 +66,10 @@ func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> {
// -----

// CHECK-LABEL: func.func @sparse_out2(
// CHECK-SAME: %[[X:.*]]: tensor<64x64xf32>,
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> (tensor<64x64xf32>, tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
// CHECK-SAME: %[[X:.*0]]: tensor<64x64xf32>,
// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
// CHECK-SAME: %[[A:.*3]]: tensor<?xf32>)
// CHECK: %[[F:.*]]:2 = call @_internal_sparse_out2(%[[X]])
// CHECK: sparse_tensor.disassemble %[[F]]#1
// CHECK: return %[[F]]#0
Expand All @@ -84,13 +84,13 @@ func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<6
// -----

// CHECK-LABEL: func.func @sparse_inout(
// CHECK-SAME: %[[A:.*0]]: tensor<?xf32>,
// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
// CHECK-SAME: %[[D:.*3]]: tensor<?xf32>,
// CHECK-SAME: %[[E:.*4]]: tensor<?xindex>,
// CHECK-SAME: %[[F:.*5]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
// CHECK-SAME: %[[B:.*0]]: tensor<?xindex>,
// CHECK-SAME: %[[C:.*1]]: tensor<?xindex>,
// CHECK-SAME: %[[A:.*2]]: tensor<?xf32>,
// CHECK-SAME: %[[E:.*3]]: tensor<?xindex>,
// CHECK-SAME: %[[F:.*4]]: tensor<?xindex>,
// CHECK-SAME: %[[D:.*5]]: tensor<?xf32>)
// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
// CHECK: %[[F:.*]] = call @_internal_sparse_inout(%[[I]])
// CHECK: sparse_tensor.disassemble %[[F]]
// CHECK: return
Expand All @@ -104,15 +104,15 @@ func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32,
// -----

// CHECK-LABEL: func.func @sparse_inout_coo_soa(
// CHECK-SAME: %[[A:.*0]]: tensor<?xf32>,
// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
// CHECK-SAME: %[[D:.*3]]: tensor<?xindex>,
// CHECK-SAME: %[[E:.*4]]: tensor<?xf32>,
// CHECK-SAME: %[[F:.*5]]: tensor<?xindex>,
// CHECK-SAME: %[[G:.*6]]: tensor<?xindex>,
// CHECK-SAME: %[[H:.*7]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>, tensor<?xindex>) {
// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]], %[[D]]
// CHECK-SAME: %[[B:.*0]]: tensor<?xindex>,
// CHECK-SAME: %[[C:.*1]]: tensor<?xindex>,
// CHECK-SAME: %[[D:.*2]]: tensor<?xindex>,
// CHECK-SAME: %[[A:.*3]]: tensor<?xf32>,
// CHECK-SAME: %[[F:.*4]]: tensor<?xindex>,
// CHECK-SAME: %[[G:.*5]]: tensor<?xindex>,
// CHECK-SAME: %[[H:.*6]]: tensor<?xindex>,
// CHECK-SAME: %[[E:.*7]]: tensor<?xf32>)
// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]], %[[D]]), %[[A]]
// CHECK: %[[F:.*]] = call @_internal_sparse_inout_coo_soa(%[[I]])
// CHECK: sparse_tensor.disassemble %[[F]]
// CHECK: return
Expand Down
Loading