Skip to content

[mlir][sparse] use a consistent order between [dis]assembleOp and sto… #84079

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 6, 2024

Conversation

PeimingLiu
Copy link
Member

…rage layout.

@llvmbot
Copy link
Member

llvmbot commented Mar 5, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sparse

Author: Peiming Liu (PeimingLiu)

Changes

…rage layout.


Patch is 44.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/84079.diff

14 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+17-14)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp (+13-17)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp (+2-2)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+3-7)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp (+13-12)
  • (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/external.mlir (+33-33)
  • (modified) mlir/test/Dialect/SparseTensor/invalid.mlir (+17-14)
  • (modified) mlir/test/Dialect/SparseTensor/pack_copy.mlir (+18-20)
  • (modified) mlir/test/Dialect/SparseTensor/roundtrip.mlir (+13-12)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_pack.mlir (+9-8)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir (+8-7)
  • (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir (+17-14)
  • (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir (+6-7)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 3a5447d29f866d..feed15d6af0544 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -55,8 +55,8 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [Pure]>,
 }
 
 def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]>,
-    Arguments<(ins TensorOf<[AnyType]>:$values,
-                   Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels)>,
+    Arguments<(ins Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels,
+                   TensorOf<[AnyType]>:$values)>,
     Results<(outs AnySparseTensor: $result)> {
   let summary = "Returns a sparse tensor assembled from the given values and levels";
 
@@ -96,20 +96,20 @@ def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]>,
   }];
 
   let assemblyFormat =
-    "$values `,` $levels attr-dict"
-    "`:` type($values) `,` type($levels) `to` type($result)";
+    "` ` `(` $levels `)` `,` $values attr-dict"
+    " `:` `(` type($levels) `)` `,` type($values) `to` type($result)";
 
   let hasVerifier = 1;
 }
 
 def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVariadicResultSize]>,
     Arguments<(ins AnySparseTensor:$tensor,
-                   TensorOf<[AnyType]>:$out_values,
-                   Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>,
-    Results<(outs TensorOf<[AnyType]>:$ret_values,
-                  Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
-                  AnyIndexingScalarLike:$val_len,
-                  Variadic<AnyIndexingScalarLike>:$lvl_lens)> {
+                   Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels,
+                   TensorOf<[AnyType]>:$out_values)>,
+    Results<(outs Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
+                  TensorOf<[AnyType]>:$ret_values,
+                  Variadic<AnyIndexingScalarLike>:$lvl_lens,
+                  AnyIndexingScalarLike:$val_len)> {
   let summary = "Returns the (values, coordinates) pair disassembled from the input tensor";
 
   let description = [{
@@ -134,8 +134,9 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
     //                  |0.0, 0.0, 0.0, 0.0|
     %v, %p, %c, %v_len, %p_len, %c_len =
         sparse_tensor.disassemble %sp : tensor<3x4xf64, #COO>
-          outs(%od, %op, %oi : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>)
-                            -> tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index)
+          out_lvls(%op, %oi) : tensor<2xindex>, tensor<3x2xindex>,
+          out_vals(%od) : tensor<3xf64> ->
+          tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index)
     // %v = arith.constant dense<[ 1.1,   2.2,   3.3 ]> : tensor<3xf64>
     // %p = arith.constant dense<[ 0,              3 ]> : tensor<2xindex>
     // %c = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex>
@@ -147,8 +148,10 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
 
   let assemblyFormat =
     "$tensor `:` type($tensor) "
-    "`outs` `(` $out_values `,` $out_levels `:` type($out_values) `,` type($out_levels) `)` attr-dict"
-    "`->` type($ret_values) `,` `(` type($ret_levels) `)` `,` type($val_len) `,` `(` type($lvl_lens) `)`";
+    "`out_lvls` `(` $out_levels `:` type($out_levels) `)` "
+    "`out_vals` `(` $out_values `:` type($out_values) `)` attr-dict"
+    "`->` `(` type($ret_levels) `)` `,` type($ret_values) `,` "
+    "`(` type($lvl_lens) `)` `,` type($val_len)";
 
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index b39a2d9c57d8b0..617ff7d39dcfbd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -33,12 +33,12 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
     }
     // Convert the external representation of the values array.
     const SparseTensorType stt(cast<RankedTensorType>(type));
-    auto shape = stt.getBatchLvlShape();
-    shape.push_back(ShapedType::kDynamic);
-    auto vtp = RankedTensorType::get(shape, stt.getElementType());
-    convTypes.push_back(vtp);
-    if (extraTypes)
-      extraTypes->push_back(vtp);
+    //    auto shape = stt.getBatchLvlShape();
+    //    shape.push_back(ShapedType::kDynamic);
+    //    auto vtp = RankedTensorType::get(shape, stt.getElementType());
+    //    convTypes.push_back(vtp);
+    //    if (extraTypes)
+    //      extraTypes->push_back(vtp);
 
     // Convert the external representation of the position/coordinate array.
     foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes](
@@ -46,7 +46,8 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
                                                SparseTensorFieldKind kind,
                                                Level, LevelType) {
       if (kind == SparseTensorFieldKind::CrdMemRef ||
-          kind == SparseTensorFieldKind::PosMemRef) {
+          kind == SparseTensorFieldKind::PosMemRef ||
+          kind == SparseTensorFieldKind::ValMemRef) {
         ShapedType st = t.cast<ShapedType>();
         auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
         convTypes.push_back(rtp);
@@ -78,21 +79,16 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
     SmallVector<Value> inputs;
     SmallVector<Type> retTypes;
     SmallVector<Type> cntTypes;
-    // Collect the external representation of the values array for
-    // input or the outgoing sparse tensor for output.
-    inputs.push_back(fromVals[idx++]);
-    if (!isIn) {
-      inputs.push_back(extraVals[extra++]);
-      retTypes.push_back(RankedTensorType::get(shape, stt.getElementType()));
-      cntTypes.push_back(builder.getIndexType()); // nnz
-    }
+    if (!isIn)
+      inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
 
     // Collect the external representations of the pos/crd arrays.
     foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
                                                      SparseTensorFieldKind kind,
                                                      Level, LevelType) {
       if (kind == SparseTensorFieldKind::CrdMemRef ||
-          kind == SparseTensorFieldKind::PosMemRef) {
+          kind == SparseTensorFieldKind::PosMemRef ||
+          kind == SparseTensorFieldKind::ValMemRef) {
         if (isIn) {
           inputs.push_back(fromVals[idx++]);
         } else {
@@ -100,7 +96,7 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
           auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
           inputs.push_back(extraVals[extra++]);
           retTypes.push_back(rtp);
-          cntTypes.push_back(rtp.getElementType());
+          cntTypes.push_back(builder.getIndexType());
         }
       }
       return true;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index cb75f6a0ea8801..8be76cac87f297 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -928,8 +928,8 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter,
   Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH);
   Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH);
   Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH);
-  rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), vt,
-                                          ValueRange{rt, ct});
+  rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), ValueRange{rt, ct},
+                                          vt);
   return success();
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index eb45a29fb3894e..44c5d4dbe485bf 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1409,14 +1409,10 @@ struct SparseDisassembleOpConverter
         sz = desc.getValMemSize(rewriter, loc);
         src = desc.getValMemRef();
         dst = genToMemref(rewriter, loc, op.getOutValues());
-        // Values is the last field in descriptor, but it is the first
-        // operand in unpack operation.
-        // TODO: maybe change unpack/pack operation instead to be
-        // consistent.
-        retMem.insert(retMem.begin(), dst);
+
+        retMem.push_back(dst);
         Type valLenTp = op.getValLen().getType();
-        retLen.insert(retLen.begin(),
-                      genScalarToTensor(rewriter, loc, sz, valLenTp));
+        retLen.push_back(genScalarToTensor(rewriter, loc, sz, valLenTp));
       } else {
         assert(fKind == SparseTensorFieldKind::PosMemRef ||
                fKind == SparseTensorFieldKind::CrdMemRef);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index b0447b2436619e..9a31785f5ce83b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -738,12 +738,6 @@ class SparseTensorDisassembleConverter
     auto stt = getSparseTensorType(op.getTensor());
     SmallVector<Value> retVal;
     SmallVector<Value> retLen;
-    // Get the values buffer first.
-    auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
-    auto valLenTp = op.getValLen().getType();
-    auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
-    retVal.push_back(vals);
-    retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
     // Then get the positions and coordinates buffers.
     const Level lvlRank = stt.getLvlRank();
     Level trailCOOLen = 0;
@@ -761,7 +755,7 @@ class SparseTensorDisassembleConverter
         auto poss =
             genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
         auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
-        auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+        auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
         retVal.push_back(poss);
         retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
       }
@@ -769,7 +763,7 @@ class SparseTensorDisassembleConverter
         auto crds =
             genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
         auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
-        auto crdLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+        auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()];
         retVal.push_back(crds);
         retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
       }
@@ -784,14 +778,13 @@ class SparseTensorDisassembleConverter
       auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
                                    cooStartLvl);
       auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
-      auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+      auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
       retVal.push_back(poss);
       retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
       // Coordinates, copied over with:
       //    for (i = 0; i < crdLen; i++)
       //       buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
-      auto buf =
-          genToMemref(rewriter, loc, op.getOutLevels()[retLen.size() - 1]);
+      auto buf = genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]);
       auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
                                       cooStartLvl);
       auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
@@ -814,10 +807,17 @@ class SparseTensorDisassembleConverter
       args[1] = one;
       rewriter.create<memref::StoreOp>(loc, c1, buf, args);
       rewriter.setInsertionPointAfter(forOp);
-      auto bufLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+      auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()];
       retVal.push_back(buf);
       retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp));
     }
+    // Get the values buffer last.
+    auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
+    auto valLenTp = op.getValLen().getType();
+    auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
+    retVal.push_back(vals);
+    retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
+
     // Converts MemRefs back to Tensors.
     assert(retVal.size() + retLen.size() == op.getNumResults());
     for (unsigned i = 0, sz = retVal.size(); i < sz; i++) {
@@ -825,6 +825,7 @@ class SparseTensorDisassembleConverter
       retVal[i] =
           rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor);
     }
+
     // Appends the actual memory length used in each buffer returned.
     retVal.append(retLen.begin(), retLen.end());
     rewriter.replaceOp(op, retVal);
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
index 7ac37c1c4950c0..fa8ad1cc506048 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
@@ -85,7 +85,7 @@
 // CHECK:           %[[VAL_a2:.*]] = bufferization.to_tensor %[[VAL_83]] : memref<?xf32>
 // CHECK:           %[[VAL_a3:.*]] = bufferization.to_tensor %[[VAL_81]] : memref<?xindex>
 // CHECK:           %[[VAL_a4:.*]] = bufferization.to_tensor %[[VAL_82]] : memref<?xindex>
