Skip to content

[mlir] use irdl as matcher description in transform #89779

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ class DynamicTypeDefinition;
namespace mlir {
namespace irdl {

class AttributeOp;
class Constraint;
class OperationOp;
class TypeOp;

/// Provides context to the verification of constraints.
/// It contains the assignment of variables to attributes, and the assignment
Expand Down Expand Up @@ -246,6 +249,14 @@ struct RegionConstraint {
std::optional<SmallVector<unsigned>> argumentConstraints;
std::optional<size_t> blockCount;
};

/// Generate an op verifier function from the given IRDL operation definition.
llvm::unique_function<LogicalResult(Operation *) const> createVerifier(
OperationOp operation,
const DenseMap<irdl::TypeOp, std::unique_ptr<DynamicTypeDefinition>>
&typeDefs,
const DenseMap<irdl::AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
&attrDefs);
} // namespace irdl
} // namespace mlir

Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Transform/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_subdirectory(DebugExtension)
add_subdirectory(Interfaces)
add_subdirectory(IR)
add_subdirectory(IRDLExtension)
add_subdirectory(LoopExtension)
add_subdirectory(PDLExtension)
add_subdirectory(Transforms)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
set(LLVM_TARGET_DEFINITIONS IRDLExtensionOps.td)
mlir_tablegen(IRDLExtensionOps.h.inc -gen-op-decls)
mlir_tablegen(IRDLExtensionOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRTransformDialectIRDLExtensionOpsIncGen)

add_mlir_doc(IRDLExtensionOps IRDLExtensionOps Dialects/ -gen-op-doc)
21 changes: 21 additions & 0 deletions mlir/include/mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===- IRDLExtension.h - IRDL extension for Transform dialect ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSION_H
#define MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSION_H

namespace mlir {
class DialectRegistry;

namespace transform {
/// Registers the IRDL extension of the Transform dialect in the given registry.
void registerIRDLExtension(DialectRegistry &dialectRegistry);
} // namespace transform
} // namespace mlir

#endif // MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSION_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//===- IRDLExtensionOps.h - IRDL Transform dialect extension ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS_H
#define MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpDefinition.h"

#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.h.inc"

#endif // MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
//===- IRDLExtensionOps.td - Transform dialect extension ---*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS
#define MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS

include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"

def IRDLCollectMatchingOp : TransformDialectOp<"irdl.collect_matching",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
SymbolTable,
NoTerminator]> {
let summary =
"Finds ops that match the IRDL definition without registering them.";

let arguments = (ins TransformHandleTypeInterface:$root);
let regions = (region SizedRegion<1>:$body);
let results = (outs TransformHandleTypeInterface:$matched);

let assemblyFormat =
"`in` $root `:` functional-type(operands, results) attr-dict-with-keyword "
"regions";

let hasVerifier = 1;
}

#endif // MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSION_H
#define MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSION_H

namespace mlir {
class DialectRegistry;

Expand All @@ -14,3 +17,5 @@ namespace transform {
void registerPDLExtension(DialectRegistry &dialectRegistry);
} // namespace transform
} // namespace mlir

#endif // MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSION_H
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllExtensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h"
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h"
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
Expand Down Expand Up @@ -77,6 +78,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
sparse_tensor::registerTransformDialectExtension(registry);
tensor::registerTransformDialectExtension(registry);
transform::registerDebugExtension(registry);
transform::registerIRDLExtension(registry);
transform::registerLoopExtension(registry);
transform::registerPDLExtension(registry);
vector::registerTransformDialectExtension(registry);
Expand Down
71 changes: 44 additions & 27 deletions mlir/lib/Dialect/IRDL/IRDLLoading.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,26 +270,30 @@ static LogicalResult irdlRegionVerifier(
return success();
}

