Skip to content

Commit aec223d

Browse files
author
Peiming Liu
committed
add verification rules.
1 parent 3343713 commit aec223d

File tree

3 files changed

+147
-15
lines changed

3 files changed

+147
-15
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1759,6 +1759,12 @@ def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
17591759
return I64BitSet(llvm::cast<IntegerAttr>(getCases()[regionIdx])
17601760
.getValue().getZExtValue());
17611761
}
1762+
auto getRegionDefinedSpaces() {
1763+
return llvm::map_range(getCases().getValue(), [](Attribute attr) {
1764+
return I64BitSet(llvm::cast<IntegerAttr>(attr).getValue().getZExtValue());
1765+
});
1766+
}
1767+
17621768
// The block arguments starts with referenced coordinates, follows by
17631769
// user-provided iteration arguments and ends with iterators.
17641770
Block::BlockArgListType getCrds(unsigned regionIdx) {
@@ -1776,12 +1782,11 @@ def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
17761782
return getRegion(regionIdx).getArguments()
17771783
.take_back(getRegionDefinedSpace(regionIdx).count());
17781784
}
1785+
ValueRange getYieldedValues(unsigned regionIdx);
17791786
}];
17801787

1781-
// TODO:
1782-
// let hasVerifier = 1;
1783-
// let hasRegionVerifier = 1;
1784-
// let hasCanonicalizer = 1;
1788+
let hasVerifier = 1;
1789+
let hasRegionVerifier = 1;
17851790
let hasCustomAssemblyFormat = 1;
17861791
}
17871792

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2497,6 +2497,21 @@ static void printInitializationList(OpAsmPrinter &p,
24972497
p << ")";
24982498
}
24992499