-// CHECK:           %[[VAL_a5:.*]] = sparse_tensor.assemble %[[VAL_a2]], %[[VAL_a3]], %[[VAL_a4]] : tensor<?xf32>, tensor<?xindex>, tensor<?xindex> to tensor<8x8xf32, #{{.*}}>
+// CHECK:           %[[VAL_a5:.*]] = sparse_tensor.assemble (%[[VAL_a3]], %[[VAL_a4]]), %[[VAL_a2]] : (tensor<?xindex>, tensor<?xindex>), tensor<?xf32> to tensor<8x8xf32, #{{.*}}>
 // CHECK:           return %[[VAL_a5]] : tensor<8x8xf32, #{{.*}}>
 // CHECK:         }
 func.func @matmulCSR(%A: tensor<8x8xf32, #CSR>,
diff --git a/mlir/test/Dialect/SparseTensor/external.mlir b/mlir/test/Dialect/SparseTensor/external.mlir
index b5701ad2024264..435737fc0979b5 100644
--- a/mlir/test/Dialect/SparseTensor/external.mlir
+++ b/mlir/test/Dialect/SparseTensor/external.mlir
@@ -13,10 +13,10 @@ func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> {
 // -----
 
 // CHECK-LABEL: func.func @sparse_in(
-// CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
-// CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
+// CHECK-SAME:    %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*]]: tensor<?xf32>) -> tensor<64x64xf32> {
+// CHECK:         %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
 // CHECK:         %[[F:.*]] = call @_internal_sparse_in(%[[I]])
 // CHECK:         return %[[F]] : tensor<64x64xf32>
 // CHECK:       }
@@ -30,11 +30,11 @@ func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
 // -----
 
 // CHECK-LABEL: func.func @sparse_in2(
-// CHECK-SAME:    %[[X:.*]]: tensor<100xf32>,
-// CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
-// CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
+// CHECK-SAME:    %[[X:.*0]]: tensor<100xf32>,
+// CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*3]]: tensor<?xf32>) -> tensor<64x64xf32> {
+// CHECK:         %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
 // CHECK:         %[[F:.*]] = call @_internal_sparse_in2(%[[X]], %[[I]])
 // CHECK:         return %[[F]] : tensor<64x64xf32>
 // CHECK:       }
@@ -48,10 +48,10 @@ func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>)
 // -----
 
 // CHECK-LABEL: func.func @sparse_out(
-// CHECK-SAME:    %[[X:.*]]: tensor<64x64xf32>,
-// CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK-SAME:    %[[X:.*0]]: tensor<64x64xf32>,
+// CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*3]]: tensor<?xf32>)
 // CHECK:         %[[F:.*]] = call @_internal_sparse_out(%[[X]])
 // CHECK:         sparse_tensor.disassemble %[[F]]
 // CHECK:         return
@@ -66,10 +66,10 @@ func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> {
 // -----
 
 // CHECK-LABEL: func.func @sparse_out2(
-// CHECK-SAME:    %[[X:.*]]: tensor<64x64xf32>,
-// CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> (tensor<64x64xf32>, tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK-SAME:    %[[X:.*0]]: tensor<64x64xf32>,
+// CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*3]]: tensor<?xf32>)
 // CHECK:         %[[F:.*]]:2 = call @_internal_sparse_out2(%[[X]])
 // CHECK:         sparse_tensor.disassemble %[[F]]#1
 // CHECK:         return %[[F]]#0
@@ -84,13 +84,13 @@ func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<6
 // -----
 
 // CHECK-LABEL: func.func @sparse_inout(
-// CHECK-SAME:    %[[A:.*0]]: tensor<?xf32>,
-// CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
-// CHECK-SAME:    %[[D:.*3]]: tensor<?xf32>,
-// CHECK-SAME:    %[[E:.*4]]: tensor<?xindex>,
-// CHECK-SAME:    %[[F:.*5]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
-// CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
+// CHECK-SAME:    %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*2]]: tensor<?xf32>,
+// CHECK-SAME:    %[[E:.*3]]: tensor<?xindex>,
+// CHECK-SAME:    %[[F:.*4]]: tensor<?xindex>,
+// CHECK-SAME:    %[[D:.*5]]: tensor<?xf32>)
+// CHECK:         %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
 // CHECK:         %[[F:.*]] = call @_internal_sparse_inout(%[[I]])
 // CHECK:         sparse_tensor.disassemble %[[F]]
 // CHECK:         return
@@ -104,15 +104,15 @@ func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32,
 // -----
 
 // CHECK-LABEL: func.func @sparse_inout_coo_soa(
-// CHECK-SAME:    %[[A:.*0]]: tensor<?xf32>,
-// CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
-// CHECK-SAME:    %[[D:.*3]]: tensor<?xindex>,
-// CHECK-SAME:    %[[E:.*4]]: tensor<?xf32>,
-// CHECK-SAME:    %[[F:.*5]]: tensor<?xindex>,
-// CHECK-SAME:    %[[G:.*6]]: tensor<?xindex>,
-// CHECK-SAME:    %[[H:.*7]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>, tensor<?xindex>) {
-// CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]], %[[D]]
+// CHECK-SAME:    %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[D:.*2]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*3]]: tensor<?xf32>,
+// CHECK-SAME:    %[[F:.*4]]: tensor<?xindex>,
+// CHECK-SAME:    %[[G:.*5]]: tensor<?xindex>,
+// CHECK-SAME:    %[[H:.*6]]: tensor<?xindex>,
+// CHECK-SAME:    %[[E:.*7]]: tensor<?xf32>)
+// CHECK:         %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]], %[[D]]), %[[A]]
 // CHECK:         %[[F:.*]] = call @_internal_sparse_inout_coo_soa(%[[I]])
 // CHECK:         sparse_tensor.disassemble %[[F]]
 // CHECK:         return
diff --git a/...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 5, 2024

@llvm/pr-subscribers-mlir-gpu

Author: Peiming Liu (PeimingLiu)

Changes

…rage layout.


Patch is 44.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/84079.diff

14 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+17-14)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp (+13-17)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp (+2-2)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+3-7)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp (+13-12)
  • (modified) mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir (+1-1)
  • (modified) mlir/test/Dialect/SparseTensor/external.mlir (+33-33)
  • (modified) mlir/test/Dialect/SparseTensor/invalid.mlir (+17-14)
  • (modified) mlir/test/Dialect/SparseTensor/pack_copy.mlir (+18-20)
  • (modified) mlir/test/Dialect/SparseTensor/roundtrip.mlir (+13-12)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_pack.mlir (+9-8)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir (+8-7)
  • (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack.mlir (+17-14)
  • (modified) mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_pack_d.mlir (+6-7)
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 3a5447d29f866d..feed15d6af0544 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -55,8 +55,8 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [Pure]>,
 }
 
 def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]>,
-    Arguments<(ins TensorOf<[AnyType]>:$values,
-                   Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels)>,
+    Arguments<(ins Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$levels,
+                   TensorOf<[AnyType]>:$values)>,
     Results<(outs AnySparseTensor: $result)> {
   let summary = "Returns a sparse tensor assembled from the given values and levels";
 
@@ -96,20 +96,20 @@ def SparseTensor_AssembleOp : SparseTensor_Op<"assemble", [Pure]>,
   }];
 
   let assemblyFormat =
-    "$values `,` $levels attr-dict"
-    "`:` type($values) `,` type($levels) `to` type($result)";
+    "` ` `(` $levels `)` `,` $values attr-dict"
+    " `:` `(` type($levels) `)` `,` type($values) `to` type($result)";
 
   let hasVerifier = 1;
 }
 
 def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVariadicResultSize]>,
     Arguments<(ins AnySparseTensor:$tensor,
-                   TensorOf<[AnyType]>:$out_values,
-                   Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels)>,
-    Results<(outs TensorOf<[AnyType]>:$ret_values,
-                  Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
-                  AnyIndexingScalarLike:$val_len,
-                  Variadic<AnyIndexingScalarLike>:$lvl_lens)> {
+                   Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$out_levels,
+                   TensorOf<[AnyType]>:$out_values)>,
+    Results<(outs Variadic<TensorOf<[AnySignlessIntegerOrIndex]>>:$ret_levels,
+                  TensorOf<[AnyType]>:$ret_values,
+                  Variadic<AnyIndexingScalarLike>:$lvl_lens,
+                  AnyIndexingScalarLike:$val_len)> {
   let summary = "Returns the (values, coordinates) pair disassembled from the input tensor";
 
   let description = [{
@@ -134,8 +134,9 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
     //                  |0.0, 0.0, 0.0, 0.0|
     %v, %p, %c, %v_len, %p_len, %c_len =
         sparse_tensor.disassemble %sp : tensor<3x4xf64, #COO>
-          outs(%od, %op, %oi : tensor<3xf64>, tensor<2xindex>, tensor<3x2xindex>)
-                            -> tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index)
+          out_lvls(%op, %oi) : tensor<2xindex>, tensor<3x2xindex>,
+          out_vals(%od) : tensor<3xf64> ->
+          tensor<3xf64>, (tensor<2xindex>, tensor<3x2xindex>), index, (index, index)
     // %v = arith.constant dense<[ 1.1,   2.2,   3.3 ]> : tensor<3xf64>
     // %p = arith.constant dense<[ 0,              3 ]> : tensor<2xindex>
     // %c = arith.constant dense<[[0,0], [1,2], [1,3]]> : tensor<3x2xindex>
@@ -147,8 +148,10 @@ def SparseTensor_DisassembleOp : SparseTensor_Op<"disassemble", [Pure, SameVaria
 
   let assemblyFormat =
     "$tensor `:` type($tensor) "
-    "`outs` `(` $out_values `,` $out_levels `:` type($out_values) `,` type($out_levels) `)` attr-dict"
-    "`->` type($ret_values) `,` `(` type($ret_levels) `)` `,` type($val_len) `,` `(` type($lvl_lens) `)`";
+    "`out_lvls` `(` $out_levels `:` type($out_levels) `)` "
+    "`out_vals` `(` $out_values `:` type($out_values) `)` attr-dict"
+    "`->` `(` type($ret_levels) `)` `,` type($ret_values) `,` "
+    "`(` type($lvl_lens) `)` `,` type($val_len)";
 
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index b39a2d9c57d8b0..617ff7d39dcfbd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -33,12 +33,12 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
     }
     // Convert the external representation of the values array.
     const SparseTensorType stt(cast<RankedTensorType>(type));
-    auto shape = stt.getBatchLvlShape();
-    shape.push_back(ShapedType::kDynamic);
-    auto vtp = RankedTensorType::get(shape, stt.getElementType());
-    convTypes.push_back(vtp);
-    if (extraTypes)
-      extraTypes->push_back(vtp);
+    //    auto shape = stt.getBatchLvlShape();
+    //    shape.push_back(ShapedType::kDynamic);
+    //    auto vtp = RankedTensorType::get(shape, stt.getElementType());
+    //    convTypes.push_back(vtp);
+    //    if (extraTypes)
+    //      extraTypes->push_back(vtp);
 
     // Convert the external representation of the position/coordinate array.
     foreachFieldAndTypeInSparseTensor(stt, [&convTypes, extraTypes](
@@ -46,7 +46,8 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
                                                SparseTensorFieldKind kind,
                                                Level, LevelType) {
       if (kind == SparseTensorFieldKind::CrdMemRef ||
-          kind == SparseTensorFieldKind::PosMemRef) {
+          kind == SparseTensorFieldKind::PosMemRef ||
+          kind == SparseTensorFieldKind::ValMemRef) {
         ShapedType st = t.cast<ShapedType>();
         auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
         convTypes.push_back(rtp);
@@ -78,21 +79,16 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
     SmallVector<Value> inputs;
     SmallVector<Type> retTypes;
     SmallVector<Type> cntTypes;
-    // Collect the external representation of the values array for
-    // input or the outgoing sparse tensor for output.
-    inputs.push_back(fromVals[idx++]);
-    if (!isIn) {
-      inputs.push_back(extraVals[extra++]);
-      retTypes.push_back(RankedTensorType::get(shape, stt.getElementType()));
-      cntTypes.push_back(builder.getIndexType()); // nnz
-    }
+    if (!isIn)
+      inputs.push_back(fromVals[idx++]); // The sparse tensor to disassemble
 
     // Collect the external representations of the pos/crd arrays.
     foreachFieldAndTypeInSparseTensor(stt, [&, isIn](Type t, FieldIndex,
                                                      SparseTensorFieldKind kind,
                                                      Level, LevelType) {
       if (kind == SparseTensorFieldKind::CrdMemRef ||
-          kind == SparseTensorFieldKind::PosMemRef) {
+          kind == SparseTensorFieldKind::PosMemRef ||
+          kind == SparseTensorFieldKind::ValMemRef) {
         if (isIn) {
           inputs.push_back(fromVals[idx++]);
         } else {
@@ -100,7 +96,7 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
           auto rtp = RankedTensorType::get(st.getShape(), st.getElementType());
           inputs.push_back(extraVals[extra++]);
           retTypes.push_back(rtp);
-          cntTypes.push_back(rtp.getElementType());
+          cntTypes.push_back(builder.getIndexType());
         }
       }
       return true;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index cb75f6a0ea8801..8be76cac87f297 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -928,8 +928,8 @@ static LogicalResult rewriteSpGEMM(PatternRewriter &rewriter,
   Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH);
   Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH);
   Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH);
-  rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), vt,
-                                          ValueRange{rt, ct});
+  rewriter.replaceOpWithNewOp<AssembleOp>(op, c.getType(), ValueRange{rt, ct},
+                                          vt);
   return success();
 }
 
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index eb45a29fb3894e..44c5d4dbe485bf 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1409,14 +1409,10 @@ struct SparseDisassembleOpConverter
         sz = desc.getValMemSize(rewriter, loc);
         src = desc.getValMemRef();
         dst = genToMemref(rewriter, loc, op.getOutValues());
-        // Values is the last field in descriptor, but it is the first
-        // operand in unpack operation.
-        // TODO: maybe change unpack/pack operation instead to be
-        // consistent.
-        retMem.insert(retMem.begin(), dst);
+
+        retMem.push_back(dst);
         Type valLenTp = op.getValLen().getType();
-        retLen.insert(retLen.begin(),
-                      genScalarToTensor(rewriter, loc, sz, valLenTp));
+        retLen.push_back(genScalarToTensor(rewriter, loc, sz, valLenTp));
       } else {
         assert(fKind == SparseTensorFieldKind::PosMemRef ||
                fKind == SparseTensorFieldKind::CrdMemRef);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index b0447b2436619e..9a31785f5ce83b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -738,12 +738,6 @@ class SparseTensorDisassembleConverter
     auto stt = getSparseTensorType(op.getTensor());
     SmallVector<Value> retVal;
     SmallVector<Value> retLen;
-    // Get the values buffer first.
-    auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
-    auto valLenTp = op.getValLen().getType();
-    auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
-    retVal.push_back(vals);
-    retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
     // Then get the positions and coordinates buffers.
     const Level lvlRank = stt.getLvlRank();
     Level trailCOOLen = 0;
@@ -761,7 +755,7 @@ class SparseTensorDisassembleConverter
         auto poss =
             genPositionsCall(rewriter, loc, stt, adaptor.getTensor(), l);
         auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
-        auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+        auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
         retVal.push_back(poss);
         retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
       }
@@ -769,7 +763,7 @@ class SparseTensorDisassembleConverter
         auto crds =
             genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(), l);
         auto crdLen = linalg::createOrFoldDimOp(rewriter, loc, crds, 0);