/// Define and load an operation represented by a `irdl.operation`
/// operation.
static WalkResult loadOperation(
OperationOp op, ExtensibleDialect *dialect,
DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
llvm::unique_function<LogicalResult(Operation *) const>
mlir::irdl::createVerifier(
OperationOp op,
const DenseMap<irdl::TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
const DenseMap<irdl::AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
&attrs) {
// Resolve SSA values to verifier constraint slots
SmallVector<Value> constrToValue;
SmallVector<Value> regionToValue;
for (Operation &op : op->getRegion(0).getOps()) {
if (isa<VerifyConstraintInterface>(op)) {
if (op.getNumResults() != 1)
return op.emitError()
<< "IRDL constraint operations must have exactly one result";
if (op.getNumResults() != 1) {
op.emitError()
<< "IRDL constraint operations must have exactly one result";
return nullptr;
}
constrToValue.push_back(op.getResult(0));
}
if (isa<VerifyRegionInterface>(op)) {
if (op.getNumResults() != 1)
return op.emitError()
<< "IRDL constraint operations must have exactly one result";
if (op.getNumResults() != 1) {
op.emitError()
<< "IRDL constraint operations must have exactly one result";
return nullptr;
}
regionToValue.push_back(op.getResult(0));
}
}
Expand All @@ -302,7 +306,7 @@ static WalkResult loadOperation(
std::unique_ptr<Constraint> verifier =
op.getVerifier(constrToValue, types, attrs);
if (!verifier)
return WalkResult::interrupt();
return nullptr;
constraints.push_back(std::move(verifier));
}

Expand Down Expand Up @@ -358,7 +362,7 @@ static WalkResult loadOperation(
}

// Gather which constraint slots correspond to attributes constraints
DenseMap<StringAttr, size_t> attributesContraints;
DenseMap<StringAttr, size_t> attributeConstraints;
auto attributesOp = op.getOp<AttributesOp>();
if (attributesOp.has_value()) {
const Operation::operand_range values = attributesOp->getAttributeValues();
Expand All @@ -367,40 +371,53 @@ static WalkResult loadOperation(
for (const auto &[name, value] : llvm::zip(names, values)) {
for (auto [i, constr] : enumerate(constrToValue)) {
if (constr == value) {
attributesContraints[cast<StringAttr>(name)] = i;
attributeConstraints[cast<StringAttr>(name)] = i;
break;
}
}
}
}

// IRDL does not support defining custom parsers or printers.
auto parser = [](OpAsmParser &parser, OperationState &result) {
return failure();
};
auto printer = [](Operation *op, OpAsmPrinter &printer, StringRef) {
printer.printGenericOp(op);
};

auto verifier =
return
[constraints{std::move(constraints)},
regionConstraints{std::move(regionConstraints)},
operandConstraints{std::move(operandConstraints)},
operandVariadicity{std::move(operandVariadicity)},
resultConstraints{std::move(resultConstraints)},
resultVariadicity{std::move(resultVariadicity)},
attributesContraints{std::move(attributesContraints)}](Operation *op) {
attributeConstraints{std::move(attributeConstraints)}](Operation *op) {
ConstraintVerifier verifier(constraints);
const LogicalResult opVerifierResult = irdlOpVerifier(
op, verifier, operandConstraints, operandVariadicity,
resultConstraints, resultVariadicity, attributesContraints);
resultConstraints, resultVariadicity, attributeConstraints);
const LogicalResult opRegionVerifierResult =
irdlRegionVerifier(op, verifier, regionConstraints);
return LogicalResult::success(opVerifierResult.succeeded() &&
opRegionVerifierResult.succeeded());
};
}

/// Define and load an operation represented by a `irdl.operation`
/// operation.
static WalkResult loadOperation(
OperationOp op, ExtensibleDialect *dialect,
const DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
const DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
&attrs) {

// IRDL does not support defining custom parsers or printers.
auto parser = [](OpAsmParser &parser, OperationState &result) {
return failure();
};
auto printer = [](Operation *op, OpAsmPrinter &printer, StringRef) {
printer.printGenericOp(op);
};

auto verifier = createVerifier(op, types, attrs);
if (!verifier)
return WalkResult::interrupt();

// IRDL supports only checking number of blocks and argument contraints
// IRDL supports only checking number of blocks and argument constraints
// It is done in the main verifier to reuse `ConstraintVerifier` context
auto regionVerifier = [](Operation *op) { return LogicalResult::success(); };

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Transform/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_subdirectory(DebugExtension)
add_subdirectory(Interfaces)
add_subdirectory(IR)
add_subdirectory(IRDLExtension)
add_subdirectory(LoopExtension)
add_subdirectory(PDLExtension)
add_subdirectory(Transforms)
Expand Down
12 changes: 12 additions & 0 deletions mlir/lib/Dialect/Transform/IRDLExtension/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
add_mlir_dialect_library(MLIRTransformDialectIRDLExtension
IRDLExtension.cpp
IRDLExtensionOps.cpp

DEPENDS
MLIRTransformDialectIRDLExtensionOpsIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRTransformDialect
MLIRIRDL
)
34 changes: 34 additions & 0 deletions mlir/lib/Dialect/Transform/IRDLExtension/IRDLExtension.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//===- IRDLExtension.cpp - IRDL extension for the Transform dialect -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
#include "mlir/Dialect/IRDL/IR/IRDL.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.h"
#include "mlir/IR/DialectRegistry.h"

using namespace mlir;

namespace {
class IRDLExtension
: public transform::TransformDialectExtension<IRDLExtension> {
public:
void init() {
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.cpp.inc"
>();

declareDependentDialect<irdl::IRDLDialect>();
}
};
} // namespace

void mlir::transform::registerIRDLExtension(DialectRegistry &dialectRegistry) {
dialectRegistry.addExtensions<IRDLExtension>();
}
Loading