Skip to content

Commit 65ee8f1

Browse files
authored
[mlir][sparse] fold explicit value during sparsification (#90530)
This ensures the explicit value is generated (and not a load into the values array). Note that actually not storing values array at all is still TBD, this is just the very first step.
1 parent 9a1386e commit 65ee8f1

File tree

4 files changed

+99
-4
lines changed

4 files changed

+99
-4
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,10 +344,14 @@ class SparseTensorType {
344344
unsigned getPosWidth() const { return enc ? enc.getPosWidth() : 0; }
345345

346346
/// Returns the explicit value, defaulting to null Attribute for unset.
347-
Attribute getExplicitVal() const { return enc.getExplicitVal(); }
347+
Attribute getExplicitVal() const {
348+
return enc ? enc.getExplicitVal() : nullptr;
349+
}
348350

349351
/// Returns the implicit value, defaulting to null Attribute for 0.
350-
Attribute getImplicitVal() const { return enc.getImplicitVal(); }
352+
Attribute getImplicitVal() const {
353+
return enc ? enc.getImplicitVal() : nullptr;
354+
}
351355

352356
/// Returns the coordinate-overhead MLIR type, defaulting to `IndexType`.
353357
Type getCrdType() const { return enc.getCrdElemType(); }

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -498,9 +498,15 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
498498
Value val = env.exp(exp).val;
499499
if (val)
500500
return val;
501-
// Load during insertion.
501+
// Get tensor operand.
502502
linalg::GenericOp op = env.op();
503+
Location loc = op.getLoc();
503504
OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
505+
// Fold binary-valued tensor into explicit value.
506+
const auto stt = getSparseTensorType(t->get());
507+
if (auto explVal = stt.getExplicitVal())
508+
return genValFromAttr(builder, loc, explVal);
509+
// Load during insertion.
504510
if (env.isSparseOutput(t)) {
505511
if (env.isCustomReduc())
506512
return genInsertionLoadReduce(env, builder, t);
@@ -509,7 +515,7 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
509515
// Actual load.
510516
SmallVector<Value> args;
511517
Value ptr = genSubscript(env, builder, t, args);
512-
return builder.create<memref::LoadOp>(op.getLoc(), ptr, args);
518+
return builder.create<memref::LoadOp>(loc, ptr, args);
513519
}
514520

515521
/// Generates a store on a dense or sparse tensor.

mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,16 @@ inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc,
399399
return constantI64(builder, loc, static_cast<uint64_t>(lt));
400400
}
401401

402+
// Generates a constant from a validated value carrying attribute.
403+
inline Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr) {
404+
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
405+
Type tp = cast<TypedAttr>(arrayAttr[0]).getType();
406+
return builder.create<complex::ConstantOp>(loc, tp, arrayAttr);
407+
}
408+
return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(attr));
409+
}
410+
411+
// TODO: is this at the right place?
402412
inline bool isZeroRankedTensorOrScalar(Type type) {
403413
auto rtp = dyn_cast<RankedTensorType>(type);
404414
return !rtp || rtp.getRank() == 0;
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// RUN: mlir-opt %s --linalg-generalize-named-ops \
2+
// RUN: --sparsification-and-bufferization | FileCheck %s
3+
4+
#CSR_ones_complex = #sparse_tensor.encoding<{
5+
map = (d0, d1) -> (d0 : dense, d1 : compressed)
6+
// explicitVal = (1.0, 0.0) : complex<f32>,
7+
// implicitVal = (0.0, 0.0) : complex<f32>
8+
}>
9+
10+
#CSR_ones_fp = #sparse_tensor.encoding<{
11+
map = (d0, d1) -> (d0 : dense, d1 : compressed),
12+
explicitVal = 1.0 : f32,
13+
implicitVal = 0.0 : f32
14+
}>
15+
16+
#CSR_ones_int = #sparse_tensor.encoding<{
17+
map = (d0, d1) -> (d0 : dense, d1 : compressed),
18+
explicitVal = 1 : i32,
19+
implicitVal = 0 : i32
20+
}>
21+
22+
// CHECK-LABEL: func.func @matmul_complex
23+
//
24+
// TODO: make this work
25+
//
26+
func.func @matmul_complex(%a: tensor<10x20xcomplex<f32>>,
27+
%b: tensor<20x30xcomplex<f32>, #CSR_ones_complex>,
28+
%c: tensor<10x30xcomplex<f32>>) -> tensor<10x30xcomplex<f32>> {
29+
%0 = linalg.matmul
30+
ins(%a, %b: tensor<10x20xcomplex<f32>>, tensor<20x30xcomplex<f32>,#CSR_ones_complex>)
31+
outs(%c: tensor<10x30xcomplex<f32>>) -> tensor<10x30xcomplex<f32>>
32+
return %0 : tensor<10x30xcomplex<f32>>
33+
}
34+
35+
// CHECK-LABEL: func.func @matmul_fp
36+
// CHECK: scf.for
37+
// CHECK: scf.for
38+
// CHECK: %[[X:.*]] = memref.load
39+
// CHECK: scf.for
40+
// CHECK: %[[I:.*]] = memref.load
41+
// CHECK: %[[Y:.*]] = memref.load
42+
// CHECK: %[[M:.*]] = arith.addf %[[Y]], %[[X]] : f32
43+
// CHECK: memref.store %[[M]]
44+
// CHECK: }
45+
// CHECK: }
46+
// CHECK: }
47+
func.func @matmul_fp(%a: tensor<10x20xf32>,
48+
%b: tensor<20x30xf32, #CSR_ones_fp>,
49+
%c: tensor<10x30xf32>) -> tensor<10x30xf32> {
50+
%0 = linalg.matmul
51+
ins(%a, %b: tensor<10x20xf32>, tensor<20x30xf32,#CSR_ones_fp>)
52+
outs(%c: tensor<10x30xf32>) -> tensor<10x30xf32>
53+
return %0 : tensor<10x30xf32>
54+
}
55+
56+
// CHECK-LABEL: func.func @matmul_int
57+
// CHECK: scf.for
58+
// CHECK: scf.for
59+
// CHECK: %[[X:.*]] = memref.load
60+
// CHECK: scf.for
61+
// CHECK: %[[I:.*]] = memref.load
62+
// CHECK: %[[Y:.*]] = memref.load
63+
// CHECK: %[[M:.*]] = arith.addi %[[Y]], %[[X]] : i32
64+
// CHECK: memref.store %[[M]]
65+
// CHECK: }
66+
// CHECK: }
67+
// CHECK: }
68+
func.func @matmul_int(%a: tensor<10x20xi32>,
69+
%b: tensor<20x30xi32, #CSR_ones_int>,
70+
%c: tensor<10x30xi32>) -> tensor<10x30xi32> {
71+
%0 = linalg.matmul
72+
ins(%a, %b: tensor<10x20xi32>, tensor<20x30xi32,#CSR_ones_int>)
73+
outs(%c: tensor<10x30xi32>) -> tensor<10x30xi32>
74+
return %0 : tensor<10x30xi32>
75+
}

0 commit comments

Comments
 (0)