-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][sparse] fold explicit value during sparsification #90530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This ensures the explicit value is generated (and not a load into the values array). Note that actually not storing values array at all is still TBD, this is just the very first step.
@llvm/pr-subscribers-mlir-sparse @llvm/pr-subscribers-mlir Author: Aart Bik (aartbik) ChangesThis ensures the explicit value is generated (and not a load into the values array). Note that actually not storing values array at all is still TBD, this is just the very first step. Full diff: https://github.com/llvm/llvm-project/pull/90530.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 0a9bb40b458d68..b04ca11f714ba1 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -498,9 +498,17 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
Value val = env.exp(exp).val;
if (val)
return val;
- // Load during insertion.
+ // Get tensor operand.
linalg::GenericOp op = env.op();
+ Location loc = op.getLoc();
OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
+ // Fold binary-valued tensor into explicit value.
+ const auto stt = getSparseTensorType(t->get());
+ if (stt.hasEncoding()) {
+ if (auto explVal = stt.getExplicitVal())
+ return genValFromAttr(builder, loc, explVal);
+ }
+ // Load during insertion.
if (env.isSparseOutput(t)) {
if (env.isCustomReduc())
return genInsertionLoadReduce(env, builder, t);
@@ -509,7 +517,7 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
// Actual load.
SmallVector<Value> args;
Value ptr = genSubscript(env, builder, t, args);
- return builder.create<memref::LoadOp>(op.getLoc(), ptr, args);
+ return builder.create<memref::LoadOp>(loc, ptr, args);
}
/// Generates a store on a dense or sparse tensor.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
index ce5831d999e9a4..cf3c35f5fa4c78 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
@@ -399,6 +399,16 @@ inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc,
return constantI64(builder, loc, static_cast<uint64_t>(lt));
}
+// Generates a constant from a validated value carrying attribute.
+inline Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr) {
+ if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+ Type tp = cast<TypedAttr>(arrayAttr[0]).getType();
+ return builder.create<complex::ConstantOp>(loc, tp, arrayAttr);
+ }
+ return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(attr));
+}
+
+// TODO: is this at the right place?
inline bool isZeroRankedTensorOrScalar(Type type) {
auto rtp = dyn_cast<RankedTensorType>(type);
return !rtp || rtp.getRank() == 0;
diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir
new file mode 100755
index 00000000000000..09ec43b393d52d
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt %s --linalg-generalize-named-ops \
+// RUN: --sparsification-and-bufferization | FileCheck %s
+
+#CSR_ones_complex = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d0 : dense, d1 : compressed)
+// explicitVal = (1.0, 0.0) : complex<f32>,
+// implicitVal = (1.0, 0.0) : complex<f32>
+}>
+
+#CSR_ones_fp = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d0 : dense, d1 : compressed),
+ explicitVal = 1.0 : f32,
+ implicitVal = 0.0 : f32
+}>
+
+#CSR_ones_int = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d0 : dense, d1 : compressed),
+ explicitVal = 1 : i32,
+ implicitVal = 0 : i32
+}>
+
+// CHECK-LABEL: func.func @matmul_complex
+//
+// TODO: make this work
+//
+func.func @matmul_complex(%a: tensor<10x20xcomplex<f32>>,
+ %b: tensor<20x30xcomplex<f32>, #CSR_ones_complex>,
+ %c: tensor<10x30xcomplex<f32>>) -> tensor<10x30xcomplex<f32>> {
+ %0 = linalg.matmul
+ ins(%a, %b: tensor<10x20xcomplex<f32>>, tensor<20x30xcomplex<f32>,#CSR_ones_complex>)
+ outs(%c: tensor<10x30xcomplex<f32>>) -> tensor<10x30xcomplex<f32>>
+ return %0 : tensor<10x30xcomplex<f32>>
+}
+
+// CHECK-LABEL: func.func @matmul_fp
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: %[[X:.*]] = memref.load
+// CHECK: scf.for
+// CHECK: %[[I:.*]] = memref.load
+// CHECK: %[[Y:.*]] = memref.load
+// CHECK: %[[M:.*]] = arith.addf %[[Y]], %[[X]] : f32
+// CHECK: memref.store %[[M]]
+// CHECK: }
+// CHECK: }
+// CHECK: }
+func.func @matmul_fp(%a: tensor<10x20xf32>,
+ %b: tensor<20x30xf32, #CSR_ones_fp>,
+ %c: tensor<10x30xf32>) -> tensor<10x30xf32> {
+ %0 = linalg.matmul
+ ins(%a, %b: tensor<10x20xf32>, tensor<20x30xf32,#CSR_ones_fp>)
+ outs(%c: tensor<10x30xf32>) -> tensor<10x30xf32>
+ return %0 : tensor<10x30xf32>
+}
+
+// CHECK-LABEL: func.func @matmul_int
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: %[[X:.*]] = memref.load
+// CHECK: scf.for
+// CHECK: %[[I:.*]] = memref.load
+// CHECK: %[[Y:.*]] = memref.load
+// CHECK: %[[M:.*]] = arith.addi %[[Y]], %[[X]] : i32
+// CHECK: memref.store %[[M]]
+// CHECK: }
+// CHECK: }
+// CHECK: }
+func.func @matmul_int(%a: tensor<10x20xi32>,
+ %b: tensor<20x30xi32, #CSR_ones_int>,
+ %c: tensor<10x30xi32>) -> tensor<10x30xi32> {
+ %0 = linalg.matmul
+ ins(%a, %b: tensor<10x20xi32>, tensor<20x30xi32,#CSR_ones_int>)
+ outs(%c: tensor<10x30xi32>) -> tensor<10x30xi32>
+ return %0 : tensor<10x30xi32>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
neat!
This ensures the explicit value is generated (and not a load into the values array). Note that actually not storing values array at all is still TBD, this is just the very first step.