-        auto crdLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+        auto crdLenTp = op.getLvlLens().getTypes()[retLen.size()];
         retVal.push_back(crds);
         retLen.push_back(genScalarToTensor(rewriter, loc, crdLen, crdLenTp));
       }
@@ -784,14 +778,13 @@ class SparseTensorDisassembleConverter
       auto poss = genPositionsCall(rewriter, loc, stt, adaptor.getTensor(),
                                    cooStartLvl);
       auto posLen = linalg::createOrFoldDimOp(rewriter, loc, poss, 0);
-      auto posLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+      auto posLenTp = op.getLvlLens().getTypes()[retLen.size()];
       retVal.push_back(poss);
       retLen.push_back(genScalarToTensor(rewriter, loc, posLen, posLenTp));
       // Coordinates, copied over with:
       //    for (i = 0; i < crdLen; i++)
       //       buf[i][0] = crd0[i]; buf[i][1] = crd1[i];
-      auto buf =
-          genToMemref(rewriter, loc, op.getOutLevels()[retLen.size() - 1]);
+      auto buf = genToMemref(rewriter, loc, op.getOutLevels()[retLen.size()]);
       auto crds0 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
                                       cooStartLvl);
       auto crds1 = genCoordinatesCall(rewriter, loc, stt, adaptor.getTensor(),
@@ -814,10 +807,17 @@ class SparseTensorDisassembleConverter
       args[1] = one;
       rewriter.create<memref::StoreOp>(loc, c1, buf, args);
       rewriter.setInsertionPointAfter(forOp);
-      auto bufLenTp = op.getLvlLens().getTypes()[retLen.size() - 1];
+      auto bufLenTp = op.getLvlLens().getTypes()[retLen.size()];
       retVal.push_back(buf);
       retLen.push_back(genScalarToTensor(rewriter, loc, bufLen, bufLenTp));
     }
