Skip to content

Commit 7e83a1a

Browse files
authored
[mlir][sparse] add verification of absent value in sparse_tensor.unary (#70248)
This value should always be a plain contant or something invariant computed outside the surrounding linalg operation, since there is no co-iteration defined on anything done in this branch. Fixes: #69395
1 parent 8958f0d commit 7e83a1a

File tree

3 files changed

+115
-57
lines changed

3 files changed

+115
-57
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 43 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -624,11 +624,11 @@ def SparseTensor_InsertOp : SparseTensor_Op<"insert",
624624
string summary = "Inserts a value into the sparse tensor";
625625
string description = [{
626626
Inserts the value into the underlying storage of the tensor at the
627-
given level-coordinates. The arity of `lvlCoords` must match the
628-
level-rank of the tensor. This operation can only be applied when
629-
the tensor materializes unintialized from a `bufferization.alloc_tensor`
630-
operation and the final tensor is constructed with a `load` operation
631-
which has the `hasInserts` attribute set.
627+
given level-coordinates. The arity of `lvlCoords` must match the
628+
level-rank of the tensor. This operation can only be applied when
629+
the tensor materializes unintialized from a `tensor.empty` operation
630+
and the final tensor is constructed with a `load` operation which
631+
has the `hasInserts` attribute set.
632632

633633
The level-properties of the sparse tensor type fully describe what
634634
kind of insertion order is allowed. When all levels have "unique"
@@ -974,7 +974,7 @@ def SparseTensor_BinaryOp : SparseTensor_Op<"binary", [Pure]>,
974974
Example of isEqual applied to intersecting elements only:
975975

976976
```mlir
977-
%C = bufferization.alloc_tensor...
977+
%C = tensor.empty(...)
978978
%0 = linalg.generic #trait
979979
ins(%A: tensor<?xf64, #SparseVector>,
980980
%B: tensor<?xf64, #SparseVector>)
@@ -996,7 +996,7 @@ def SparseTensor_BinaryOp : SparseTensor_Op<"binary", [Pure]>,
996996
Example of A+B in upper triangle, A-B in lower triangle:
997997

998998
```mlir
999-
%C = bufferization.alloc_tensor...
999+
%C = tensor.empty(...)
10001000
%1 = linalg.generic #trait
10011001
ins(%A: tensor<?x?xf64, #CSR>, %B: tensor<?x?xf64, #CSR>
10021002
outs(%C: tensor<?x?xf64, #CSR> {
@@ -1029,7 +1029,7 @@ def SparseTensor_BinaryOp : SparseTensor_Op<"binary", [Pure]>,
10291029
because we never use its values, only its sparse structure:
10301030

10311031
```mlir
1032-
%C = bufferization.alloc_tensor...
1032+
%C = tensor.empty(...)
10331033
%2 = linalg.generic #trait
10341034
ins(%A: tensor<?x?xf64, #CSR>, %B: tensor<?x?xi32, #CSR>
10351035
outs(%C: tensor<?x?xf64, #CSR> {
@@ -1069,7 +1069,9 @@ def SparseTensor_UnaryOp : SparseTensor_Op<"unary", [Pure]>,
10691069
Each region contains a single block describing the computation and result.
10701070
A non-empty block must end with a sparse_tensor.yield and the return type
10711071
must match the type of `output`. The primary region's block has one
1072-
argument, while the missing region's block has zero arguments.
1072+
argument, while the missing region's block has zero arguments. The
1073+
absent region may only generate constants or values already computed
1074+
on entry of the `linalg.generic` operation.
10731075

10741076
A region may also be declared empty (i.e. `absent={}`), indicating that the
10751077
region does not contribute to the output.
@@ -1082,17 +1084,17 @@ def SparseTensor_UnaryOp : SparseTensor_Op<"unary", [Pure]>,
10821084
Example of A+1, restricted to existing elements:
10831085

10841086
```mlir
1085-
%C = bufferization.alloc_tensor...
1087+
%C = tensor.empty(...) : tensor<?xf64, #SparseVector>
10861088
%0 = linalg.generic #trait
10871089
ins(%A: tensor<?xf64, #SparseVector>)
10881090
outs(%C: tensor<?xf64, #SparseVector>) {
10891091
^bb0(%a: f64, %c: f64) :
10901092
%result = sparse_tensor.unary %a : f64 to f64
10911093
present={
1092-
^bb0(%arg0: f64):
1093-
%cf1 = arith.constant 1.0 : f64
1094-
%ret = arith.addf %arg0, %cf1 : f64
1095-
sparse_tensor.yield %ret : f64
1094+
^bb0(%arg0: f64):
1095+
%cf1 = arith.constant 1.0 : f64
1096+
%ret = arith.addf %arg0, %cf1 : f64
1097+
sparse_tensor.yield %ret : f64
10961098
}
10971099
absent={}
10981100
linalg.yield %result : f64
@@ -1102,41 +1104,42 @@ def SparseTensor_UnaryOp : SparseTensor_Op<"unary", [Pure]>,
11021104
Example returning +1 for existing values and -1 for missing values:
11031105

11041106
```mlir
1105-
%C = bufferization.alloc_tensor...
1107+
%p1 = arith.constant 1 : i32
1108+
%m1 = arith.constant -1 : i32
1109+
%C = tensor.empty(...) : tensor<?xi32, #SparseVector>
11061110
%1 = linalg.generic #trait
11071111
ins(%A: tensor<?xf64, #SparseVector>)
1108-
outs(%C: tensor<?xf64, #SparseVector>) {
1109-
^bb0(%a: f64, %c: f64) :
1112+
outs(%C: tensor<?xi32, #SparseVector>) {
1113+
^bb0(%a: f64, %c: i32) :
11101114
%result = sparse_tensor.unary %a : f64 to i32
11111115
present={
11121116
^bb0(%x: f64):
1113-
%ret = arith.constant 1 : i32
1114-
sparse_tensor.yield %ret : i32
1115-
}
1116-
absent={
1117-
%ret = arith.constant -1 : i32
1118-
sparse_tensor.yield %ret : i32
1119-
}
1120-
linalg.yield %result : f64
1121-
} -> tensor<?xf64, #SparseVector>
1117+
sparse_tensor.yield %p1 : i32
1118+
}
1119+
absent={
1120+
sparse_tensor.yield %m1 : i32
1121+
}
1122+
linalg.yield %result : i32
1123+
} -> tensor<?xi32, #SparseVector>
11221124
```
11231125

11241126
Example showing a structural inversion (existing values become missing in
11251127
the output, while missing values are filled with 1):
11261128

11271129
```mlir
1128-
%C = bufferization.alloc_tensor...
1130+
%c1 = arith.constant 1 : i64
1131+
%C = tensor.empty(...) : tensor<?xi64, #SparseVector>
11291132
%2 = linalg.generic #trait
1130-
ins(%A: tensor<?xf64, #SparseVector>)
1131-
outs(%C: tensor<?xf64, #SparseVector>) {
1132-
%result = sparse_tensor.unary %a : f64 to i64
1133-
present={}
1134-
absent={
1135-
%ret = arith.constant 1 : i64
1136-
sparse_tensor.yield %ret : i64
1137-
}
1138-
linalg.yield %result : f64
1139-
} -> tensor<?xf64, #SparseVector>
1133+
ins(%A: tensor<?xf64, #SparseVector>)
1134+
outs(%C: tensor<?xi64, #SparseVector>) {
1135+
^bb0(%a: f64, %c: i64) :
1136+
%result = sparse_tensor.unary %a : f64 to i64
1137+
present={}
1138+
absent={
1139+
sparse_tensor.yield %c1 : i64
1140+
}
1141+
linalg.yield %result : i64
1142+
} -> tensor<?xi64, #SparseVector>
11401143
```
11411144
}];
11421145

@@ -1177,7 +1180,7 @@ def SparseTensor_ReduceOp : SparseTensor_Op<"reduce", [Pure, SameOperandsAndResu
11771180
```mlir
11781181
%cf1 = arith.constant 1.0 : f64
11791182
%cf100 = arith.constant 100.0 : f64
1180-
%C = bufferization.alloc_tensor...
1183+
%C = tensor.empty(...)
11811184
%0 = linalg.generic #trait
11821185
ins(%A: tensor<?x?xf64, #SparseMatrix>)
11831186
outs(%C: tensor<?xf64, #SparseVector>) {
@@ -1220,7 +1223,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
12201223
Example of selecting A >= 4.0:
12211224

12221225
```mlir
1223-
%C = bufferization.alloc_tensor...
1226+
%C = tensor.empty(...)
12241227
%0 = linalg.generic #trait
12251228
ins(%A: tensor<?xf64, #SparseVector>)
12261229
outs(%C: tensor<?xf64, #SparseVector>) {
@@ -1238,7 +1241,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
12381241
Example of selecting lower triangle of a matrix:
12391242

12401243
```mlir
1241-
%C = bufferization.alloc_tensor...
1244+
%C = tensor.empty(...)
12421245
%1 = linalg.generic #trait
12431246
ins(%A: tensor<?x?xf64, #CSR>)
12441247
outs(%C: tensor<?x?xf64, #CSR>) {

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,13 @@
3434
using namespace mlir;
3535
using namespace mlir::sparse_tensor;
3636

37+
#define RETURN_FAILURE_IF_FAILED(X) \
38+
if (failed(X)) { \
39+
return failure(); \
40+
}
41+
3742
//===----------------------------------------------------------------------===//
38-
// Additional convenience methods.
43+
// Local convenience methods.
3944
//===----------------------------------------------------------------------===//
4045

4146
static constexpr bool acceptBitWidth(unsigned bitWidth) {
@@ -52,7 +57,7 @@ static constexpr bool acceptBitWidth(unsigned bitWidth) {
5257
}
5358

5459
//===----------------------------------------------------------------------===//
55-
// StorageLayout
60+
// SparseTensorDialect StorageLayout.
5661
//===----------------------------------------------------------------------===//
5762

5863
static constexpr Level kInvalidLevel = -1u;
@@ -183,7 +188,7 @@ StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
183188
}
184189

185190
//===----------------------------------------------------------------------===//
186-
// TensorDialect Attribute Methods.
191+
// SparseTensorDialect Attribute Methods.
187192
//===----------------------------------------------------------------------===//
188193

189194
std::optional<uint64_t> SparseTensorDimSliceAttr::getStatic(int64_t v) {
@@ -658,11 +663,6 @@ SparseTensorEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
658663
return success();
659664
}
660665

661-
#define RETURN_FAILURE_IF_FAILED(X) \
662-
if (failed(X)) { \
663-
return failure(); \
664-
}
665-
666666
LogicalResult SparseTensorEncodingAttr::verifyEncoding(
667667
ArrayRef<DynSize> dimShape, Type elementType,
668668
function_ref<InFlightDiagnostic()> emitError) const {
@@ -685,7 +685,7 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
685685
}
686686

687687
//===----------------------------------------------------------------------===//
688-
// Convenience Methods.
688+
// Convenience methods.
689689
//===----------------------------------------------------------------------===//
690690

691691
SparseTensorEncodingAttr
@@ -1365,10 +1365,6 @@ LogicalResult SetStorageSpecifierOp::verify() {
13651365
return success();
13661366
}
13671367

1368-
//===----------------------------------------------------------------------===//
1369-
// TensorDialect Linalg.Generic Operations.
1370-
//===----------------------------------------------------------------------===//
1371-
13721368
template <class T>
13731369
static LogicalResult verifyNumBlockArgs(T *op, Region &region,
13741370
const char *regionName,
@@ -1445,6 +1441,18 @@ LogicalResult UnaryOp::verify() {
14451441
if (!absent.empty()) {
14461442
RETURN_FAILURE_IF_FAILED(
14471443
verifyNumBlockArgs(this, absent, "absent", TypeRange{}, outputType))
1444+
// Absent branch can only yield invariant values.
1445+
Block *absentBlock = &absent.front();
1446+
Block *parent = getOperation()->getBlock();
1447+
Value absentVal = cast<YieldOp>(absentBlock->getTerminator()).getResult();
1448+
if (auto arg = dyn_cast<BlockArgument>(absentVal)) {
1449+
if (arg.getOwner() == parent)
1450+
return emitError("absent region cannot yield linalg argument");
1451+
} else if (Operation *def = absentVal.getDefiningOp()) {
1452+
if (!isa<arith::ConstantOp>(def) &&
1453+
(def->getBlock() == absentBlock || def->getBlock() == parent))
1454+
return emitError("absent region cannot yield locally computed value");
1455+
}
14481456
}
14491457
return success();
14501458
}
@@ -1719,10 +1727,6 @@ LogicalResult YieldOp::verify() {
17191727

17201728
#undef RETURN_FAILURE_IF_FAILED
17211729

1722-
//===----------------------------------------------------------------------===//
1723-
// TensorDialect Methods.
1724-
//===----------------------------------------------------------------------===//
1725-
17261730
/// Materialize a single constant operation from a given attribute value with
17271731
/// the desired resultant type.
17281732
Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,

mlir/test/Dialect/SparseTensor/invalid.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,57 @@ func.func @invalid_unary_wrong_yield(%arg0: f64) -> f64 {
544544

545545
// -----
546546

547+
548+
#SparseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
549+
550+
#trait = {
551+
indexing_maps = [ affine_map<(i) -> (i)>, affine_map<(i) -> (i)> ],
552+
iterator_types = ["parallel"]
553+
}
554+
555+
func.func @invalid_absent_value(%arg0 : tensor<100xf64, #SparseVector>) -> tensor<100xf64, #SparseVector> {
556+
%C = tensor.empty() : tensor<100xf64, #SparseVector>
557+
%0 = linalg.generic #trait
558+
ins(%arg0: tensor<100xf64, #SparseVector>)
559+
outs(%C: tensor<100xf64, #SparseVector>) {
560+
^bb0(%a: f64, %c: f64) :
561+
// expected-error@+1 {{absent region cannot yield linalg argument}}
562+
%result = sparse_tensor.unary %a : f64 to f64
563+
present={}
564+
absent={ sparse_tensor.yield %a : f64 }
565+
linalg.yield %result : f64
566+
} -> tensor<100xf64, #SparseVector>
567+
return %0 : tensor<100xf64, #SparseVector>
568+
}
569+
570+
// -----
571+
572+
#SparseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
573+
574+
#trait = {
575+
indexing_maps = [ affine_map<(i) -> (i)>, affine_map<(i) -> (i)> ],
576+
iterator_types = ["parallel"]
577+
}
578+
579+
func.func @invalid_absent_computation(%arg0 : tensor<100xf64, #SparseVector>) -> tensor<100xf64, #SparseVector> {
580+
%f0 = arith.constant 0.0 : f64
581+
%C = tensor.empty() : tensor<100xf64, #SparseVector>
582+
%0 = linalg.generic #trait
583+
ins(%arg0: tensor<100xf64, #SparseVector>)
584+
outs(%C: tensor<100xf64, #SparseVector>) {
585+
^bb0(%a: f64, %c: f64) :
586+
%v = arith.addf %a, %f0 : f64
587+
// expected-error@+1 {{absent region cannot yield locally computed value}}
588+
%result = sparse_tensor.unary %a : f64 to f64
589+
present={}
590+
absent={ sparse_tensor.yield %v : f64 }
591+
linalg.yield %result : f64
592+
} -> tensor<100xf64, #SparseVector>
593+
return %0 : tensor<100xf64, #SparseVector>
594+
}
595+
596+
// -----
597+
547598
func.func @invalid_reduce_num_args_mismatch(%arg0: f64, %arg1: f64) -> f64 {
548599
%cf1 = arith.constant 1.0 : f64
549600
// expected-error@+1 {{reduce region must have exactly 2 arguments}}

0 commit comments

Comments
 (0)