Skip to content

Commit 105c992

Browse files
authored
[mlir] use irdl as matcher description in transform (#89779)
Introduce a new Transform dialect extension that uses IRDL op definitions as matcher descriptors. IRDL allows one to essentially define additional op constraits to be verified and, unlike PDL, does not assume rewriting will happen. Leverage IRDL verification capability to filter out ops that match an IRDL definition without actually registering the corresponding operation with the system.
1 parent 11bda17 commit 105c992

File tree

14 files changed

+302
-27
lines changed

14 files changed

+302
-27
lines changed

mlir/include/mlir/Dialect/IRDL/IRDLVerifiers.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ class DynamicTypeDefinition;
3030
namespace mlir {
3131
namespace irdl {
3232

33+
class AttributeOp;
3334
class Constraint;
35+
class OperationOp;
36+
class TypeOp;
3437

3538
/// Provides context to the verification of constraints.
3639
/// It contains the assignment of variables to attributes, and the assignment
@@ -246,6 +249,14 @@ struct RegionConstraint {
246249
std::optional<SmallVector<unsigned>> argumentConstraints;
247250
std::optional<size_t> blockCount;
248251
};
252+
253+
/// Generate an op verifier function from the given IRDL operation definition.
254+
llvm::unique_function<LogicalResult(Operation *) const> createVerifier(
255+
OperationOp operation,
256+
const DenseMap<irdl::TypeOp, std::unique_ptr<DynamicTypeDefinition>>
257+
&typeDefs,
258+
const DenseMap<irdl::AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
259+
&attrDefs);
249260
} // namespace irdl
250261
} // namespace mlir
251262

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_subdirectory(DebugExtension)
22
add_subdirectory(Interfaces)
33
add_subdirectory(IR)
4+
add_subdirectory(IRDLExtension)
45
add_subdirectory(LoopExtension)
56
add_subdirectory(PDLExtension)
67
add_subdirectory(Transforms)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
set(LLVM_TARGET_DEFINITIONS IRDLExtensionOps.td)
2+
mlir_tablegen(IRDLExtensionOps.h.inc -gen-op-decls)
3+
mlir_tablegen(IRDLExtensionOps.cpp.inc -gen-op-defs)
4+
add_public_tablegen_target(MLIRTransformDialectIRDLExtensionOpsIncGen)
5+
6+
add_mlir_doc(IRDLExtensionOps IRDLExtensionOps Dialects/ -gen-op-doc)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- IRDLExtension.h - IRDL extension for Transform dialect ---*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSION_H
10+
#define MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSION_H
11+
12+
namespace mlir {
13+
class DialectRegistry;
14+
15+
namespace transform {
16+
/// Registers the IRDL extension of the Transform dialect in the given registry.
17+
void registerIRDLExtension(DialectRegistry &dialectRegistry);
18+
} // namespace transform
19+
} // namespace mlir
20+
21+
#endif // MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSION_H
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===- IRDLExtensionOps.h - IRDL Transform dialect extension ----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS_H
10+
#define MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS_H
11+
12+
#include "mlir/Bytecode/BytecodeOpInterface.h"
13+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
14+
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
15+
#include "mlir/IR/OpDefinition.h"
16+
17+
#define GET_OP_CLASSES
18+
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.h.inc"
19+
20+
#endif // MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS_H
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//===- IRDLExtensionOps.td - Transform dialect extension ---*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS
10+
#define MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS
11+
12+
include "mlir/Dialect/Transform/IR/TransformDialect.td"
13+
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
14+
include "mlir/Interfaces/SideEffectInterfaces.td"
15+
include "mlir/IR/SymbolInterfaces.td"
16+
17+
def IRDLCollectMatchingOp : TransformDialectOp<"irdl.collect_matching",
18+
[DeclareOpInterfaceMethods<TransformOpInterface>,
19+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
20+
SymbolTable,
21+
NoTerminator]> {
22+
let summary =
23+
"Finds ops that match the IRDL definition without registering them.";
24+
25+
let arguments = (ins TransformHandleTypeInterface:$root);
26+
let regions = (region SizedRegion<1>:$body);
27+
let results = (outs TransformHandleTypeInterface:$matched);
28+
29+
let assemblyFormat =
30+
"`in` $root `:` functional-type(operands, results) attr-dict-with-keyword "
31+
"regions";
32+
33+
let hasVerifier = 1;
34+
}
35+
36+
#endif // MLIR_DIALECT_TRANSFORM_IRDLEXTENSION_IRDLEXTENSIONOPS

mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtension.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#ifndef MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSION_H
10+
#define MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSION_H
11+
912
namespace mlir {
1013
class DialectRegistry;
1114

@@ -14,3 +17,5 @@ namespace transform {
1417
void registerPDLExtension(DialectRegistry &dialectRegistry);
1518
} // namespace transform
1619
} // namespace mlir
20+
21+
#endif // MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSION_H

mlir/include/mlir/InitAllExtensions.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h"
3636
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
3737
#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h"
38+
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
3839
#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
3940
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
4041
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
@@ -77,6 +78,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
7778
sparse_tensor::registerTransformDialectExtension(registry);
7879
tensor::registerTransformDialectExtension(registry);
7980
transform::registerDebugExtension(registry);
81+
transform::registerIRDLExtension(registry);
8082
transform::registerLoopExtension(registry);
8183
transform::registerPDLExtension(registry);
8284
vector::registerTransformDialectExtension(registry);

mlir/lib/Dialect/IRDL/IRDLLoading.cpp

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -270,26 +270,30 @@ static LogicalResult irdlRegionVerifier(
270270
return success();
271271
}
272272

273-
/// Define and load an operation represented by a `irdl.operation`
274-
/// operation.
275-
static WalkResult loadOperation(
276-
OperationOp op, ExtensibleDialect *dialect,
277-
DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
278-
DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> &attrs) {
273+
llvm::unique_function<LogicalResult(Operation *) const>
274+
mlir::irdl::createVerifier(
275+
OperationOp op,
276+
const DenseMap<irdl::TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
277+
const DenseMap<irdl::AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
278+
&attrs) {
279279
// Resolve SSA values to verifier constraint slots
280280
SmallVector<Value> constrToValue;
281281
SmallVector<Value> regionToValue;
282282
for (Operation &op : op->getRegion(0).getOps()) {
283283
if (isa<VerifyConstraintInterface>(op)) {
284-
if (op.getNumResults() != 1)
285-
return op.emitError()
286-
<< "IRDL constraint operations must have exactly one result";
284+
if (op.getNumResults() != 1) {
285+
op.emitError()
286+
<< "IRDL constraint operations must have exactly one result";
287+
return nullptr;
288+
}
287289
constrToValue.push_back(op.getResult(0));
288290
}
289291
if (isa<VerifyRegionInterface>(op)) {
290-
if (op.getNumResults() != 1)
291-
return op.emitError()
292-
<< "IRDL constraint operations must have exactly one result";
292+
if (op.getNumResults() != 1) {
293+
op.emitError()
294+
<< "IRDL constraint operations must have exactly one result";
295+
return nullptr;
296+
}
293297
regionToValue.push_back(op.getResult(0));
294298
}
295299
}
@@ -302,7 +306,7 @@ static WalkResult loadOperation(
302306
std::unique_ptr<Constraint> verifier =
303307
op.getVerifier(constrToValue, types, attrs);
304308
if (!verifier)
305-
return WalkResult::interrupt();
309+
return nullptr;
306310
constraints.push_back(std::move(verifier));
307311
}
308312

@@ -358,7 +362,7 @@ static WalkResult loadOperation(
358362
}
359363

360364
// Gather which constraint slots correspond to attributes constraints
361-
DenseMap<StringAttr, size_t> attributesContraints;
365+
DenseMap<StringAttr, size_t> attributeConstraints;
362366
auto attributesOp = op.getOp<AttributesOp>();
363367
if (attributesOp.has_value()) {
364368
const Operation::operand_range values = attributesOp->getAttributeValues();
@@ -367,40 +371,53 @@ static WalkResult loadOperation(
367371
for (const auto &[name, value] : llvm::zip(names, values)) {
368372
for (auto [i, constr] : enumerate(constrToValue)) {
369373
if (constr == value) {
370-
attributesContraints[cast<StringAttr>(name)] = i;
374+
attributeConstraints[cast<StringAttr>(name)] = i;
371375
break;
372376
}
373377
}
374378
}
375379
}
376380

377-
// IRDL does not support defining custom parsers or printers.
378-
auto parser = [](OpAsmParser &parser, OperationState &result) {
379-
return failure();
380-
};
381-
auto printer = [](Operation *op, OpAsmPrinter &printer, StringRef) {
382-
printer.printGenericOp(op);
383-
};
384-
385-
auto verifier =
381+
return
386382
[constraints{std::move(constraints)},
387383
regionConstraints{std::move(regionConstraints)},
388384
operandConstraints{std::move(operandConstraints)},
389385
operandVariadicity{std::move(operandVariadicity)},
390386
resultConstraints{std::move(resultConstraints)},
391387
resultVariadicity{std::move(resultVariadicity)},
392-
attributesContraints{std::move(attributesContraints)}](Operation *op) {
388+
attributeConstraints{std::move(attributeConstraints)}](Operation *op) {
393389
ConstraintVerifier verifier(constraints);
394390
const LogicalResult opVerifierResult = irdlOpVerifier(
395391
op, verifier, operandConstraints, operandVariadicity,
396-
resultConstraints, resultVariadicity, attributesContraints);
392+
resultConstraints, resultVariadicity, attributeConstraints);
397393
const LogicalResult opRegionVerifierResult =
398394
irdlRegionVerifier(op, verifier, regionConstraints);
399395
return LogicalResult::success(opVerifierResult.succeeded() &&
400396
opRegionVerifierResult.succeeded());
401397
};
398+
}
399+
400+
/// Define and load an operation represented by a `irdl.operation`
401+
/// operation.
402+
static WalkResult loadOperation(
403+
OperationOp op, ExtensibleDialect *dialect,
404+
const DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> &types,
405+
const DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
406+
&attrs) {
407+
408+
// IRDL does not support defining custom parsers or printers.
409+
auto parser = [](OpAsmParser &parser, OperationState &result) {
410+
return failure();
411+
};
412+
auto printer = [](Operation *op, OpAsmPrinter &printer, StringRef) {
413+
printer.printGenericOp(op);
414+
};
415+
416+
auto verifier = createVerifier(op, types, attrs);
417+
if (!verifier)
418+
return WalkResult::interrupt();
402419

403-
// IRDL supports only checking number of blocks and argument contraints
420+
// IRDL supports only checking number of blocks and argument constraints
404421
// It is done in the main verifier to reuse `ConstraintVerifier` context
405422
auto regionVerifier = [](Operation *op) { return LogicalResult::success(); };
406423

mlir/lib/Dialect/Transform/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
add_subdirectory(DebugExtension)
22
add_subdirectory(Interfaces)
33
add_subdirectory(IR)
4+
add_subdirectory(IRDLExtension)
45
add_subdirectory(LoopExtension)
56
add_subdirectory(PDLExtension)
67
add_subdirectory(Transforms)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
add_mlir_dialect_library(MLIRTransformDialectIRDLExtension
2+
IRDLExtension.cpp
3+
IRDLExtensionOps.cpp
4+
5+
DEPENDS
6+
MLIRTransformDialectIRDLExtensionOpsIncGen
7+
8+
LINK_LIBS PUBLIC
9+
MLIRIR
10+
MLIRTransformDialect
11+
MLIRIRDL
12+
)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
//===- IRDLExtension.cpp - IRDL extension for the Transform dialect -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
10+
#include "mlir/Dialect/IRDL/IR/IRDL.h"
11+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
12+
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.h"
13+
#include "mlir/IR/DialectRegistry.h"
14+
15+
using namespace mlir;
16+
17+
namespace {
18+
class IRDLExtension
19+
: public transform::TransformDialectExtension<IRDLExtension> {
20+
public:
21+
void init() {
22+
registerTransformOps<
23+
#define GET_OP_LIST
24+
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.cpp.inc"
25+
>();
26+
27+
declareDependentDialect<irdl::IRDLDialect>();
28+
}
29+
};
30+
} // namespace
31+
32+
void mlir::transform::registerIRDLExtension(DialectRegistry &dialectRegistry) {
33+
dialectRegistry.addExtensions<IRDLExtension>();
34+
}

0 commit comments

Comments
 (0)