Skip to content

Commit b6d0fd0

Browse files
address comments
1 parent 0a84872 commit b6d0fd0

File tree

3 files changed

+56
-40
lines changed

3 files changed

+56
-40
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
167167
- **soa** : only applicable to singleton levels, fuses the singleton
168168
level in SoA (structure of arrays) scheme.
169169

170-
In addition to the map, the following four fields are optional:
170+
In addition to the map, the following fields are optional:
171171

172172
- The required bitwidth for position storage (integral offsets
173173
into the sparse storage scheme). A narrow width reduces the memory
@@ -185,7 +185,10 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
185185

186186
- The explicit value for the sparse tensor. If explicitVal is set,
187187
then all the non-zero values in the tensor have the same explicit value.
188-
The default value Attribute() indicates that it is not set.
188+
The default value Attribute() indicates that it is not set. This
189+
is useful for binary-valued tensors whose values could only
190+
be 0 or 1, as we can set the explicit value to be 1 instead of
191+
storing the values array.
189192

190193
- The implicit value for the sparse tensor. If implicitVal is set,
191194
then the "zero" value in the tensor is equal to the implicit value.

mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,61 +22,61 @@ func.func private @sparse_csr(tensor<?x?xf32, #CSR>)
2222

2323
// -----
2424

25-
#CSR = #sparse_tensor.encoding<{
25+
#CSR_OnlyOnes = #sparse_tensor.encoding<{
2626
map = (d0, d1) -> (d0 : dense, d1 : compressed),
2727
posWidth = 64,
2828
crdWidth = 64,
2929
explicitVal = 1.0 : f32,
3030
implicitVal = 0.0 : f32
3131
}>
3232

33-
// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1.000000e+00 : f32, implicitVal = 0.000000e+00 : f32 }>
33+
// CHECK: #[[$CSR_OnlyOnes:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1.000000e+00 : f32, implicitVal = 0.000000e+00 : f32 }>
3434
// CHECK-LABEL: func private @sparse_csr(
35-
// CHECK-SAME: tensor<?x?xf32, #[[$CSR]]>)
36-
func.func private @sparse_csr(tensor<?x?xf32, #CSR>)
35+
// CHECK-SAME: tensor<?x?xf32, #[[$CSR_OnlyOnes]]>)
36+
func.func private @sparse_csr(tensor<?x?xf32, #CSR_OnlyOnes>)
3737

3838
// -----
3939

40-
#CSR = #sparse_tensor.encoding<{
40+
#CSR_OnlyOnes = #sparse_tensor.encoding<{
4141
map = (d0, d1) -> (d0 : dense, d1 : compressed),
4242
explicitVal = 1.0 : f64,
4343
implicitVal = 0.0 : f64
4444
}>
4545

46-
// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), explicitVal = 1.000000e+00 : f64, implicitVal = 0.000000e+00 : f64 }>
46+
// CHECK: #[[$CSR_OnlyOnes:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), explicitVal = 1.000000e+00 : f64, implicitVal = 0.000000e+00 : f64 }>
4747
// CHECK-LABEL: func private @sparse_csr(
48-
// CHECK-SAME: tensor<?x?xf64, #[[$CSR]]>)
49-
func.func private @sparse_csr(tensor<?x?xf64, #CSR>)
48+
// CHECK-SAME: tensor<?x?xf64, #[[$CSR_OnlyOnes]]>)
49+
func.func private @sparse_csr(tensor<?x?xf64, #CSR_OnlyOnes>)
5050

5151
// -----
5252

53-
#CSR = #sparse_tensor.encoding<{
53+
#CSR_OnlyOnes = #sparse_tensor.encoding<{
5454
map = (d0, d1) -> (d0 : dense, d1 : compressed),
5555
posWidth = 64,
5656
crdWidth = 64,
5757
explicitVal = 1 : i32,
5858
implicitVal = 0 : i32
5959
}>
6060

61-
// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1 : i32, implicitVal = 0 : i32 }>
61+
// CHECK: #[[$CSR_OnlyOnes:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1 : i32, implicitVal = 0 : i32 }>
6262
// CHECK-LABEL: func private @sparse_csr(
63-
// CHECK-SAME: tensor<?x?xi32, #[[$CSR]]>)
64-
func.func private @sparse_csr(tensor<?x?xi32, #CSR>)
63+
// CHECK-SAME: tensor<?x?xi32, #[[$CSR_OnlyOnes]]>)
64+
func.func private @sparse_csr(tensor<?x?xi32, #CSR_OnlyOnes>)
6565

6666
// -----
6767

68-
#CSR = #sparse_tensor.encoding<{
68+
#CSR_OnlyOnes = #sparse_tensor.encoding<{
6969
map = (d0, d1) -> (d0 : dense, d1 : compressed),
7070
posWidth = 64,
7171
crdWidth = 64,
7272
explicitVal = 1 : i64,
7373
implicitVal = 0 : i64
7474
}>
7575

76-
// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1 : i64, implicitVal = 0 : i64 }>
76+
// CHECK: #[[$CSR_OnlyOnes:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1 : i64, implicitVal = 0 : i64 }>
7777
// CHECK-LABEL: func private @sparse_csr(
78-
// CHECK-SAME: tensor<?x?xi64, #[[$CSR]]>)
79-
func.func private @sparse_csr(tensor<?x?xi64, #CSR>)
78+
// CHECK-SAME: tensor<?x?xi64, #[[$CSR_OnlyOnes]]>)
79+
func.func private @sparse_csr(tensor<?x?xi64, #CSR_OnlyOnes>)
8080

8181
// -----
8282

mlir/test/python/dialects/sparse_tensor/dialect.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from mlir.ir import *
44
from mlir.dialects import sparse_tensor as st
5+
import textwrap
56

67

78
def run(f):
@@ -15,12 +16,15 @@ def run(f):
1516
def testEncodingAttr1D():
1617
with Context() as ctx:
1718
parsed = Attribute.parse(
18-
"#sparse_tensor.encoding<{"
19-
" map = (d0) -> (d0 : compressed),"
20-
" posWidth = 16,"
21-
" crdWidth = 32,"
22-
" explicitVal = 1.0 : f64"
23-
"}>"
19+
textwrap.dedent("""\
20+
#sparse_tensor.encoding<{
21+
map = (d0) -> (d0 : compressed),
22+
posWidth = 16,
23+
crdWidth = 32,
24+
explicitVal = 1.0 : f64
25+
}>\
26+
"""
27+
)
2428
)
2529
# CHECK: #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed), posWidth = 16, crdWidth = 32, explicitVal = 1.000000e+00 : f64 }>
2630
print(parsed)
@@ -65,12 +69,15 @@ def testEncodingAttr1D():
6569
def testEncodingAttrStructure():
6670
with Context() as ctx:
6771
parsed = Attribute.parse(
68-
"#sparse_tensor.encoding<{"
69-
" map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense,"
70-
" d1 mod 4 : structured[2, 4]),"
71-
" posWidth = 16,"
72-
" crdWidth = 32"
73-
"}>"
72+
textwrap.dedent("""\
73+
#sparse_tensor.encoding<{
74+
map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense,
75+
d1 mod 4 : structured[2, 4]),
76+
posWidth = 16,
77+
crdWidth = 32,
78+
}>\
79+
"""
80+
)
7481
)
7582
# CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 floordiv 4 : dense, d1 mod 4 : structured[2, 4]), posWidth = 16, crdWidth = 32 }>
7683
print(parsed)
@@ -152,11 +159,14 @@ def testEncodingAttrStructure():
152159
def testEncodingAttr2D():
153160
with Context() as ctx:
154161
parsed = Attribute.parse(
155-
"#sparse_tensor.encoding<{"
156-
" map = (d0, d1) -> (d1 : dense, d0 : compressed),"
157-
" posWidth = 8,"
158-
" crdWidth = 32"
159-
"}>"
162+
textwrap.dedent("""\
163+
#sparse_tensor.encoding<{
164+
map = (d0, d1) -> (d1 : dense, d0 : compressed),
165+
posWidth = 8,
166+
crdWidth = 32,
167+
}>\
168+
"""
169+
)
160170
)
161171
# CHECK: #sparse_tensor.encoding<{ map = (d0, d1) -> (d1 : dense, d0 : compressed), posWidth = 8, crdWidth = 32 }>
162172
print(parsed)
@@ -195,11 +205,14 @@ def testEncodingAttrOnTensorType():
195205
with Context() as ctx, Location.unknown():
196206
encoding = st.EncodingAttr(
197207
Attribute.parse(
198-
"#sparse_tensor.encoding<{"
199-
" map = (d0) -> (d0 : compressed), "
200-
" posWidth = 64,"
201-
" crdWidth = 32"
202-
"}>"
208+
textwrap.dedent("""\
209+
#sparse_tensor.encoding<{
210+
map = (d0) -> (d0 : compressed),
211+
posWidth = 64,
212+
crdWidth = 32,
213+
}>\
214+
"""
215+
)
203216
)
204217
)
205218
tt = RankedTensorType.get((1024,), F32Type.get(), encoding=encoding)

0 commit comments

Comments
 (0)