+    // Get the values buffer last.
+    auto vals = genValuesCall(rewriter, loc, stt, adaptor.getTensor());
+    auto valLenTp = op.getValLen().getType();
+    auto valLen = linalg::createOrFoldDimOp(rewriter, loc, vals, 0);
+    retVal.push_back(vals);
+    retLen.push_back(genScalarToTensor(rewriter, loc, valLen, valLenTp));
+
     // Converts MemRefs back to Tensors.
     assert(retVal.size() + retLen.size() == op.getNumResults());
     for (unsigned i = 0, sz = retVal.size(); i < sz; i++) {
@@ -825,6 +825,7 @@ class SparseTensorDisassembleConverter
       retVal[i] =
           rewriter.create<tensor::CastOp>(loc, op.getResultTypes()[i], tensor);
     }
+
     // Appends the actual memory length used in each buffer returned.
     retVal.append(retLen.begin(), retLen.end());
     rewriter.replaceOp(op, retVal);
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
index 7ac37c1c4950c0..fa8ad1cc506048 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_spgemm_lib.mlir
@@ -85,7 +85,7 @@
 // CHECK:           %[[VAL_a2:.*]] = bufferization.to_tensor %[[VAL_83]] : memref<?xf32>
 // CHECK:           %[[VAL_a3:.*]] = bufferization.to_tensor %[[VAL_81]] : memref<?xindex>
 // CHECK:           %[[VAL_a4:.*]] = bufferization.to_tensor %[[VAL_82]] : memref<?xindex>
-// CHECK:           %[[VAL_a5:.*]] = sparse_tensor.assemble %[[VAL_a2]], %[[VAL_a3]], %[[VAL_a4]] : tensor<?xf32>, tensor<?xindex>, tensor<?xindex> to tensor<8x8xf32, #{{.*}}>
+// CHECK:           %[[VAL_a5:.*]] = sparse_tensor.assemble (%[[VAL_a3]], %[[VAL_a4]]), %[[VAL_a2]] : (tensor<?xindex>, tensor<?xindex>), tensor<?xf32> to tensor<8x8xf32, #{{.*}}>
 // CHECK:           return %[[VAL_a5]] : tensor<8x8xf32, #{{.*}}>
 // CHECK:         }
 func.func @matmulCSR(%A: tensor<8x8xf32, #CSR>,
diff --git a/mlir/test/Dialect/SparseTensor/external.mlir b/mlir/test/Dialect/SparseTensor/external.mlir
index b5701ad2024264..435737fc0979b5 100644
--- a/mlir/test/Dialect/SparseTensor/external.mlir
+++ b/mlir/test/Dialect/SparseTensor/external.mlir
@@ -13,10 +13,10 @@ func.func @nop(%arg0: tensor<100xf32>) -> tensor<100xf32> {
 // -----
 
 // CHECK-LABEL: func.func @sparse_in(
-// CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
-// CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
+// CHECK-SAME:    %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*]]: tensor<?xf32>) -> tensor<64x64xf32> {
+// CHECK:         %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
 // CHECK:         %[[F:.*]] = call @_internal_sparse_in(%[[I]])
 // CHECK:         return %[[F]] : tensor<64x64xf32>
 // CHECK:       }
@@ -30,11 +30,11 @@ func.func @sparse_in(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32> {
 // -----
 
 // CHECK-LABEL: func.func @sparse_in2(
-// CHECK-SAME:    %[[X:.*]]: tensor<100xf32>,
-// CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> tensor<64x64xf32> {
-// CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
+// CHECK-SAME:    %[[X:.*0]]: tensor<100xf32>,
+// CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*3]]: tensor<?xf32>) -> tensor<64x64xf32> {
+// CHECK:         %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
 // CHECK:         %[[F:.*]] = call @_internal_sparse_in2(%[[X]], %[[I]])
 // CHECK:         return %[[F]] : tensor<64x64xf32>
 // CHECK:       }
@@ -48,10 +48,10 @@ func.func @sparse_in2(%arg0: tensor<100xf32>, %arg1: tensor<64x64xf32, #sparse>)
 // -----
 
 // CHECK-LABEL: func.func @sparse_out(
-// CHECK-SAME:    %[[X:.*]]: tensor<64x64xf32>,
-// CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK-SAME:    %[[X:.*0]]: tensor<64x64xf32>,
+// CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*3]]: tensor<?xf32>)
 // CHECK:         %[[F:.*]] = call @_internal_sparse_out(%[[X]])
 // CHECK:         sparse_tensor.disassemble %[[F]]
 // CHECK:         return
@@ -66,10 +66,10 @@ func.func @sparse_out(%arg0: tensor<64x64xf32>) -> tensor<64x64xf32, #sparse> {
 // -----
 
 // CHECK-LABEL: func.func @sparse_out2(
-// CHECK-SAME:    %[[X:.*]]: tensor<64x64xf32>,
-// CHECK-SAME:    %[[A:.*]]: tensor<?xf32>,
-// CHECK-SAME:    %[[B:.*]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*]]: tensor<?xindex>) -> (tensor<64x64xf32>, tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
+// CHECK-SAME:    %[[X:.*0]]: tensor<64x64xf32>,
+// CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*3]]: tensor<?xf32>)
 // CHECK:         %[[F:.*]]:2 = call @_internal_sparse_out2(%[[X]])
 // CHECK:         sparse_tensor.disassemble %[[F]]#1
 // CHECK:         return %[[F]]#0
@@ -84,13 +84,13 @@ func.func @sparse_out2(%arg0: tensor<64x64xf32>) -> (tensor<64x64xf32>, tensor<6
 // -----
 
 // CHECK-LABEL: func.func @sparse_inout(
-// CHECK-SAME:    %[[A:.*0]]: tensor<?xf32>,
-// CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
-// CHECK-SAME:    %[[D:.*3]]: tensor<?xf32>,
-// CHECK-SAME:    %[[E:.*4]]: tensor<?xindex>,
-// CHECK-SAME:    %[[F:.*5]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>) {
-// CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]]
+// CHECK-SAME:    %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*2]]: tensor<?xf32>,
+// CHECK-SAME:    %[[E:.*3]]: tensor<?xindex>,
+// CHECK-SAME:    %[[F:.*4]]: tensor<?xindex>,
+// CHECK-SAME:    %[[D:.*5]]: tensor<?xf32>)
+// CHECK:         %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]]), %[[A]]
 // CHECK:         %[[F:.*]] = call @_internal_sparse_inout(%[[I]])
 // CHECK:         sparse_tensor.disassemble %[[F]]
 // CHECK:         return
@@ -104,15 +104,15 @@ func.func @sparse_inout(%arg0: tensor<64x64xf32, #sparse>) -> tensor<64x64xf32,
 // -----
 
 // CHECK-LABEL: func.func @sparse_inout_coo_soa(
-// CHECK-SAME:    %[[A:.*0]]: tensor<?xf32>,
-// CHECK-SAME:    %[[B:.*1]]: tensor<?xindex>,
-// CHECK-SAME:    %[[C:.*2]]: tensor<?xindex>,
-// CHECK-SAME:    %[[D:.*3]]: tensor<?xindex>,
-// CHECK-SAME:    %[[E:.*4]]: tensor<?xf32>,
-// CHECK-SAME:    %[[F:.*5]]: tensor<?xindex>,
-// CHECK-SAME:    %[[G:.*6]]: tensor<?xindex>,
-// CHECK-SAME:    %[[H:.*7]]: tensor<?xindex>) -> (tensor<?xf32>, tensor<?xindex>, tensor<?xindex>, tensor<?xindex>) {
-// CHECK:         %[[I:.*]] = sparse_tensor.assemble %[[A]], %[[B]], %[[C]], %[[D]]
+// CHECK-SAME:    %[[B:.*0]]: tensor<?xindex>,
+// CHECK-SAME:    %[[C:.*1]]: tensor<?xindex>,
+// CHECK-SAME:    %[[D:.*2]]: tensor<?xindex>,
+// CHECK-SAME:    %[[A:.*3]]: tensor<?xf32>,
+// CHECK-SAME:    %[[F:.*4]]: tensor<?xindex>,
+// CHECK-SAME:    %[[G:.*5]]: tensor<?xindex>,
+// CHECK-SAME:    %[[H:.*6]]: tensor<?xindex>,
+// CHECK-SAME:    %[[E:.*7]]: tensor<?xf32>)
+// CHECK:         %[[I:.*]] = sparse_tensor.assemble (%[[B]], %[[C]], %[[D]]), %[[A]]
 // CHECK:         %[[F:.*]] = call @_internal_sparse_inout_coo_soa(%[[I]])
 // CHECK:         sparse_tensor.disassemble %[[F]]
 // CHECK:         return
diff --git a/...
[truncated]

@PeimingLiu PeimingLiu merged commit fc9f1d4 into llvm:main Mar 6, 2024
@PeimingLiu PeimingLiu deleted the merger-batch branch March 6, 2024 19:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:gpu mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants