Skip to content

Commit 550eb45

Browse files
author
Peiming Liu
committed
[mlir][sparse] use a consistent order between [dis]assembleOp and storage layout.
1 parent 82cc2a6 commit 550eb45

File tree

14 files changed

+164
-171
lines changed

14 files changed

+164
-171
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [Pure]>,
5555
}
5656

5757
def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]>,
58-
Arguments<(ins TensorOf<[AnyType]>:$values,
59-
Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels)>,
58+
Arguments<(ins Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels,
59+
TensorOf<[AnyType]>:$values)>,
6060
Results<(outs AnySparseTensor: $result)> {
6161
let summary = "Returns a sparse tensor assembled from the given values and levels";
6262

@@ -96,20 +96,20 @@ def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]>,
9696
}];
9797

9898
let assemblyFormat =
99-
"$values `,` $levels attr-dict"
100-
"`:` type($values) `,` type($levels) `to` type($result)";
99+
"` ` `(` $levels `)` `,` $values attr-dict"
100+
" `:` `(` type($levels) `)` `,` type($values) `to` type($result)";
101101

102102
let hasVerifier = 1;
103103
}
104104

105105
def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVariadicResultSize]>,
106106
Arguments<(ins AnySparseTensor:$tensor,
107-
TensorOf<[AnyType]>:$out_values,
108-
Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>,
109-
Results<(outs TensorOf<[AnyType]>:$ret_values,
110-
Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
111-
AnyIndexingScalarLike:$val_len,
112-
Variadic<AnyIndexingScalarLike>:$lvl_lens)> {
107+
Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels,
108+
TensorOf<[AnyType]>:$out_values)>,
109+
Results<(outs Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
110+
TensorOf<[AnyType]>:$ret_values,
111+
Variadic<AnyIndexingScalarLike>:$lvl_lens,
112+
AnyIndexingScalarLike:$val_len)> {
113113
let summary = "Returns the (values, coordinates) pair disassembled from the input tensor";
114114

115115
let description = [{
@@ -134,8 +134,9 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
134134
// |0.0, 0.0, 0.0, 0.0|
135135
%v, %p, %c, %v_len, %p_len, %c_len =
136136
sparse_tensor.disassemble %sp : tensor<3x4xf64, #COO>
137-
outs(%od, %op, %oi : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>)
138-
-> tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index)
137+
out_lvls(%op, %oi) : tensor<2xindex>, tensor<3x2xindex>,
138+
out_vals(%od) : tensor<3xf64> ->
139+
tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index)
139140
// %v = arith.constant dense<[ 1.1, 2.2, 3.3 ]> : tensor<3xf64>
140141
// %p = arith.constant dense<[ 0, 3 ]> : tensor<2xindex>
141142
// %c = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex>
@@ -147,8 +148,10 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
147148

148149
let assemblyFormat =
149150
"$tensor `:` type($tensor) "
150-
"`outs` `(` $out_values `,` $out_levels `:` type($out_values) `,` type($out_levels) `)` attr-dict"
151-
"`->` type($ret_values) `,` `(` type($ret_levels) `)` `,` type($val_len) `,` `(` type($lvl_lens) `)`";
151+
"`out_lvls` `(` $out_levels `:` type($out_levels) `)` "
152+
"`out_vals` `(` $out_values `:` type($out_values) `)` attr-dict"
153+
"`->` `(` type($ret_levels) `)` `,` type($ret_values) `,` "
154+
"`(` type($lvl_lens) `)` `,` type($val_len)";
152155

153156
let hasVerifier = 1;
154157
}

mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp

Lines changed: 7 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,20 +33,14 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
3333
}
3434
// Convert the external representation of the values array.
3535
const SparseTensorType stt(cast<RankedTensorType>(type));
36-
auto shape = stt.getBatchLvlShape();
37-
shape.push_back(ShapedType::kDynamic);
38-
auto vtp = RankedTensorType::get(shape, stt.getElementType());
39-
convTypes.push_back(vtp);
40-
if (extraTypes)
41-
extraTypes->push_back(vtp);
42-
4336
// Convert the external representation of the position/coordinate array.
4437
foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes](
4538
Type t, FieldIndex,
4639
SparseTensorFieldKind kind,
4740
Level, LevelType) {
4841
if (kind == SparseTensorFieldKind::CrdMemRef ||
49-
kind == SparseTensorFieldKind::PosMemRef) {
42+
kind == SparseTensorFieldKind::PosMemRef ||
43+
kind == SparseTensorFieldKind::ValMemRef) {
5044
ShapedType st = t.cast<ShapedType>();
5145
auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
5246
convTypes.push_back(rtp);
@@ -73,34 +67,27 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
7367
// Convert the external representation of the values array.
7468
auto rtp = cast<RankedTensorType>(type);
7569
const SparseTensorType stt(rtp);
76-
auto shape = stt.getBatchLvlShape();
77-
shape.push_back(ShapedType::kDynamic);
7870
SmallVector<Value> inputs;
7971
SmallVector<Type> retTypes;
8072
SmallVector<Type> cntTypes;
81-
// Collect the external representation of the values array for
82-
// input or the outgoing sparse tensor for output.
83-
inputs.push_back(fromVals[idx++]);
84-
if (!isIn) {
85-
inputs.push_back(extraVals[extra++]);
86-
retTypes.push_back(RankedTensorType::get(shape, stt.getElementType()));
87-
cntTypes.push_back(builder.getIndexType()); // nnz
88-
}
73+
if (!isIn)
74+
inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
8975

9076
// Collect the external representations of the pos/crd arrays.
9177
foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
9278
SparseTensorFieldKind kind,
9379
Level, LevelType) {
9480
if (kind == SparseTensorFieldKind::CrdMemRef ||
95-
kind == SparseTensorFieldKind::PosMemRef) {
81+
kind == SparseTensorFieldKind::PosMemRef ||
82+
kind == SparseTensorFieldKind::ValMemRef) {
9683
if (isIn) {
9784
inputs.push_back(fromVals[idx++]);
9885
} else {
9986
ShapedType st = t.cast<ShapedType>();
10087
auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
10188
inputs.push_back(extraVals[extra++]);
10289
retTypes.push_back(rtp);
103-
cntTypes.push_back(rtp.getElementType());
90+
cntTypes.push_back(builder.getIndexType());
10491
}
10592
}
10693
return true;

mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -928,8 +928,8 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter,
928928
Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH);
929929
Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH);
930930
Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH);
931-
rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), vt,
932-
ValueRange{rt, ct});
931+
rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), ValueRange{rt, ct},
932+
vt);
933933
return success();
934934
}
935935

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,14 +1409,10 @@ struct SparseDisassembleOpConverter
14091409
sz = desc.getValMemSize(rewriter, loc);
14101410
src = desc.getValMemRef();
14111411
dst = genToMemref(rewriter, loc, op.getOutValues());
1412-
// Values is the last field in descriptor, but it is the first
1413-
// operand in unpack operation.
1414-
// TODO: maybe change unpack/pack operation instead to be
1415-
// consistent.
1416-
retMem.insert(retMem.begin(), dst);
1412+
1413+
retMem.push_back(dst);
14171414
Type valLenTp = op.getValLen().getType();
1418-
retLen.insert(retLen.begin(),
1419-
genScalarToTensor(rewriter, loc, sz, valLenTp));
1415+
retLen.push_back(genScalarToTensor(rewriter, loc, sz, valLenTp));
14201416
} else {
14211417
assert(fKind == SparseTensorFieldKind::PosMemRef ||
14221418
fKind == SparseTensorFieldKind::CrdMemRef);

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -738,12 +738,6 @@ class SparseTensorDisassembleConverter
738738
auto stt = getSparseTensorType(op.getTensor());
739739
SmallVector<Value> retVal;
740740
SmallVector<Value> retLen;
741-
// Get the values buffer first.
742-
auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
743-
auto valLenTp = op.getValLen().getType();
744-
auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
745-
retVal.push_back(vals);
746-
retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
747741
// Then get the positions and coordinates buffers.
748742
const Level lvlRank = stt.getLvlRank();
749743
Level trailCOOLen = 0;
@@ -761,15 +755,15 @@ class SparseTensorDisassembleConverter
761755
auto poss =
762756
genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
763757
auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
764-
auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
758+
auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
765759
retVal.push_back(poss);
766760
retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
767761
}
768762
if (stt.isWithCrd(l)) {
769763
auto crds =
770764
genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
771765
auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
772-
auto crdLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
766+
auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()];
773767
retVal.push_back(crds);
774768
retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
775769
}
@@ -784,14 +778,13 @@ class SparseTensorDisassembleConverter
784778
auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
785779
cooStartLvl);
786780
auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
787-
auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
781+
auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
788782
retVal.push_back(poss);
789783
retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
790784
// Coordinates, copied over with:
791785
// for (i = 0; i < crdLen; i++)
792786
// buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
793-
auto buf =
794-
genToMemref(rewriter, loc, op.getOutLevels()[retLen.size() - 1]);
787+
auto buf = genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]);
795788
auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
796789
cooStartLvl);
797790
auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
@@ -814,17 +807,25 @@ class SparseTensorDisassembleConverter
814807
args[1] = one;
815808
rewriter.create<memref::StoreOp>(loc, c1, buf, args);
816809
rewriter.setInsertionPointAfter(forOp);
817-
auto bufLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
810+
auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()];
818811
retVal.push_back(buf);
819812
retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp));
820813
}
814+
// Get the values buffer last.
815+
auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
816+
auto valLenTp = op.getValLen().getType();
817+
auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
818+
retVal.push_back(vals);
819+
retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
820+
821821
// Converts MemRefs back to Tensors.
822822
assert(retVal.size() + retLen.size() == op.getNumResults());
823823
for (unsigned i = 0, sz = retVal.size(); i < sz; i++) {
824824
auto tensor = rewriter.create<bufferization::ToTensorOp>(loc, retVal[i]);
825825
retVal[i] =
826826
rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor);
827827
}
828+
828829
// Appends the actual memory length used in each buffer returned.
829830
retVal.append(retLen.begin(), retLen.end());
830831
rewriter.replaceOp(op, retVal);

mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
// CHECK: %[[VAL_a2:.*]] = bufferization.to_tensor %[[VAL_83]] : memref<?xf32>
8686
// CHECK: %[[VAL_a3:.*]] = bufferization.to_tensor %[[VAL_81]] : memref<?xindex>
8787
// CHECK: %[[VAL_a4:.*]] = bufferization.to_tensor %[[VAL_82]] : memref<?xindex>
88-
// CHECK: %[[VAL_a5:.*]] = sparse_tensor.assemble %[[VAL_a2]], %[[VAL_a3]], %[[VAL_a4]] : tensor<?xf32>, tensor<?xindex>, tensor<?xindex> to tensor<8x8xf32, #{{.*}}>
88+
// CHECK: %[[VAL_a5:.*]] = sparse_tensor.assemble (%[[VAL_a3]], %[[VAL_a4]]), %[[VAL_a2]] : (tensor<?xindex>, tensor<?xindex>), tensor<?xf32> to tensor<8x8xf32, #{{.*}}>
8989
// CHECK: return %[[VAL_a5]] : tensor<8x8xf32, #{{.*}}>
9090
// CHECK: }
9191
func.func @matmulCSR(%A: tensor<8x8xf32, #CSR>,

mlir/test/Dialect/SparseTensor/external.mlir

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> {
1313
// -----
1414

1515
// CHECK-LABEL: func.func @sparse_in(
16-
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
17-
// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
18-
// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
19-
// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
16+
// CHECK-SAME: %[[B:.*0]]: tensor<?xindex>,
17+
// CHECK-SAME: %[[C:.*1]]: tensor<?xindex>,
18+
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>) -> tensor<64x64xf32> {
19+
// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
2020
// CHECK: %[[F:.*]] = call @_internal_sparse_in(%[[I]])
2121
// CHECK: return %[[F]] : tensor<64x64xf32>
2222
// CHECK: }
@@ -30,11 +30,11 @@ func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
3030
// -----
3131

3232
// CHECK-LABEL: func.func @sparse_in2(
33-
// CHECK-SAME: %[[X:.*]]: tensor<100xf32>,
34-
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
35-
// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
36-
// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
37-
// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
33+
// CHECK-SAME: %[[X:.*0]]: tensor<100xf32>,
34+
// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
35+
// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
36+
// CHECK-SAME: %[[A:.*3]]: tensor<?xf32>) -> tensor<64x64xf32> {
37+
// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
3838
// CHECK: %[[F:.*]] = call @_internal_sparse_in2(%[[X]], %[[I]])
3939
// CHECK: return %[[F]] : tensor<64x64xf32>
4040
// CHECK: }
@@ -48,10 +48,10 @@ func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>)
4848
// -----
4949

5050
// CHECK-LABEL: func.func @sparse_out(
51-
// CHECK-SAME: %[[X:.*]]: tensor<64x64xf32>,
52-
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
53-
// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
54-
// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
51+
// CHECK-SAME: %[[X:.*0]]: tensor<64x64xf32>,
52+
// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
53+
// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
54+
// CHECK-SAME: %[[A:.*3]]: tensor<?xf32>)
5555
// CHECK: %[[F:.*]] = call @_internal_sparse_out(%[[X]])
5656
// CHECK: sparse_tensor.disassemble %[[F]]
5757
// CHECK: return
@@ -66,10 +66,10 @@ func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> {
6666
// -----
6767

6868
// CHECK-LABEL: func.func @sparse_out2(
69-
// CHECK-SAME: %[[X:.*]]: tensor<64x64xf32>,
70-
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
71-
// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
72-
// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> (tensor<64x64xf32>, tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
69+
// CHECK-SAME: %[[X:.*0]]: tensor<64x64xf32>,
70+
// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
71+
// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
72+
// CHECK-SAME: %[[A:.*3]]: tensor<?xf32>)
7373
// CHECK: %[[F:.*]]:2 = call @_internal_sparse_out2(%[[X]])
7474
// CHECK: sparse_tensor.disassemble %[[F]]#1
7575
// CHECK: return %[[F]]#0
@@ -84,13 +84,13 @@ func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<6
8484
// -----
8585

8686
// CHECK-LABEL: func.func @sparse_inout(
87-
// CHECK-SAME: %[[A:.*0]]: tensor<?xf32>,
88-
// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
89-
// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
90-
// CHECK-SAME: %[[D:.*3]]: tensor<?xf32>,
91-
// CHECK-SAME: %[[E:.*4]]: tensor<?xindex>,
92-
// CHECK-SAME: %[[F:.*5]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
93-
// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
87+
// CHECK-SAME: %[[B:.*0]]: tensor<?xindex>,
88+
// CHECK-SAME: %[[C:.*1]]: tensor<?xindex>,
89+
// CHECK-SAME: %[[A:.*2]]: tensor<?xf32>,
90+
// CHECK-SAME: %[[E:.*3]]: tensor<?xindex>,
91+
// CHECK-SAME: %[[F:.*4]]: tensor<?xindex>,
92+
// CHECK-SAME: %[[D:.*5]]: tensor<?xf32>)
93+
// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
9494
// CHECK: %[[F:.*]] = call @_internal_sparse_inout(%[[I]])
9595
// CHECK: sparse_tensor.disassemble %[[F]]
9696
// CHECK: return
@@ -104,15 +104,15 @@ func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32,
104104
// -----
105105

106106
// CHECK-LABEL: func.func @sparse_inout_coo_soa(
107-
// CHECK-SAME: %[[A:.*0]]: tensor<?xf32>,
108-
// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
109-
// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
110-
// CHECK-SAME: %[[D:.*3]]: tensor<?xindex>,
111-
// CHECK-SAME: %[[E:.*4]]: tensor<?xf32>,
112-
// CHECK-SAME: %[[F:.*5]]: tensor<?xindex>,
113-
// CHECK-SAME: %[[G:.*6]]: tensor<?xindex>,
114-
// CHECK-SAME: %[[H:.*7]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>, tensor<?xindex>) {
115-
// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]], %[[D]]
107+
// CHECK-SAME: %[[B:.*0]]: tensor<?xindex>,
108+
// CHECK-SAME: %[[C:.*1]]: tensor<?xindex>,
109+
// CHECK-SAME: %[[D:.*2]]: tensor<?xindex>,
110+
// CHECK-SAME: %[[A:.*3]]: tensor<?xf32>,
111+
// CHECK-SAME: %[[F:.*4]]: tensor<?xindex>,
112+
// CHECK-SAME: %[[G:.*5]]: tensor<?xindex>,
113+
// CHECK-SAME: %[[H:.*6]]: tensor<?xindex>,
114+
// CHECK-SAME: %[[E:.*7]]: tensor<?xf32>)
115+
// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]], %[[D]]), %[[A]]
116116
// CHECK: %[[F:.*]] = call @_internal_sparse_inout_coo_soa(%[[I]])
117117
// CHECK: sparse_tensor.disassemble %[[F]]
118118
// CHECK: return

0 commit comments

Comments
 (0)