Skip to content

Commit 3ffd13b

Browse files
committed
[mlir][irdl] Add IRDL verification constraint classes
This patch adds the necessary constraint classes that are be used by IRDL to define Operation, Type, and Attribute verifiers. A constraint is a class inheriting the `irdl::Constraint` class, which may call other constraints that are indexed by `unsigned`. A constraint represent an invariant over an Attribute. The `ConstraintVerifier` class group these constraints together, and make sure that a constraint can only identify a single attribute. So, once a constraint is used to check the satisfiability of an `Attribute`, the `Attribute` will be memorized for this constraint. This ensure that in IRDL, a single `!irdl.attribute` value only correspond to a single `Attribute`. Depends on D144693 Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D145733
1 parent 7fbf72a commit 3ffd13b

File tree

3 files changed

+362
-0
lines changed

3 files changed

+362
-0
lines changed
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
//===- IRDLVerifiers.h - IRDL verifiers --------------------------- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Verifiers for objects declared by IRDL.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_DIALECT_IRDL_IRDLVERIFIERS_H
14+
#define MLIR_DIALECT_IRDL_IRDLVERIFIERS_H
15+
16+
#include "mlir/IR/Attributes.h"
17+
#include "mlir/Support/LLVM.h"
18+
#include "llvm/ADT/ArrayRef.h"
19+
#include <optional>
20+
21+
namespace mlir {
22+
struct LogicalResult;
23+
class InFlightDiagnostic;
24+
class DynamicAttrDefinition;
25+
class DynamicTypeDefinition;
26+
} // namespace mlir
27+
28+
namespace mlir {
29+
namespace irdl {
30+
31+
class Constraint;
32+
33+
/// Provides context to the verification of constraints.
34+
/// It contains the assignment of variables to attributes, and the assignment
35+
/// of variables to constraints.
36+
class ConstraintVerifier {
37+
public:
38+
ConstraintVerifier(ArrayRef<std::unique_ptr<Constraint>> constraints);
39+
40+
/// Check that a constraint is satisfied by an attribute.
41+
///
42+
/// Constraints may call other constraint verifiers. If that is the case,
43+
/// the constraint verifier will check if the variable is already assigned,
44+
/// and if so, check that the attribute is the same as the one assigned.
45+
/// If the variable is not assigned, the constraint verifier will
46+
/// assign the attribute to the variable, and check that the constraint
47+
/// is satisfied.
48+
LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
49+
Attribute attr, unsigned variable);
50+
51+
private:
52+
/// The constraints that can be used for verification.
53+
ArrayRef<std::unique_ptr<Constraint>> constraints;
54+
55+
/// The assignment of variables to attributes. Variables that are not assigned
56+
/// are represented by nullopt. Null attributes needs to be supported here as
57+
/// some attributes or types might use the null attribute to represent
58+
/// optional parameters.
59+
SmallVector<std::optional<Attribute>> assigned;
60+
};
61+
62+
/// Once turned into IRDL verifiers, all constraints are
63+
/// attribute constraints. Type constraints are represented
64+
/// as `TypeAttr` attribute constraints to simplify verification.
65+
/// Verification that a type constraint must yield a
66+
/// `TypeAttr` attribute happens before conversion, at the MLIR level.
67+
class Constraint {
68+
public:
69+
virtual ~Constraint() = default;
70+
71+
/// Check that an attribute is satisfying the constraint.
72+
///
73+
/// Constraints may call other constraint verifiers. If that is the case,
74+
/// the constraint verifier will check if the variable is already assigned,
75+
/// and if so, check that the attribute is the same as the one assigned.
76+
/// If the variable is not assigned, the constraint verifier will
77+
/// assign the attribute to the variable, and check that the constraint
78+
/// is satisfied.
79+
virtual LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
80+
Attribute attr,
81+
ConstraintVerifier &context) const = 0;
82+
};
83+
84+
/// A constraint that checks that an attribute is equal to a given attribute.
85+
class IsConstraint : public Constraint {
86+
public:
87+
IsConstraint(Attribute expectedAttribute)
88+
: expectedAttribute(expectedAttribute) {}
89+
90+
virtual ~IsConstraint() = default;
91+
92+
LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
93+
Attribute attr,
94+
ConstraintVerifier &context) const override;
95+
96+
private:
97+
Attribute expectedAttribute;
98+
};
99+
100+
/// A constraint that checks that an attribute is of a
101+
/// specific dynamic attribute definition, and that all of its parameters
102+
/// satisfy the given constraints.
103+
class DynParametricAttrConstraint : public Constraint {
104+
public:
105+
DynParametricAttrConstraint(DynamicAttrDefinition *attrDef,
106+
SmallVector<unsigned> constraints)
107+
: attrDef(attrDef), constraints(std::move(constraints)) {}
108+
109+
virtual ~DynParametricAttrConstraint() = default;
110+
111+
LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
112+
Attribute attr,
113+
ConstraintVerifier &context) const override;
114+
115+
private:
116+
DynamicAttrDefinition *attrDef;
117+
SmallVector<unsigned> constraints;
118+
};
119+
120+
/// A constraint that checks that a type is of a specific dynamic type
121+
/// definition, and that all of its parameters satisfy the given constraints.
122+
class DynParametricTypeConstraint : public Constraint {
123+
public:
124+
DynParametricTypeConstraint(DynamicTypeDefinition *typeDef,
125+
SmallVector<unsigned> constraints)
126+
: typeDef(typeDef), constraints(std::move(constraints)) {}
127+
128+
virtual ~DynParametricTypeConstraint() = default;
129+
130+
LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
131+
Attribute attr,
132+
ConstraintVerifier &context) const override;
133+
134+
private:
135+
DynamicTypeDefinition *typeDef;
136+
SmallVector<unsigned> constraints;
137+
};
138+
139+
/// A constraint checking that one of the given constraints is satisfied.
140+
class AnyOfConstraint : public Constraint {
141+
public:
142+
AnyOfConstraint(SmallVector<unsigned> constraints)
143+
: constraints(std::move(constraints)) {}
144+
145+
virtual ~AnyOfConstraint() = default;
146+
147+
LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
148+
Attribute attr,
149+
ConstraintVerifier &context) const override;
150+
151+
private:
152+
SmallVector<unsigned> constraints;
153+
};
154+
155+
/// A constraint checking that all of the given constraints are satisfied.
156+
class AllOfConstraint : public Constraint {
157+
public:
158+
AllOfConstraint(SmallVector<unsigned> constraints)
159+
: constraints(std::move(constraints)) {}
160+
161+
virtual ~AllOfConstraint() = default;
162+
163+
LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
164+
Attribute attr,
165+
ConstraintVerifier &context) const override;
166+
167+
private:
168+
SmallVector<unsigned> constraints;
169+
};
170+
171+
/// A constraint that is always satisfied.
172+
class AnyAttributeConstraint : public Constraint {
173+
public:
174+
virtual ~AnyAttributeConstraint() = default;
175+
176+
LogicalResult verify(function_ref<InFlightDiagnostic()> emitError,
177+
Attribute attr,
178+
ConstraintVerifier &context) const override;
179+
};
180+
181+
} // namespace irdl
182+
} // namespace mlir
183+
184+
#endif // MLIR_DIALECT_IRDL_IRDLVERIFIERS_H

mlir/lib/Dialect/IRDL/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_mlir_dialect_library(MLIRIRDL
22
IR/IRDL.cpp
33
IRDLLoading.cpp
4+
IRDLVerifiers.cpp
45

56
DEPENDS
67
MLIRIRDLIncGen
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
//===- IRDLVerifiers.cpp - IRDL verifiers ------------------------- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Verifiers for objects declared by IRDL.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/IRDL/IRDLVerifiers.h"
14+
#include "mlir/IR/Diagnostics.h"
15+
#include "mlir/IR/ExtensibleDialect.h"
16+
#include "mlir/Support/LogicalResult.h"
17+
18+
using namespace mlir;
19+
using namespace mlir::irdl;
20+
21+
ConstraintVerifier::ConstraintVerifier(
22+
ArrayRef<std::unique_ptr<Constraint>> constraints)
23+
: constraints(constraints), assigned() {
24+
assigned.resize(this->constraints.size());
25+
}
26+
27+
LogicalResult
28+
ConstraintVerifier::verify(function_ref<InFlightDiagnostic()> emitError,
29+
Attribute attr, unsigned variable) {
30+
31+
assert(variable < constraints.size() && "invalid constraint variable");
32+
33+
// If the variable is already assigned, check that the attribute is the same.
34+
if (assigned[variable].has_value()) {
35+
if (attr == assigned[variable].value()) {
36+
return success();
37+
} else {
38+
if (emitError)
39+
return emitError() << "expected '" << assigned[variable].value()
40+
<< "' but got '" << attr << "'";
41+
return failure();
42+
}
43+
}
44+
45+
// Otherwise, check the constraint and assign the attribute to the variable.
46+
LogicalResult result = constraints[variable]->verify(emitError, attr, *this);
47+
if (succeeded(result))
48+
assigned[variable] = attr;
49+
50+
return result;
51+
}
52+
53+
LogicalResult IsConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
54+
Attribute attr,
55+
ConstraintVerifier &context) const {
56+
if (attr == expectedAttribute)
57+
return success();
58+
59+
if (emitError)
60+
return emitError() << "expected '" << expectedAttribute << "' but got '"
61+
<< attr << "'";
62+
return failure();
63+
}
64+
65+
LogicalResult DynParametricAttrConstraint::verify(
66+
function_ref<InFlightDiagnostic()> emitError, Attribute attr,
67+
ConstraintVerifier &context) const {
68+
69+
// Check that the base is the expected one.
70+
auto dynAttr = attr.dyn_cast<DynamicAttr>();
71+
if (!dynAttr || dynAttr.getAttrDef() != attrDef) {
72+
if (emitError) {
73+
StringRef dialectName = attrDef->getDialect()->getNamespace();
74+
StringRef attrName = attrDef->getName();
75+
return emitError() << "expected base attribute '" << attrName << '.'
76+
<< dialectName << "' but got '" << attr << "'";
77+
}
78+
return failure();
79+
}
80+
81+
// Check that the parameters satisfy the constraints.
82+
ArrayRef<Attribute> params = dynAttr.getParams();
83+
if (params.size() != constraints.size()) {
84+
if (emitError) {
85+
StringRef dialectName = attrDef->getDialect()->getNamespace();
86+
StringRef attrName = attrDef->getName();
87+
emitError() << "attribute '" << dialectName << "." << attrName
88+
<< "' expects " << params.size() << " parameters but got "
89+
<< constraints.size();
90+
}
91+
return failure();
92+
}
93+
94+
for (size_t i = 0, s = params.size(); i < s; i++)
95+
if (failed(context.verify(emitError, params[i], constraints[i])))
96+
return failure();
97+
98+
return success();
99+
}
100+
101+
LogicalResult DynParametricTypeConstraint::verify(
102+
function_ref<InFlightDiagnostic()> emitError, Attribute attr,
103+
ConstraintVerifier &context) const {
104+
// Check that the base is a TypeAttr.
105+
auto typeAttr = attr.dyn_cast<TypeAttr>();
106+
if (!typeAttr) {
107+
if (emitError)
108+
return emitError() << "expected type, got attribute '" << attr;
109+
return failure();
110+
}
111+
112+
// Check that the type base is the expected one.
113+
auto dynType = typeAttr.getValue().dyn_cast<DynamicType>();
114+
if (!dynType || dynType.getTypeDef() != typeDef) {
115+
if (emitError) {
116+
StringRef dialectName = typeDef->getDialect()->getNamespace();
117+
StringRef attrName = typeDef->getName();
118+
return emitError() << "expected base type '" << dialectName << '.'
119+
<< attrName << "' but got '" << attr << "'";
120+
}
121+
return failure();
122+
}
123+
124+
// Check that the parameters satisfy the constraints.
125+
ArrayRef<Attribute> params = dynType.getParams();
126+
if (params.size() != constraints.size()) {
127+
if (emitError) {
128+
StringRef dialectName = typeDef->getDialect()->getNamespace();
129+
StringRef attrName = typeDef->getName();
130+
emitError() << "attribute '" << dialectName << "." << attrName
131+
<< "' expects " << params.size() << " parameters but got "
132+
<< constraints.size();
133+
}
134+
return failure();
135+
}
136+
137+
for (size_t i = 0, s = params.size(); i < s; i++)
138+
if (failed(context.verify(emitError, params[i], constraints[i])))
139+
return failure();
140+
141+
return success();
142+
}
143+
144+
LogicalResult
145+
AnyOfConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
146+
Attribute attr, ConstraintVerifier &context) const {
147+
for (unsigned constr : constraints) {
148+
// We do not pass the `emitError` here, since we want to emit an error
149+
// only if none of the constraints are satisfied.
150+
if (succeeded(context.verify({}, attr, constr))) {
151+
return success();
152+
}
153+
}
154+
155+
if (emitError)
156+
return emitError() << "'" << attr << "' does not satisfy the constraint";
157+
return failure();
158+
}
159+
160+
LogicalResult
161+
AllOfConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
162+
Attribute attr, ConstraintVerifier &context) const {
163+
for (unsigned constr : constraints) {
164+
if (failed(context.verify(emitError, attr, constr))) {
165+
return failure();
166+
}
167+
}
168+
169+
return success();
170+
}
171+
172+
LogicalResult
173+
AnyAttributeConstraint::verify(function_ref<InFlightDiagnostic()> emitError,
174+
Attribute attr,
175+
ConstraintVerifier &context) const {
176+
return success();
177+
}

0 commit comments

Comments
 (0)