-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][sparse] use a consistent order between [dis]assembleOp and sto… #84079
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-sparse Author: Peiming Liu (PeimingLiu) Changes…rage layout. Patch is 44.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/84079.diff 14 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 3a5447d29f866d..feed15d6af0544 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -55,8 +55,8 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [Pure]>,
}
def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]>,
- Arguments<(ins TensorOf<[AnyType]>:$values,
- Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels)>,
+ Arguments<(ins Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels,
+ TensorOf<[AnyType]>:$values)>,
Results<(outs AnySparseTensor: $result)> {
let summary = "Returns a sparse tensor assembled from the given values and levels";
@@ -96,20 +96,20 @@ def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]>,
}];
let assemblyFormat =
- "$values `,` $levels attr-dict"
- "`:` type($values) `,` type($levels) `to` type($result)";
+ "` ` `(` $levels `)` `,` $values attr-dict"
+ " `:` `(` type($levels) `)` `,` type($values) `to` type($result)";
let hasVerifier = 1;
}
def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVariadicResultSize]>,
Arguments<(ins AnySparseTensor:$tensor,
- TensorOf<[AnyType]>:$out_values,
- Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>,
- Results<(outs TensorOf<[AnyType]>:$ret_values,
- Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
- AnyIndexingScalarLike:$val_len,
- Variadic<AnyIndexingScalarLike>:$lvl_lens)> {
+ Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels,
+ TensorOf<[AnyType]>:$out_values)>,
+ Results<(outs Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
+ TensorOf<[AnyType]>:$ret_values,
+ Variadic<AnyIndexingScalarLike>:$lvl_lens,
+ AnyIndexingScalarLike:$val_len)> {
let summary = "Returns the (values, coordinates) pair disassembled from the input tensor";
let description = [{
@@ -134,8 +134,9 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
// |0.0, 0.0, 0.0, 0.0|
%v, %p, %c, %v_len, %p_len, %c_len =
sparse_tensor.disassemble %sp : tensor<3x4xf64, #COO>
- outs(%od, %op, %oi : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>)
- -> tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index)
+ out_lvls(%op, %oi) : tensor<2xindex>, tensor<3x2xindex>,
+ out_vals(%od) : tensor<3xf64> ->
+ tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index)
// %v = arith.constant dense<[ 1.1, 2.2, 3.3 ]> : tensor<3xf64>
// %p = arith.constant dense<[ 0, 3 ]> : tensor<2xindex>
// %c = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex>
@@ -147,8 +148,10 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
let assemblyFormat =
"$tensor `:` type($tensor) "
- "`outs` `(` $out_values `,` $out_levels `:` type($out_values) `,` type($out_levels) `)` attr-dict"
- "`->` type($ret_values) `,` `(` type($ret_levels) `)` `,` type($val_len) `,` `(` type($lvl_lens) `)`";
+ "`out_lvls` `(` $out_levels `:` type($out_levels) `)` "
+ "`out_vals` `(` $out_values `:` type($out_values) `)` attr-dict"
+ "`->` `(` type($ret_levels) `)` `,` type($ret_values) `,` "
+ "`(` type($lvl_lens) `)` `,` type($val_len)";
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index b39a2d9c57d8b0..617ff7d39dcfbd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -33,12 +33,12 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
}
// Convert the external representation of the values array.
const SparseTensorType stt(cast<RankedTensorType>(type));
- auto shape = stt.getBatchLvlShape();
- shape.push_back(ShapedType::kDynamic);
- auto vtp = RankedTensorType::get(shape, stt.getElementType());
- convTypes.push_back(vtp);
- if (extraTypes)
- extraTypes->push_back(vtp);
+ // auto shape = stt.getBatchLvlShape();
+ // shape.push_back(ShapedType::kDynamic);
+ // auto vtp = RankedTensorType::get(shape, stt.getElementType());
+ // convTypes.push_back(vtp);
+ // if (extraTypes)
+ // extraTypes->push_back(vtp);
// Convert the external representation of the position/coordinate array.
foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes](
@@ -46,7 +46,8 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
SparseTensorFieldKind kind,
Level, LevelType) {
if (kind == SparseTensorFieldKind::CrdMemRef ||
- kind == SparseTensorFieldKind::PosMemRef) {
+ kind == SparseTensorFieldKind::PosMemRef ||
+ kind == SparseTensorFieldKind::ValMemRef) {
ShapedType st = t.cast<ShapedType>();
auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
convTypes.push_back(rtp);
@@ -78,21 +79,16 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
SmallVector<Value> inputs;
SmallVector<Type> retTypes;
SmallVector<Type> cntTypes;
- // Collect the external representation of the values array for
- // input or the outgoing sparse tensor for output.
- inputs.push_back(fromVals[idx++]);
- if (!isIn) {
- inputs.push_back(extraVals[extra++]);
- retTypes.push_back(RankedTensorType::get(shape, stt.getElementType()));
- cntTypes.push_back(builder.getIndexType()); // nnz
- }
+ if (!isIn)
+ inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
// Collect the external representations of the pos/crd arrays.
foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
SparseTensorFieldKind kind,
Level, LevelType) {
if (kind == SparseTensorFieldKind::CrdMemRef ||
- kind == SparseTensorFieldKind::PosMemRef) {
+ kind == SparseTensorFieldKind::PosMemRef ||
+ kind == SparseTensorFieldKind::ValMemRef) {
if (isIn) {
inputs.push_back(fromVals[idx++]);
} else {
@@ -100,7 +96,7 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
inputs.push_back(extraVals[extra++]);
retTypes.push_back(rtp);
- cntTypes.push_back(rtp.getElementType());
+ cntTypes.push_back(builder.getIndexType());
}
}
return true;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index cb75f6a0ea8801..8be76cac87f297 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -928,8 +928,8 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter,
Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH);
Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH);
Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH);
- rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), vt,
- ValueRange{rt, ct});
+ rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), ValueRange{rt, ct},
+ vt);
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index eb45a29fb3894e..44c5d4dbe485bf 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1409,14 +1409,10 @@ struct SparseDisassembleOpConverter
sz = desc.getValMemSize(rewriter, loc);
src = desc.getValMemRef();
dst = genToMemref(rewriter, loc, op.getOutValues());
- // Values is the last field in descriptor, but it is the first
- // operand in unpack operation.
- // TODO: maybe change unpack/pack operation instead to be
- // consistent.
- retMem.insert(retMem.begin(), dst);
+
+ retMem.push_back(dst);
Type valLenTp = op.getValLen().getType();
- retLen.insert(retLen.begin(),
- genScalarToTensor(rewriter, loc, sz, valLenTp));
+ retLen.push_back(genScalarToTensor(rewriter, loc, sz, valLenTp));
} else {
assert(fKind == SparseTensorFieldKind::PosMemRef ||
fKind == SparseTensorFieldKind::CrdMemRef);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index b0447b2436619e..9a31785f5ce83b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -738,12 +738,6 @@ class SparseTensorDisassembleConverter
auto stt = getSparseTensorType(op.getTensor());
SmallVector<Value> retVal;
SmallVector<Value> retLen;
- // Get the values buffer first.
- auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
- auto valLenTp = op.getValLen().getType();
- auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
- retVal.push_back(vals);
- retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
// Then get the positions and coordinates buffers.
const Level lvlRank = stt.getLvlRank();
Level trailCOOLen = 0;
@@ -761,7 +755,7 @@ class SparseTensorDisassembleConverter
auto poss =
genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
- auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
retVal.push_back(poss);
retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
}
@@ -769,7 +763,7 @@ class SparseTensorDisassembleConverter
auto crds =
genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
- auto crdLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()];
retVal.push_back(crds);
retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
}
@@ -784,14 +778,13 @@ class SparseTensorDisassembleConverter
auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
cooStartLvl);
auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
- auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
retVal.push_back(poss);
retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
// Coordinates, copied over with:
// for (i = 0; i < crdLen; i++)
// buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
- auto buf =
- genToMemref(rewriter, loc, op.getOutLevels()[retLen.size() - 1]);
+ auto buf = genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]);
auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
cooStartLvl);
auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
@@ -814,10 +807,17 @@ class SparseTensorDisassembleConverter
args[1] = one;
rewriter.create<memref::StoreOp>(loc, c1, buf, args);
rewriter.setInsertionPointAfter(forOp);
- auto bufLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()];
retVal.push_back(buf);
retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp));
}
+ // Get the values buffer last.
+ auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
+ auto valLenTp = op.getValLen().getType();
+ auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
+ retVal.push_back(vals);
+ retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
+
// Converts MemRefs back to Tensors.
assert(retVal.size() + retLen.size() == op.getNumResults());
for (unsigned i = 0, sz = retVal.size(); i < sz; i++) {
@@ -825,6 +825,7 @@ class SparseTensorDisassembleConverter
retVal[i] =
rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor);
}
+
// Appends the actual memory length used in each buffer returned.
retVal.append(retLen.begin(), retLen.end());
rewriter.replaceOp(op, retVal);
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
index 7ac37c1c4950c0..fa8ad1cc506048 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
@@ -85,7 +85,7 @@
// CHECK: %[[VAL_a2:.*]] = bufferization.to_tensor %[[VAL_83]] : memref<?xf32>
// CHECK: %[[VAL_a3:.*]] = bufferization.to_tensor %[[VAL_81]] : memref<?xindex>
// CHECK: %[[VAL_a4:.*]] = bufferization.to_tensor %[[VAL_82]] : memref<?xindex>
-// CHECK: %[[VAL_a5:.*]] = sparse_tensor.assemble %[[VAL_a2]], %[[VAL_a3]], %[[VAL_a4]] : tensor<?xf32>, tensor<?xindex>, tensor<?xindex> to tensor<8x8xf32, #{{.*}}>
+// CHECK: %[[VAL_a5:.*]] = sparse_tensor.assemble (%[[VAL_a3]], %[[VAL_a4]]), %[[VAL_a2]] : (tensor<?xindex>, tensor<?xindex>), tensor<?xf32> to tensor<8x8xf32, #{{.*}}>
// CHECK: return %[[VAL_a5]] : tensor<8x8xf32, #{{.*}}>
// CHECK: }
func.func @matmulCSR(%A: tensor<8x8xf32, #CSR>,
diff --git a/mlir/test/Dialect/SparseTensor/external.mlir b/mlir/test/Dialect/SparseTensor/external.mlir
index b5701ad2024264..435737fc0979b5 100644
--- a/mlir/test/Dialect/SparseTensor/external.mlir
+++ b/mlir/test/Dialect/SparseTensor/external.mlir
@@ -13,10 +13,10 @@ func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> {
// -----
// CHECK-LABEL: func.func @sparse_in(
-// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
-// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
+// CHECK-SAME: %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*]]: tensor<?xf32>) -> tensor<64x64xf32> {
+// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
// CHECK: %[[F:.*]] = call @_internal_sparse_in(%[[I]])
// CHECK: return %[[F]] : tensor<64x64xf32>
// CHECK: }
@@ -30,11 +30,11 @@ func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
// -----
// CHECK-LABEL: func.func @sparse_in2(
-// CHECK-SAME: %[[X:.*]]: tensor<100xf32>,
-// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
-// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
+// CHECK-SAME: %[[X:.*0]]: tensor<100xf32>,
+// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*3]]: tensor<?xf32>) -> tensor<64x64xf32> {
+// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
// CHECK: %[[F:.*]] = call @_internal_sparse_in2(%[[X]], %[[I]])
// CHECK: return %[[F]] : tensor<64x64xf32>
// CHECK: }
@@ -48,10 +48,10 @@ func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>)
// -----
// CHECK-LABEL: func.func @sparse_out(
-// CHECK-SAME: %[[X:.*]]: tensor<64x64xf32>,
-// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK-SAME: %[[X:.*0]]: tensor<64x64xf32>,
+// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*3]]: tensor<?xf32>)
// CHECK: %[[F:.*]] = call @_internal_sparse_out(%[[X]])
// CHECK: sparse_tensor.disassemble %[[F]]
// CHECK: return
@@ -66,10 +66,10 @@ func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> {
// -----
// CHECK-LABEL: func.func @sparse_out2(
-// CHECK-SAME: %[[X:.*]]: tensor<64x64xf32>,
-// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> (tensor<64x64xf32>, tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK-SAME: %[[X:.*0]]: tensor<64x64xf32>,
+// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*3]]: tensor<?xf32>)
// CHECK: %[[F:.*]]:2 = call @_internal_sparse_out2(%[[X]])
// CHECK: sparse_tensor.disassemble %[[F]]#1
// CHECK: return %[[F]]#0
@@ -84,13 +84,13 @@ func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<6
// -----
// CHECK-LABEL: func.func @sparse_inout(
-// CHECK-SAME: %[[A:.*0]]: tensor<?xf32>,
-// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
-// CHECK-SAME: %[[D:.*3]]: tensor<?xf32>,
-// CHECK-SAME: %[[E:.*4]]: tensor<?xindex>,
-// CHECK-SAME: %[[F:.*5]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
-// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
+// CHECK-SAME: %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*2]]: tensor<?xf32>,
+// CHECK-SAME: %[[E:.*3]]: tensor<?xindex>,
+// CHECK-SAME: %[[F:.*4]]: tensor<?xindex>,
+// CHECK-SAME: %[[D:.*5]]: tensor<?xf32>)
+// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
// CHECK: %[[F:.*]] = call @_internal_sparse_inout(%[[I]])
// CHECK: sparse_tensor.disassemble %[[F]]
// CHECK: return
@@ -104,15 +104,15 @@ func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32,
// -----
// CHECK-LABEL: func.func @sparse_inout_coo_soa(
-// CHECK-SAME: %[[A:.*0]]: tensor<?xf32>,
-// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
-// CHECK-SAME: %[[D:.*3]]: tensor<?xindex>,
-// CHECK-SAME: %[[E:.*4]]: tensor<?xf32>,
-// CHECK-SAME: %[[F:.*5]]: tensor<?xindex>,
-// CHECK-SAME: %[[G:.*6]]: tensor<?xindex>,
-// CHECK-SAME: %[[H:.*7]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>, tensor<?xindex>) {
-// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]], %[[D]]
+// CHECK-SAME: %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[D:.*2]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*3]]: tensor<?xf32>,
+// CHECK-SAME: %[[F:.*4]]: tensor<?xindex>,
+// CHECK-SAME: %[[G:.*5]]: tensor<?xindex>,
+// CHECK-SAME: %[[H:.*6]]: tensor<?xindex>,
+// CHECK-SAME: %[[E:.*7]]: tensor<?xf32>)
+// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]], %[[D]]), %[[A]]
// CHECK: %[[F:.*]] = call @_internal_sparse_inout_coo_soa(%[[I]])
// CHECK: sparse_tensor.disassemble %[[F]]
// CHECK: return
diff --git a/...
[truncated]
|
@llvm/pr-subscribers-mlir-gpu Author: Peiming Liu (PeimingLiu) Changes…rage layout. Patch is 44.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/84079.diff 14 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 3a5447d29f866d..feed15d6af0544 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -55,8 +55,8 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [Pure]>,
}
def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]>,
- Arguments<(ins TensorOf<[AnyType]>:$values,
- Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels)>,
+ Arguments<(ins Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels,
+ TensorOf<[AnyType]>:$values)>,
Results<(outs AnySparseTensor: $result)> {
let summary = "Returns a sparse tensor assembled from the given values and levels";
@@ -96,20 +96,20 @@ def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]>,
}];
let assemblyFormat =
- "$values `,` $levels attr-dict"
- "`:` type($values) `,` type($levels) `to` type($result)";
+ "` ` `(` $levels `)` `,` $values attr-dict"
+ " `:` `(` type($levels) `)` `,` type($values) `to` type($result)";
let hasVerifier = 1;
}
def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVariadicResultSize]>,
Arguments<(ins AnySparseTensor:$tensor,
- TensorOf<[AnyType]>:$out_values,
- Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>,
- Results<(outs TensorOf<[AnyType]>:$ret_values,
- Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
- AnyIndexingScalarLike:$val_len,
- Variadic<AnyIndexingScalarLike>:$lvl_lens)> {
+ Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels,
+ TensorOf<[AnyType]>:$out_values)>,
+ Results<(outs Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
+ TensorOf<[AnyType]>:$ret_values,
+ Variadic<AnyIndexingScalarLike>:$lvl_lens,
+ AnyIndexingScalarLike:$val_len)> {
let summary = "Returns the (values, coordinates) pair disassembled from the input tensor";
let description = [{
@@ -134,8 +134,9 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
// |0.0, 0.0, 0.0, 0.0|
%v, %p, %c, %v_len, %p_len, %c_len =
sparse_tensor.disassemble %sp : tensor<3x4xf64, #COO>
- outs(%od, %op, %oi : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>)
- -> tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index)
+ out_lvls(%op, %oi) : tensor<2xindex>, tensor<3x2xindex>,
+ out_vals(%od) : tensor<3xf64> ->
+ tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index)
// %v = arith.constant dense<[ 1.1, 2.2, 3.3 ]> : tensor<3xf64>
// %p = arith.constant dense<[ 0, 3 ]> : tensor<2xindex>
// %c = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex>
@@ -147,8 +148,10 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
let assemblyFormat =
"$tensor `:` type($tensor) "
- "`outs` `(` $out_values `,` $out_levels `:` type($out_values) `,` type($out_levels) `)` attr-dict"
- "`->` type($ret_values) `,` `(` type($ret_levels) `)` `,` type($val_len) `,` `(` type($lvl_lens) `)`";
+ "`out_lvls` `(` $out_levels `:` type($out_levels) `)` "
+ "`out_vals` `(` $out_values `:` type($out_values) `)` attr-dict"
+ "`->` `(` type($ret_levels) `)` `,` type($ret_values) `,` "
+ "`(` type($lvl_lens) `)` `,` type($val_len)";
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index b39a2d9c57d8b0..617ff7d39dcfbd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -33,12 +33,12 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
}
// Convert the external representation of the values array.
const SparseTensorType stt(cast<RankedTensorType>(type));
- auto shape = stt.getBatchLvlShape();
- shape.push_back(ShapedType::kDynamic);
- auto vtp = RankedTensorType::get(shape, stt.getElementType());
- convTypes.push_back(vtp);
- if (extraTypes)
- extraTypes->push_back(vtp);
+ // auto shape = stt.getBatchLvlShape();
+ // shape.push_back(ShapedType::kDynamic);
+ // auto vtp = RankedTensorType::get(shape, stt.getElementType());
+ // convTypes.push_back(vtp);
+ // if (extraTypes)
+ // extraTypes->push_back(vtp);
// Convert the external representation of the position/coordinate array.
foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes](
@@ -46,7 +46,8 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
SparseTensorFieldKind kind,
Level, LevelType) {
if (kind == SparseTensorFieldKind::CrdMemRef ||
- kind == SparseTensorFieldKind::PosMemRef) {
+ kind == SparseTensorFieldKind::PosMemRef ||
+ kind == SparseTensorFieldKind::ValMemRef) {
ShapedType st = t.cast<ShapedType>();
auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
convTypes.push_back(rtp);
@@ -78,21 +79,16 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
SmallVector<Value> inputs;
SmallVector<Type> retTypes;
SmallVector<Type> cntTypes;
- // Collect the external representation of the values array for
- // input or the outgoing sparse tensor for output.
- inputs.push_back(fromVals[idx++]);
- if (!isIn) {
- inputs.push_back(extraVals[extra++]);
- retTypes.push_back(RankedTensorType::get(shape, stt.getElementType()));
- cntTypes.push_back(builder.getIndexType()); // nnz
- }
+ if (!isIn)
+ inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
// Collect the external representations of the pos/crd arrays.
foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
SparseTensorFieldKind kind,
Level, LevelType) {
if (kind == SparseTensorFieldKind::CrdMemRef ||
- kind == SparseTensorFieldKind::PosMemRef) {
+ kind == SparseTensorFieldKind::PosMemRef ||
+ kind == SparseTensorFieldKind::ValMemRef) {
if (isIn) {
inputs.push_back(fromVals[idx++]);
} else {
@@ -100,7 +96,7 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
inputs.push_back(extraVals[extra++]);
retTypes.push_back(rtp);
- cntTypes.push_back(rtp.getElementType());
+ cntTypes.push_back(builder.getIndexType());
}
}
return true;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index cb75f6a0ea8801..8be76cac87f297 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -928,8 +928,8 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter,
Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH);
Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH);
Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH);
- rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), vt,
- ValueRange{rt, ct});
+ rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), ValueRange{rt, ct},
+ vt);
return success();
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index eb45a29fb3894e..44c5d4dbe485bf 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1409,14 +1409,10 @@ struct SparseDisassembleOpConverter
sz = desc.getValMemSize(rewriter, loc);
src = desc.getValMemRef();
dst = genToMemref(rewriter, loc, op.getOutValues());
- // Values is the last field in descriptor, but it is the first
- // operand in unpack operation.
- // TODO: maybe change unpack/pack operation instead to be
- // consistent.
- retMem.insert(retMem.begin(), dst);
+
+ retMem.push_back(dst);
Type valLenTp = op.getValLen().getType();
- retLen.insert(retLen.begin(),
- genScalarToTensor(rewriter, loc, sz, valLenTp));
+ retLen.push_back(genScalarToTensor(rewriter, loc, sz, valLenTp));
} else {
assert(fKind == SparseTensorFieldKind::PosMemRef ||
fKind == SparseTensorFieldKind::CrdMemRef);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index b0447b2436619e..9a31785f5ce83b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -738,12 +738,6 @@ class SparseTensorDisassembleConverter
auto stt = getSparseTensorType(op.getTensor());
SmallVector<Value> retVal;
SmallVector<Value> retLen;
- // Get the values buffer first.
- auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
- auto valLenTp = op.getValLen().getType();
- auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
- retVal.push_back(vals);
- retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
// Then get the positions and coordinates buffers.
const Level lvlRank = stt.getLvlRank();
Level trailCOOLen = 0;
@@ -761,7 +755,7 @@ class SparseTensorDisassembleConverter
auto poss =
genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
- auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
retVal.push_back(poss);
retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
}
@@ -769,7 +763,7 @@ class SparseTensorDisassembleConverter
auto crds =
genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
- auto crdLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()];
retVal.push_back(crds);
retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
}
@@ -784,14 +778,13 @@ class SparseTensorDisassembleConverter
auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
cooStartLvl);
auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
- auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
retVal.push_back(poss);
retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
// Coordinates, copied over with:
// for (i = 0; i < crdLen; i++)
// buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
- auto buf =
- genToMemref(rewriter, loc, op.getOutLevels()[retLen.size() - 1]);
+ auto buf = genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]);
auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
cooStartLvl);
auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
@@ -814,10 +807,17 @@ class SparseTensorDisassembleConverter
args[1] = one;
rewriter.create<memref::StoreOp>(loc, c1, buf, args);
rewriter.setInsertionPointAfter(forOp);
- auto bufLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+ auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()];
retVal.push_back(buf);
retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp));
}
+ // Get the values buffer last.
+ auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
+ auto valLenTp = op.getValLen().getType();
+ auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
+ retVal.push_back(vals);
+ retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
+
// Converts MemRefs back to Tensors.
assert(retVal.size() + retLen.size() == op.getNumResults());
for (unsigned i = 0, sz = retVal.size(); i < sz; i++) {
@@ -825,6 +825,7 @@ class SparseTensorDisassembleConverter
retVal[i] =
rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor);
}
+
// Appends the actual memory length used in each buffer returned.
retVal.append(retLen.begin(), retLen.end());
rewriter.replaceOp(op, retVal);
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
index 7ac37c1c4950c0..fa8ad1cc506048 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
@@ -85,7 +85,7 @@
// CHECK: %[[VAL_a2:.*]] = bufferization.to_tensor %[[VAL_83]] : memref<?xf32>
// CHECK: %[[VAL_a3:.*]] = bufferization.to_tensor %[[VAL_81]] : memref<?xindex>
// CHECK: %[[VAL_a4:.*]] = bufferization.to_tensor %[[VAL_82]] : memref<?xindex>
-// CHECK: %[[VAL_a5:.*]] = sparse_tensor.assemble %[[VAL_a2]], %[[VAL_a3]], %[[VAL_a4]] : tensor<?xf32>, tensor<?xindex>, tensor<?xindex> to tensor<8x8xf32, #{{.*}}>
+// CHECK: %[[VAL_a5:.*]] = sparse_tensor.assemble (%[[VAL_a3]], %[[VAL_a4]]), %[[VAL_a2]] : (tensor<?xindex>, tensor<?xindex>), tensor<?xf32> to tensor<8x8xf32, #{{.*}}>
// CHECK: return %[[VAL_a5]] : tensor<8x8xf32, #{{.*}}>
// CHECK: }
func.func @matmulCSR(%A: tensor<8x8xf32, #CSR>,
diff --git a/mlir/test/Dialect/SparseTensor/external.mlir b/mlir/test/Dialect/SparseTensor/external.mlir
index b5701ad2024264..435737fc0979b5 100644
--- a/mlir/test/Dialect/SparseTensor/external.mlir
+++ b/mlir/test/Dialect/SparseTensor/external.mlir
@@ -13,10 +13,10 @@ func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> {
// -----
// CHECK-LABEL: func.func @sparse_in(
-// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
-// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
+// CHECK-SAME: %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*]]: tensor<?xf32>) -> tensor<64x64xf32> {
+// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
// CHECK: %[[F:.*]] = call @_internal_sparse_in(%[[I]])
// CHECK: return %[[F]] : tensor<64x64xf32>
// CHECK: }
@@ -30,11 +30,11 @@ func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
// -----
// CHECK-LABEL: func.func @sparse_in2(
-// CHECK-SAME: %[[X:.*]]: tensor<100xf32>,
-// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
-// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
+// CHECK-SAME: %[[X:.*0]]: tensor<100xf32>,
+// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*3]]: tensor<?xf32>) -> tensor<64x64xf32> {
+// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
// CHECK: %[[F:.*]] = call @_internal_sparse_in2(%[[X]], %[[I]])
// CHECK: return %[[F]] : tensor<64x64xf32>
// CHECK: }
@@ -48,10 +48,10 @@ func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>)
// -----
// CHECK-LABEL: func.func @sparse_out(
-// CHECK-SAME: %[[X:.*]]: tensor<64x64xf32>,
-// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK-SAME: %[[X:.*0]]: tensor<64x64xf32>,
+// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*3]]: tensor<?xf32>)
// CHECK: %[[F:.*]] = call @_internal_sparse_out(%[[X]])
// CHECK: sparse_tensor.disassemble %[[F]]
// CHECK: return
@@ -66,10 +66,10 @@ func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> {
// -----
// CHECK-LABEL: func.func @sparse_out2(
-// CHECK-SAME: %[[X:.*]]: tensor<64x64xf32>,
-// CHECK-SAME: %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME: %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*]]: tensor<?xindex>) -> (tensor<64x64xf32>, tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK-SAME: %[[X:.*0]]: tensor<64x64xf32>,
+// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*3]]: tensor<?xf32>)
// CHECK: %[[F:.*]]:2 = call @_internal_sparse_out2(%[[X]])
// CHECK: sparse_tensor.disassemble %[[F]]#1
// CHECK: return %[[F]]#0
@@ -84,13 +84,13 @@ func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<6
// -----
// CHECK-LABEL: func.func @sparse_inout(
-// CHECK-SAME: %[[A:.*0]]: tensor<?xf32>,
-// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
-// CHECK-SAME: %[[D:.*3]]: tensor<?xf32>,
-// CHECK-SAME: %[[E:.*4]]: tensor<?xindex>,
-// CHECK-SAME: %[[F:.*5]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
-// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
+// CHECK-SAME: %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*2]]: tensor<?xf32>,
+// CHECK-SAME: %[[E:.*3]]: tensor<?xindex>,
+// CHECK-SAME: %[[F:.*4]]: tensor<?xindex>,
+// CHECK-SAME: %[[D:.*5]]: tensor<?xf32>)
+// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
// CHECK: %[[F:.*]] = call @_internal_sparse_inout(%[[I]])
// CHECK: sparse_tensor.disassemble %[[F]]
// CHECK: return
@@ -104,15 +104,15 @@ func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32,
// -----
// CHECK-LABEL: func.func @sparse_inout_coo_soa(
-// CHECK-SAME: %[[A:.*0]]: tensor<?xf32>,
-// CHECK-SAME: %[[B:.*1]]: tensor<?xindex>,
-// CHECK-SAME: %[[C:.*2]]: tensor<?xindex>,
-// CHECK-SAME: %[[D:.*3]]: tensor<?xindex>,
-// CHECK-SAME: %[[E:.*4]]: tensor<?xf32>,
-// CHECK-SAME: %[[F:.*5]]: tensor<?xindex>,
-// CHECK-SAME: %[[G:.*6]]: tensor<?xindex>,
-// CHECK-SAME: %[[H:.*7]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>, tensor<?xindex>) {
-// CHECK: %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]], %[[D]]
+// CHECK-SAME: %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME: %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME: %[[D:.*2]]: tensor<?xindex>,
+// CHECK-SAME: %[[A:.*3]]: tensor<?xf32>,
+// CHECK-SAME: %[[F:.*4]]: tensor<?xindex>,
+// CHECK-SAME: %[[G:.*5]]: tensor<?xindex>,
+// CHECK-SAME: %[[H:.*6]]: tensor<?xindex>,
+// CHECK-SAME: %[[E:.*7]]: tensor<?xf32>)
+// CHECK: %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]], %[[D]]), %[[A]]
// CHECK: %[[F:.*]] = call @_internal_sparse_inout_coo_soa(%[[I]])
// CHECK: sparse_tensor.disassemble %[[F]]
// CHECK: return
diff --git a/...
[truncated]
|
…rage layout.