Skip to content

[mlir] [tblgen-to-irdl] Refactor tblgen-to-irdl script and support more types #105505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,10 @@ class AllOfType<list<Type> allowedTypeList, string summary = "",
class ConfinedType<Type type, list<Pred> predicates, string summary = "",
string cppType = type.cppType> : Type<
And<!listconcat([type.predicate], !foreach(pred, predicates, pred))>,
summary, cppType>;
summary, cppType> {
Type baseType = type;
list<Pred> predicateList = predicates;
}

// Integer types.

Expand Down
1 change: 0 additions & 1 deletion mlir/test/tblgen-to-irdl/CMathDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def CMath_ComplexType : CMath_Type<"ComplexType", "complex"> {

// CHECK: irdl.operation @identity {
// CHECK-NEXT: %0 = irdl.base "!cmath.complex"
// CHECK-NEXT: irdl.operands()
// CHECK-NEXT: irdl.results(%0)
// CHECK-NEXT: }
def CMath_IdentityOp : CMath_Op<"identity"> {
Expand Down
59 changes: 53 additions & 6 deletions mlir/test/tblgen-to-irdl/TestDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,8 @@ def Test_AndOp : Test_Op<"and"> {
// CHECK-LABEL: irdl.operation @and {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.any
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
// CHECK-NEXT: irdl.operands(%[[v2]])
// CHECK-NEXT: irdl.results()
// CHECK-NEXT: }


Expand All @@ -41,9 +40,39 @@ def Test_AnyOp : Test_Op<"any"> {
// CHECK-LABEL: irdl.operation @any {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.any
// CHECK-NEXT: irdl.operands(%[[v0]])
// CHECK-NEXT: irdl.results()
// CHECK-NEXT: }

// Check confined types are converted correctly.
def Test_ConfinedOp : Test_Op<"confined"> {
let arguments = (ins ConfinedType<AnyType, [CPred<"::llvm::isa<::mlir::TensorType>($_self)">]>:$tensor,
ConfinedType<AnyType, [And<[CPred<"::llvm::isa<::mlir::VectorType>($_self)">
, CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>]>:$vector);
}
// CHECK-LABEL: irdl.operation @confined {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.any
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.c_pred "(::llvm::isa<::mlir::TensorType>($_self))"
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]])
// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any
// CHECK-NEXT: %[[v4:[^ ]*]] = irdl.c_pred "(::llvm::isa<::mlir::VectorType>($_self))"
// CHECK-NEXT: %[[v5:[^ ]*]] = irdl.c_pred "(::llvm::cast<::mlir::VectorType>($_self).getRank() > 0)"
// CHECK-NEXT: %[[v6:[^ ]*]] = irdl.all_of(%[[v4]], %[[v5]])
// CHECK-NEXT: %[[v7:[^ ]*]] = irdl.all_of(%[[v3]], %[[v6]])
// CHECK-NEXT: irdl.operands(%[[v2]], %[[v7]])
// CHECK-NEXT: }

// Check generic integer types are converted correctly.
def Test_Integers : Test_Op<"integers"> {
let arguments = (ins AnyI8:$any_int,
AnyInteger:$any_integer);
}
// CHECK-LABEL: irdl.operation @integers {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.is i8
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.is si8
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.is ui8
// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
// CHECK-NEXT: %[[v4:[^ ]*]] = irdl.base "!builtin.integer"
// CHECK-NEXT: irdl.operands(%[[v3]], %[[v4]])
// CHECK-NEXT: }

// Check that AnyTypeOf is converted correctly.
def Test_OrOp : Test_Op<"or"> {
Expand All @@ -53,11 +82,30 @@ def Test_OrOp : Test_Op<"or"> {
// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a"
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]])
// CHECK-NEXT: irdl.operands(%[[v3]])
// CHECK-NEXT: irdl.results()
// CHECK-NEXT: }

// Check that various types are converted correctly.
def Test_TypesOp : Test_Op<"types"> {
let arguments = (ins I32:$a,
SI64:$b,
UI8:$c,
Index:$d,
F32:$e,
NoneType:$f,
Complex<F8E4M3FN>);
}
// CHECK-LABEL: irdl.operation @types {
// CHECK-NEXT: %{{.*}} = irdl.is i32
// CHECK-NEXT: %{{.*}} = irdl.is si64
// CHECK-NEXT: %{{.*}} = irdl.is ui8
// CHECK-NEXT: %{{.*}} = irdl.is index
// CHECK-NEXT: %{{.*}} = irdl.is f32
// CHECK-NEXT: %{{.*}} = irdl.is none
// CHECK-NEXT: %{{.*}} = irdl.is complex<f8E4M3FN>
// CHECK-NEXT: irdl.operands({{.*}})
// CHECK-NEXT: }

// Check that variadics and optionals are converted correctly.
def Test_VariadicityOp : Test_Op<"variadicity"> {
Expand All @@ -70,5 +118,4 @@ def Test_VariadicityOp : Test_Op<"variadicity"> {
// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b"
// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c"
// CHECK-NEXT: irdl.operands(variadic %[[v0]], optional %[[v1]], %[[v2]])
// CHECK-NEXT: irdl.results()
// CHECK-NEXT: }
179 changes: 170 additions & 9 deletions mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,131 @@ llvm::cl::opt<std::string>
selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
llvm::cl::cat(dialectGenCat), llvm::cl::Required);

Value createPredicate(OpBuilder &builder, tblgen::Pred pred) {
MLIRContext *ctx = builder.getContext();

if (pred.isCombined()) {
auto combiner = pred.getDef().getValueAsDef("kind")->getName();
if (combiner == "PredCombinerAnd" || combiner == "PredCombinerOr") {
std::vector<Value> constraints;
for (auto *child : pred.getDef().getValueAsListOfDefs("children")) {
constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
}
if (combiner == "PredCombinerAnd") {
auto op =
builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
return op.getOutput();
}
auto op =
builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), constraints);
return op.getOutput();
}
}

std::string condition = pred.getCondition();
// Build a CPredOp to match the C constraint built.
irdl::CPredOp op = builder.create<irdl::CPredOp>(
UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
return op;
}

Value typeToConstraint(OpBuilder &builder, Type type) {
MLIRContext *ctx = builder.getContext();
auto op =
builder.create<irdl::IsOp>(UnknownLoc::get(ctx), TypeAttr::get(type));
return op.getOutput();
}

std::optional<Type> recordToType(MLIRContext *ctx, const Record &predRec) {

if (predRec.isSubClassOf("I")) {
auto width = predRec.getValueAsInt("bitwidth");
return IntegerType::get(ctx, width, IntegerType::Signless);
}

if (predRec.isSubClassOf("SI")) {
auto width = predRec.getValueAsInt("bitwidth");
return IntegerType::get(ctx, width, IntegerType::Signed);
}

if (predRec.isSubClassOf("UI")) {
auto width = predRec.getValueAsInt("bitwidth");
return IntegerType::get(ctx, width, IntegerType::Unsigned);
}

// Index type
if (predRec.getName() == "Index") {
return IndexType::get(ctx);
}

// Float types
if (predRec.isSubClassOf("F")) {
auto width = predRec.getValueAsInt("bitwidth");
switch (width) {
case 16:
return FloatType::getF16(ctx);
case 32:
return FloatType::getF32(ctx);
case 64:
return FloatType::getF64(ctx);
case 80:
return FloatType::getF80(ctx);
case 128:
return FloatType::getF128(ctx);
}
}

if (predRec.getName() == "NoneType") {
return NoneType::get(ctx);
}

if (predRec.getName() == "BF16") {
return FloatType::getBF16(ctx);
}

if (predRec.getName() == "TF32") {
return FloatType::getTF32(ctx);
}

if (predRec.getName() == "F8E4M3FN") {
return FloatType::getFloat8E4M3FN(ctx);
}

if (predRec.getName() == "F8E5M2") {
return FloatType::getFloat8E5M2(ctx);
}

if (predRec.getName() == "F8E4M3") {
return FloatType::getFloat8E4M3(ctx);
}

if (predRec.getName() == "F8E4M3FNUZ") {
return FloatType::getFloat8E4M3FNUZ(ctx);
}

if (predRec.getName() == "F8E4M3B11FNUZ") {
return FloatType::getFloat8E4M3B11FNUZ(ctx);
}

if (predRec.getName() == "F8E5M2FNUZ") {
return FloatType::getFloat8E5M2FNUZ(ctx);
}

if (predRec.getName() == "F8E3M4") {
return FloatType::getFloat8E3M4(ctx);
}

if (predRec.isSubClassOf("Complex")) {
const Record *elementRec = predRec.getValueAsDef("elementType");
auto elementType = recordToType(ctx, *elementRec);
if (elementType.has_value()) {
return ComplexType::get(elementType.value());
}
}

return std::nullopt;
}

Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
MLIRContext *ctx = builder.getContext();
const Record &predRec = constraint.getDef();
Expand Down Expand Up @@ -78,11 +203,45 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
return op.getOutput();
}