2500+
template <typename SparseLoopOp>
2501+
static LogicalResult verifySparseLoopOp(SparseLoopOp op) {
2502+
if (op.getInitArgs().size() != op.getNumResults()) {
2503+
return op.emitOpError(
2504+
"mismatch in number of loop-carried values and defined values");
2505+
}
2506+
if (op.getCrdUsedLvls().max() > op.getSpaceDim())
2507+
return op.emitOpError("required out-of-bound coordinates");
2508+
2509+
return success();
2510+
}
2511+
2512+
LogicalResult IterateOp::verify() { return verifySparseLoopOp(*this); }
2513+
LogicalResult CoIterateOp::verify() { return verifySparseLoopOp(*this); }
2514+
25002515
void IterateOp::print(OpAsmPrinter &p) {
25012516
p << " " << getIterator() << " in " << getIterSpace();
25022517
if (!getCrdUsedLvls().empty()) {
@@ -2515,17 +2530,6 @@ void IterateOp::print(OpAsmPrinter &p) {
25152530
/*printBlockTerminators=*/!getInitArgs().empty());
25162531
}
25172532

2518-
LogicalResult IterateOp::verify() {
2519-
if (getInitArgs().size() != getNumResults()) {
2520-
return emitOpError(
2521-
"mismatch in number of loop-carried values and defined values");
2522-
}
2523-
if (getCrdUsedLvls().max() > getSpaceDim())
2524-
return emitOpError("required out-of-bound coordinates");
2525-
2526-
return success();
2527-
}
2528-
25292533
LogicalResult IterateOp::verifyRegions() {
25302534
if (getIterator().getType() != getIterSpace().getType().getIteratorType())
25312535
return emitOpError("mismatch in iterator and iteration space type");
@@ -2665,6 +2669,54 @@ void CoIterateOp::print(OpAsmPrinter &p) {
26652669
}
26662670
}
26672671

2672+
ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) {
2673+
return cast<sparse_tensor::YieldOp>(
2674+
getRegion(regionIdx).getBlocks().front().getTerminator())
2675+
.getResults();
2676+
}
2677+
2678+
LogicalResult CoIterateOp::verifyRegions() {
2679+
for (unsigned r = 0, e = getNumRegions(); r < e; r++) {
2680+
if (getNumRegionIterArgs(r) != getNumResults())
2681+
return emitOpError(
2682+
"mismatch in number of basic block args and defined values");
2683+
2684+
auto initArgs = getInitArgs();
2685+
auto iterArgs = getRegionIterArgs(r);
2686+
auto yieldVals = getYieldedValues(r);
2687+
auto opResults = getResults();
2688+
if (!llvm::all_equal({initArgs.size(), iterArgs.size(), yieldVals.size(),
2689+
opResults.size()})) {
2690+
return emitOpError()
2691+
<< "number mismatch between iter args and results on " << r
2692+
<< "th region";
2693+
}
2694+
2695+
for (auto [i, init, iter, yield, ret] :
2696+
llvm::enumerate(initArgs, iterArgs, yieldVals, opResults)) {
2697+
if (init.getType() != ret.getType())
2698+
return emitOpError()
2699+
<< "types mismatch between " << i
2700+
<< "th iter operand and defined value on " << r << "th region";
2701+
if (iter.getType() != ret.getType())
2702+
return emitOpError() << "types mismatch between " << i
2703+
<< "th iter region arg and defined value on " << r
2704+
<< "th region";
2705+
if (yield.getType() != ret.getType())
2706+
return emitOpError()
2707+
<< "types mismatch between " << i
2708+
<< "th yield value and defined value on " << r << "th region";
2709+
}
2710+
}
2711+
2712+
auto cases = getRegionDefinedSpaces();
2713+
llvm::SmallSetVector<uint64_t, 8> set(cases.begin(), cases.end());
2714+
if (set.size() != getNumRegions())
2715+
return emitOpError("contains duplicated cases.");
2716+
2717+
return success();
2718+
}
2719+
26682720
//===----------------------------------------------------------------------===//
26692721
// Sparse Tensor Dialect Setups.
26702722
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SparseTensor/invalid.mlir

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1191,3 +1191,78 @@ func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -
11911191
}
11921192
return %r1 : index
11931193
}
1194+
1195+
// -----
1196+
1197+
#COO = #sparse_tensor.encoding<{
1198+
map = (i, j) -> (
1199+
i : compressed(nonunique),
1200+
j : singleton(soa)
1201+
)
1202+
}>
1203+
1204+
1205+
func.func @sparse_coiteration(%sp1 : !sparse_tensor.iter_space<#COO, lvls = 0>,
1206+
%sp2 : !sparse_tensor.iter_space<#COO, lvls = 1>) -> index {
1207+
%init = arith.constant 0 : index
1208+
// expected-error @+1 {{'sparse_tensor.coiterate' op contains duplicated cases.}}
1209+
%ret = sparse_tensor.coiterate (%sp1, %sp2) at (%coord) iter_args(%arg = %init)
1210+
: (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>)
1211+
-> index
1212+
case %it1, _ {
1213+
sparse_tensor.yield %arg : index
1214+
}
1215+
case %it1, _ {
1216+
sparse_tensor.yield %arg : index
1217+
}
1218+
return %ret : index
1219+
}
1220+
1221+
1222+
// -----
1223+
1224+
#COO = #sparse_tensor.encoding<{
1225+
map = (i, j) -> (
1226+
i : compressed(nonunique),
1227+
j : singleton(soa)
1228+
)
1229+
}>
1230+
1231+
1232+
func.func @sparse_coiteration(%sp1 : !sparse_tensor.iter_space<#COO, lvls = 0>,
1233+
%sp2 : !sparse_tensor.iter_space<#COO, lvls = 1>) -> index {
1234+
%init = arith.constant 0 : index
1235+
// expected-error @+1 {{'sparse_tensor.coiterate' op types mismatch between 0th yield value and defined value on 0th region}}
1236+
%ret = sparse_tensor.coiterate (%sp1, %sp2) at (%coord) iter_args(%arg = %init)
1237+
: (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>)
1238+
-> index
1239+
case %it1, _ {
1240+
%i = arith.constant 1 : i32
1241+
sparse_tensor.yield %i : i32
1242+
}
1243+
return %ret : index
1244+
}
1245+
1246+
// -----
1247+
1248+
#COO = #sparse_tensor.encoding<{
1249+
map = (i, j) -> (
1250+
i : compressed(nonunique),
1251+
j : singleton(soa)
1252+
)
1253+
}>
1254+
1255+
1256+
func.func @sparse_coiteration(%sp1 : !sparse_tensor.iter_space<#COO, lvls = 0>,
1257+
%sp2 : !sparse_tensor.iter_space<#COO, lvls = 1>) -> index {
1258+
%init = arith.constant 0 : index
1259+
// expected-error @+1 {{'sparse_tensor.coiterate' op required out-of-bound coordinates}}
1260+
%ret = sparse_tensor.coiterate (%sp1, %sp2) at (%coord1, %coord2) iter_args(%arg = %init)
1261+
: (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>)
1262+
-> index
1263+
case %it1, _ {
1264+
%i = arith.constant 1 : i32
1265+
sparse_tensor.yield %i : i32
1266+
}
1267+
return %ret : index
1268+
}

0 commit comments

Comments
 (0)