Skip to content

[mlir][sparse] fold explicit value during sparsification #90530

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 30, 2024
Merged

Conversation

aartbik
Copy link
Contributor

@aartbik aartbik commented Apr 29, 2024

This ensures the explicit value is generated (and not a load into the values array). Note that actually not storing values array at all is still TBD, this is just the very first step.

This ensures the explicit value is generated (and not a load
into the values array). Note that actually not storing values
array at all is still TBD, this is just the very first step.
@llvmbot
Copy link
Member

llvmbot commented Apr 29, 2024

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Aart Bik (aartbik)

Changes

This ensures the explicit value is generated (and not a load into the values array). Note that actually not storing values array at all is still TBD, this is just the very first step.


Full diff: https://github.com/llvm/llvm-project/pull/90530.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+10-2)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h (+10)
  • (added) mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir (+75)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 0a9bb40b458d68..b04ca11f714ba1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -498,9 +498,17 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
   Value val = env.exp(exp).val;
   if (val)
     return val;
-  // Load during insertion.
+  // Get tensor operand.
   linalg::GenericOp op = env.op();
+  Location loc = op.getLoc();
   OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
+  // Fold binary-valued tensor into explicit value.
+  const auto stt = getSparseTensorType(t->get());
+  if (stt.hasEncoding()) {
+    if (auto explVal = stt.getExplicitVal())
+      return genValFromAttr(builder, loc, explVal);
+  }
+  // Load during insertion.
   if (env.isSparseOutput(t)) {
     if (env.isCustomReduc())
       return genInsertionLoadReduce(env, builder, t);
@@ -509,7 +517,7 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
   // Actual load.
   SmallVector<Value> args;
   Value ptr = genSubscript(env, builder, t, args);
-  return builder.create<memref::LoadOp>(op.getLoc(), ptr, args);
+  return builder.create<memref::LoadOp>(loc, ptr, args);
 }
 
 /// Generates a store on a dense or sparse tensor.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
index ce5831d999e9a4..cf3c35f5fa4c78 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
@@ -399,6 +399,16 @@ inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc,
   return constantI64(builder, loc, static_cast<uint64_t>(lt));
 }
 
+// Generates a constant from a validated value carrying attribute.
+inline Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr) {
+  if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+    Type tp = cast<TypedAttr>(arrayAttr[0]).getType();
+    return builder.create<complex::ConstantOp>(loc, tp, arrayAttr);
+  }
+  return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(attr));
+}
+
+// TODO: is this at the right place?
 inline bool isZeroRankedTensorOrScalar(Type type) {
   auto rtp = dyn_cast<RankedTensorType>(type);
   return !rtp || rtp.getRank() == 0;
diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir
new file mode 100755
index 00000000000000..09ec43b393d52d
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt %s --linalg-generalize-named-ops \
+// RUN:             --sparsification-and-bufferization | FileCheck %s
+
+#CSR_ones_complex = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed)
+// explicitVal = (1.0, 0.0) : complex<f32>,
+// implicitVal = (1.0, 0.0) : complex<f32>
+}>
+
+#CSR_ones_fp = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  explicitVal = 1.0 : f32,
+  implicitVal = 0.0 : f32
+}>
+
+#CSR_ones_int = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  explicitVal = 1 : i32,
+  implicitVal = 0 : i32
+}>
+
+// CHECK-LABEL:   func.func @matmul_complex
+//
+// TODO: make this work
+//
+func.func @matmul_complex(%a: tensor<10x20xcomplex<f32>>,
+                          %b: tensor<20x30xcomplex<f32>, #CSR_ones_complex>,
+                          %c: tensor<10x30xcomplex<f32>>) -> tensor<10x30xcomplex<f32>> {
+  %0 = linalg.matmul
+    ins(%a, %b: tensor<10x20xcomplex<f32>>, tensor<20x30xcomplex<f32>,#CSR_ones_complex>)
+    outs(%c: tensor<10x30xcomplex<f32>>) -> tensor<10x30xcomplex<f32>>
+  return %0 : tensor<10x30xcomplex<f32>>
+}
+
+// CHECK-LABEL:   func.func @matmul_fp
+// CHECK:         scf.for
+// CHECK:           scf.for
+// CHECK:             %[[X:.*]] = memref.load
+// CHECK:             scf.for
+// CHECK:               %[[I:.*]] = memref.load
+// CHECK:               %[[Y:.*]] = memref.load
+// CHECK:               %[[M:.*]] = arith.addf %[[Y]], %[[X]] : f32
+// CHECK:               memref.store %[[M]]
+// CHECK:             }
+// CHECK:           }
+// CHECK:         }
+func.func @matmul_fp(%a: tensor<10x20xf32>,
+                     %b: tensor<20x30xf32, #CSR_ones_fp>,
+                     %c: tensor<10x30xf32>) -> tensor<10x30xf32> {
+  %0 = linalg.matmul
+    ins(%a, %b: tensor<10x20xf32>, tensor<20x30xf32,#CSR_ones_fp>)
+    outs(%c: tensor<10x30xf32>) -> tensor<10x30xf32>
+  return %0 : tensor<10x30xf32>
+}
+
+// CHECK-LABEL:   func.func @matmul_int
+// CHECK:         scf.for
+// CHECK:           scf.for
+// CHECK:             %[[X:.*]] = memref.load
+// CHECK:             scf.for
+// CHECK:               %[[I:.*]] = memref.load
+// CHECK:               %[[Y:.*]] = memref.load
+// CHECK:               %[[M:.*]] = arith.addi %[[Y]], %[[X]] : i32
+// CHECK:               memref.store %[[M]]
+// CHECK:             }
+// CHECK:           }
+// CHECK:         }
+func.func @matmul_int(%a: tensor<10x20xi32>,
+                      %b: tensor<20x30xi32, #CSR_ones_int>,
+                      %c: tensor<10x30xi32>) -> tensor<10x30xi32> {
+  %0 = linalg.matmul
+    ins(%a, %b: tensor<10x20xi32>, tensor<20x30xi32,#CSR_ones_int>)
+    outs(%c: tensor<10x30xi32>) -> tensor<10x30xi32>
+  return %0 : tensor<10x30xi32>
+}

Copy link
Member

@PeimingLiu PeimingLiu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neat!

@aartbik aartbik requested a review from yinying-lisa-li April 29, 2024 23:06
@aartbik aartbik merged commit 65ee8f1 into llvm:main Apr 30, 2024
@aartbik aartbik deleted the bik branch April 30, 2024 01:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants