Skip to content

Commit e71eacc

Browse files
[mlir][sparse] Support explicit/implicit value for complex type (#90771)
1 parent 0708500 commit e71eacc

File tree

6 files changed

+42
-9
lines changed

6 files changed

+42
-9
lines changed

mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ add_mlir_dialect_library(MLIRSparseTensorDialect
4545

4646
LINK_LIBS PUBLIC
4747
MLIRArithDialect
48+
MLIRComplexDialect
4849
MLIRDialect
4950
MLIRDialectUtils
5051
MLIRIR

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include "mlir/Dialect/Arith/IR/Arith.h"
1919
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
20+
#include "mlir/Dialect/Complex/IR/Complex.h"
2021
#include "mlir/Dialect/Utils/StaticValueUtils.h"
2122
#include "mlir/IR/Builders.h"
2223
#include "mlir/IR/DialectImplementation.h"
@@ -663,6 +664,8 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
663664
explicitVal = result;
664665
} else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
665666
explicitVal = result;
667+
} else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
668+
explicitVal = result;
666669
} else {
667670
parser.emitError(parser.getNameLoc(),
668671
"expected a numeric value for explicitVal");
@@ -678,6 +681,8 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
678681
implicitVal = result;
679682
} else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
680683
implicitVal = result;
684+
} else if (auto result = llvm::dyn_cast<complex::NumberAttr>(attr)) {
685+
implicitVal = result;
681686
} else {
682687
parser.emitError(parser.getNameLoc(),
683688
"expected a numeric value for implicitVal");

mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -401,9 +401,12 @@ inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc,
401401

402402
// Generates a constant from a validated value carrying attribute.
403403
inline Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr) {
404-
if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
405-
Type tp = cast<TypedAttr>(arrayAttr[0]).getType();
406-
return builder.create<complex::ConstantOp>(loc, tp, arrayAttr);
404+
if (auto complexAttr = dyn_cast<complex::NumberAttr>(attr)) {
405+
Type tp = cast<ComplexType>(complexAttr.getType()).getElementType();
406+
return builder.create<complex::ConstantOp>(
407+
loc, complexAttr.getType(),
408+
builder.getArrayAttr({FloatAttr::get(tp, complexAttr.getReal()),
409+
FloatAttr::get(tp, complexAttr.getImag())}));
407410
}
408411
return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(attr));
409412
}

mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,21 @@ func.func private @sparse_csr(tensor<?x?xi64, #CSR_OnlyOnes>)
8080

8181
// -----
8282

83+
#CSR_OnlyOnes = #sparse_tensor.encoding<{
84+
map = (d0, d1) -> (d0 : dense, d1 : compressed),
85+
posWidth = 64,
86+
crdWidth = 64,
87+
explicitVal = #complex.number<:f32 1.0, 0.0>,
88+
implicitVal = #complex.number<:f32 0.0, 0.0>
89+
}>
90+
91+
// CHECK: #[[$CSR_OnlyOnes:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = #complex.number<:f32 1.000000e+00, 0.000000e+00> : complex<f32>, implicitVal = #complex.number<:f32 0.000000e+00, 0.000000e+00> : complex<f32> }>
92+
// CHECK-LABEL: func private @sparse_csr(
93+
// CHECK-SAME: tensor<?x?xcomplex<f32>, #[[$CSR_OnlyOnes]]>)
94+
func.func private @sparse_csr(tensor<?x?xcomplex<f32>, #CSR_OnlyOnes>)
95+
96+
// -----
97+
8398
#BCSR = #sparse_tensor.encoding<{
8499
map = (d0, d1, d2) -> (d0 : batch, d1: dense, d2 : compressed),
85100
}>

mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
// RUN: --sparsification-and-bufferization | FileCheck %s
33

44
#CSR_ones_complex = #sparse_tensor.encoding<{
5-
map = (d0, d1) -> (d0 : dense, d1 : compressed)
6-
// explicitVal = (1.0, 0.0) : complex<f32>,
7-
// implicitVal = (0.0, 0.0) : complex<f32>
5+
map = (d0, d1) -> (d0 : dense, d1 : compressed),
6+
explicitVal = #complex.number<:f32 1.0, 0.0>,
7+
implicitVal = #complex.number<:f32 0.0, 0.0>
88
}>
99

1010
#CSR_ones_fp = #sparse_tensor.encoding<{
@@ -20,9 +20,17 @@
2020
}>
2121

2222
// CHECK-LABEL: func.func @matmul_complex
23-
//
24-
// TODO: make this work
25-
//
23+
// CHECK: scf.for
24+
// CHECK: scf.for
25+
// CHECK: %[[X:.*]] = memref.load
26+
// CHECK: scf.for
27+
// CHECK: %[[I:.*]] = memref.load
28+
// CHECK: %[[Y:.*]] = memref.load
29+
// CHECK: %[[M:.*]] = complex.add %[[Y]], %[[X]] : complex<f32>
30+
// CHECK: memref.store %[[M]]
31+
// CHECK: }
32+
// CHECK: }
33+
// CHECK: }
2634
func.func @matmul_complex(%a: tensor<10x20xcomplex<f32>>,
2735
%b: tensor<20x30xcomplex<f32>, #CSR_ones_complex>,
2836
%c: tensor<10x30xcomplex<f32>>) -> tensor<10x30xcomplex<f32>> {

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3066,6 +3066,7 @@ cc_library(
30663066
":ArithDialect",
30673067
":BufferizationInterfaces",
30683068
":BytecodeOpInterface",
3069+
":ComplexDialect",
30693070
":DialectUtils",
30703071
":IR",
30713072
":InferTypeOpInterface",

0 commit comments

Comments
 (0)