Skip to content

Commit 076d3e2

Browse files
authored
[mlir][ods] Verify access to operands in inferReturnTypes (#112574)
The patch adds graceful handling of incorrectly constructed MLIR operation with less operands than expected.
1 parent 076aac5 commit 076d3e2

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

mlir/test/mlir-tblgen/op-result.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ def OpL1 : NS_Op<"op_with_all_types_constraint",
130130

131131
// CHECK-LABEL: LogicalResult OpL1::inferReturnTypes
132132
// CHECK-NOT: }
133+
// CHECK: if (operands.size() <= 0)
134+
// CHECK-NEXT: return ::mlir::failure();
133135
// CHECK: ::mlir::Type odsInferredType0 = operands[0].getType();
134136
// CHECK: inferredReturnTypes[0] = odsInferredType0;
135137

@@ -141,6 +143,9 @@ def OpL2 : NS_Op<"op_with_all_types_constraint",
141143

142144
// CHECK-LABEL: LogicalResult OpL2::inferReturnTypes
143145
// CHECK-NOT: }
146+
// CHECK: if (operands.size() <= 2)
147+
// CHECK-NEXT: return ::mlir::failure();
148+
// CHECK-NOT: if (operands.size() <= 0)
144149
// CHECK: ::mlir::Type odsInferredType0 = operands[2].getType();
145150
// CHECK: ::mlir::Type odsInferredType1 = operands[0].getType();
146151
// CHECK: inferredReturnTypes[0] = odsInferredType0;
@@ -166,6 +171,8 @@ def OpL4 : NS_Op<"two_inference_edges", [
166171
}
167172

168173
// CHECK-LABEL: LogicalResult OpL4::inferReturnTypes
174+
// CHECK: if (operands.size() <= 0)
175+
// CHECK-NEXT: return ::mlir::failure();
169176
// CHECK: odsInferredType0 = fromInput(operands[0].getType())
170177
// CHECK: odsInferredType1 = infer0(odsInferredType0)
171178
// CHECK: odsInferredType2 = infer1(odsInferredType1)

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3584,6 +3584,24 @@ void OpEmitter::genTypeInterfaceMethods() {
35843584
fctx.addSubst("_ctxt", "context");
35853585
body << " ::mlir::Builder odsBuilder(context);\n";
35863586

3587+
// Preprocessing stage to verify all accesses to operands are valid.
3588+
int maxAccessedIndex = -1;
3589+
for (int i = 0, e = op.getNumResults(); i != e; ++i) {
3590+
const InferredResultType &infer = op.getInferredResultType(i);
3591+
if (!infer.isArg())
3592+
continue;
3593+
Operator::OperandOrAttribute arg =
3594+
op.getArgToOperandOrAttribute(infer.getIndex());
3595+
if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
3596+
maxAccessedIndex =
3597+
std::max(maxAccessedIndex, arg.operandOrAttributeIndex());
3598+
}
3599+
}
3600+
if (maxAccessedIndex != -1) {
3601+
body << " if (operands.size() <= " << Twine(maxAccessedIndex) << ")\n";
3602+
body << " return ::mlir::failure();\n";
3603+
}
3604+
35873605
// Process the type inference graph in topological order, starting from types
35883606
// that are always fully-inferred: operands and results with constructible
35893607
// types. The type inference graph here will always be a DAG, so this gives
@@ -3600,7 +3618,8 @@ void OpEmitter::genTypeInterfaceMethods() {
36003618
if (infer.isArg()) {
36013619
// If this is an operand, just index into operand list to access the
36023620
// type.
3603-
auto arg = op.getArgToOperandOrAttribute(infer.getIndex());
3621+
Operator::OperandOrAttribute arg =
3622+
op.getArgToOperandOrAttribute(infer.getIndex());
36043623
if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) {
36053624
typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) +
36063625
"].getType()")

0 commit comments

Comments
 (0)