Skip to content

Commit 28fe1a4

Browse files
Tai78641eric-k256
authored andcommitted
[mlir] Add trait SameOperandsAndResultRank
This adds a native op trait SameOperandsAndResultRank and associated verifier that checks that an operator's operands and result types have same ranks if their ranks are known. Signed-off-by: Tai Ly <[email protected]> Change-Id: I2d536f77be10f3710d0c8d84c907ff492a984fda Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D156369
1 parent 1eab92b commit 28fe1a4

File tree

5 files changed

+93
-0
lines changed

5 files changed

+93
-0
lines changed

mlir/include/mlir/IR/OpDefinition.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ LogicalResult verifySameOperandsAndResultShape(Operation *op);
341341
LogicalResult verifySameOperandsElementType(Operation *op);
342342
LogicalResult verifySameOperandsAndResultElementType(Operation *op);
343343
LogicalResult verifySameOperandsAndResultType(Operation *op);
344+
LogicalResult verifySameOperandsAndResultRank(Operation *op);
344345
LogicalResult verifyResultsAreBoolLike(Operation *op);
345346
LogicalResult verifyResultsAreFloatLike(Operation *op);
346347
LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op);
@@ -1114,6 +1115,17 @@ class SameOperandsAndResultType
11141115
}
11151116
};
11161117

1118+
/// This class verifies that op has same ranks for all
1119+
/// operands and results types, if known.
1120+
template <typename ConcreteType>
1121+
class SameOperandsAndResultRank
1122+
: public TraitBase<ConcreteType, SameOperandsAndResultRank> {
1123+
public:
1124+
static LogicalResult verifyTrait(Operation *op) {
1125+
return impl::verifySameOperandsAndResultRank(op);
1126+
}
1127+
};
1128+
11171129
/// This class verifies that any results of the specified op have a boolean
11181130
/// type, a vector thereof, or a tensor thereof.
11191131
template <typename ConcreteType>

mlir/include/mlir/Interfaces/InferTypeOpInterface.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,4 +369,7 @@ def ReifyRankedShapedTypeOpInterface :
369369
// TODO: Change from hard coded to utilizing type inference trait.
370370
def SameOperandsAndResultType : NativeOpTrait<"SameOperandsAndResultType">;
371371

372+
// Op has the same ranks for all operands and results types, if known.
373+
def SameOperandsAndResultRank : NativeOpTrait<"SameOperandsAndResultRank">;
374+
372375
#endif // MLIR_INFERTYPEOPINTERFACE

mlir/lib/IR/Operation.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,51 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
10821082
return success();
10831083
}
10841084

1085+
LogicalResult OpTrait::impl::verifySameOperandsAndResultRank(Operation *op) {
1086+
if (failed(verifyAtLeastNOperands(op, 1)))
1087+
return failure();
1088+
1089+
// delegate function that returns true if type is a shaped type with known
1090+
// rank
1091+
auto hasRank = [](const Type type) {
1092+
if (auto shaped_type = dyn_cast<ShapedType>(type))
1093+
return shaped_type.hasRank();
1094+
1095+
return false;
1096+
};
1097+
1098+
auto rankedOperandTypes =
1099+
llvm::make_filter_range(op->getOperandTypes(), hasRank);
1100+
auto rankedResultTypes =
1101+
llvm::make_filter_range(op->getResultTypes(), hasRank);
1102+
1103+
// If all operands and results are unranked, then no further verification.
1104+
if (rankedOperandTypes.empty() && rankedResultTypes.empty())
1105+
return success();
1106+
1107+
// delegate function that returns rank of shaped type with known rank
1108+
auto getRank = [](const Type type) {
1109+
return type.cast<ShapedType>().getRank();
1110+
};
1111+
1112+
auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
1113+
: getRank(*rankedResultTypes.begin());
1114+
1115+
for (const auto type : rankedOperandTypes) {
1116+
if (rank != getRank(type)) {
1117+
return op->emitOpError("operands don't have matching ranks");
1118+
}
1119+
}
1120+
1121+
for (const auto type : rankedResultTypes) {
1122+
if (rank != getRank(type)) {
1123+
return op->emitOpError("result type has different rank than operands");
1124+
}
1125+
}
1126+
1127+
return success();
1128+
}
1129+
10851130
LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) {
10861131
Block *block = op->getBlock();
10871132
// Verify that the operation is at the end of the respective parent block.

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,12 @@ def OperandZeroAndResultHaveSameRank :
692692
let results = (outs AnyShaped:$res);
693693
}
694694

695+
def OperandsAndResultHaveSameRank :
696+
TEST_Op<"operands_and_result_have_same_rank", [SameOperandsAndResultRank]> {
697+
let arguments = (ins AnyShaped:$x, AnyShaped:$y);
698+
let results = (outs AnyShaped:$res);
699+
}
700+
695701
def OperandZeroAndResultHaveSameShape :
696702
TEST_Op<"operand0_and_result_have_same_shape",
697703
[AllShapesMatch<["x", "res"]>]> {

mlir/test/mlir-tblgen/types.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,33 @@ func.func @same_rank_failure(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xf32>) {
377377

378378
// -----
379379

380+
// CHECK-LABEL: same_rank_if_known_success
381+
func.func @same_rank_if_known_success(%t1xi : tensor<1xi32>, %t2xf : tensor<2xf32>, %m3xi : memref<3xi32>, %t1x2xf : tensor<1x2xf32>, %tuxi : tensor<*xi32>) {
382+
%0 = "test.operands_and_result_have_same_rank"(%t1xi, %t2xf) : (tensor<1xi32>, tensor<2xf32>) -> (tensor<3xf64>)
383+
%1 = "test.operands_and_result_have_same_rank"(%t1xi, %m3xi) : (tensor<1xi32>, memref<3xi32>) -> (tensor<3xi64>)
384+
%3 = "test.operands_and_result_have_same_rank"(%tuxi, %t2xf) : (tensor<*xi32>, tensor<2xf32>) -> (tensor<2xf32>)
385+
%4 = "test.operands_and_result_have_same_rank"(%t1x2xf, %tuxi) : (tensor<1x2xf32>, tensor<*xi32>) -> (tensor<1x2xf64>)
386+
return
387+
}
388+
389+
// -----
390+
391+
func.func @same_rank_if_known_failure(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) {
392+
// expected-error@+1 {{operands don't have matching ranks}}
393+
%0 = "test.operands_and_result_have_same_rank"(%arg0, %arg1) : (tensor<1xi32>, tensor<1x2xf32>) -> (tensor<*xf32>)
394+
return
395+
}
396+
397+
// -----
398+
399+
func.func @same_rank_if_known_failure(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) {
400+
// expected-error@+1 {{result type has different rank than operands}}
401+
%0 = "test.operands_and_result_have_same_rank"(%arg1, %arg1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2x3xf32>)
402+
return
403+
}
404+
405+
// -----
406+
380407
// CHECK-LABEL: same_shape_success
381408
func.func @same_shape_success(%t2x3: tensor<2x3xi32>, %m2x3: memref<2x3xf32>, %v2x3 : vector<2x3xi32>, %t4x5 : tensor<4x5xi32>) {
382409
"test.operand0_and_result_have_same_shape"(%t2x3, %t4x5) : (tensor<2x3xi32>, tensor<4x5xi32>) -> (tensor<2x3xf32>)

0 commit comments

Comments
 (0)