std::string condition = constraint.getPredicate().getCondition();
// Build a CPredOp to match the C constraint built.
irdl::CPredOp op = builder.create<irdl::CPredOp>(
UnknownLoc::get(ctx), StringAttr::get(ctx, condition));
return op;
// Integer types
if (predRec.getName() == "AnyInteger") {
auto op = builder.create<irdl::BaseOp>(
UnknownLoc::get(ctx), StringAttr::get(ctx, "!builtin.integer"));
return op.getOutput();
}

if (predRec.isSubClassOf("AnyI")) {
auto width = predRec.getValueAsInt("bitwidth");
std::vector<Value> types = {
typeToConstraint(builder,
IntegerType::get(ctx, width, IntegerType::Signless)),
typeToConstraint(builder,
IntegerType::get(ctx, width, IntegerType::Signed)),
typeToConstraint(builder,
IntegerType::get(ctx, width, IntegerType::Unsigned))};
auto op = builder.create<irdl::AnyOfOp>(UnknownLoc::get(ctx), types);
return op.getOutput();
}

auto type = recordToType(ctx, predRec);

if (type.has_value()) {
return typeToConstraint(builder, type.value());
}

// Confined type
if (predRec.isSubClassOf("ConfinedType")) {
std::vector<Value> constraints;
constraints.push_back(createConstraint(
builder, tblgen::Constraint(predRec.getValueAsDef("baseType"))));
for (Record *child : predRec.getValueAsListOfDefs("predicateList")) {
constraints.push_back(createPredicate(builder, tblgen::Pred(child)));
}
auto op = builder.create<irdl::AllOfOp>(UnknownLoc::get(ctx), constraints);
return op.getOutput();
}

return createPredicate(builder, constraint.getPredicate());
}

/// Returns the name of the operation without the dialect prefix.
Expand Down Expand Up @@ -131,10 +290,12 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
auto [results, resultVariadicity] = getValues(tblgenOp.getResults());

// Create the operands and results operations.
consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
operandVariadicity);
consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
resultVariadicity);
if (!operands.empty())
consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
operandVariadicity);
if (!results.empty())
consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
resultVariadicity);

return op;
}
Expand Down
Loading