Skip to content

Commit 63c9d2b

Browse files
committed
[mlir] Add transform.foreach_match
Add a new transform op combinator that implements an "if-then-else" style of mechanism for applying transformations. Its main purpose is to serve as a higher-level driver when applying multiple transform scripts to potentially overlapping pieces of the payload IR. This is similar to how the various rewrite drivers operate in C++, but at a higher level and with more declarative expressions. This is not intended to replace existing pattern-based rewrites, but to to drive more complex transformations that are exposed in the transform dialect and are too complex to be expressed as simple declarative rewrites. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D148013
1 parent a3ee34f commit 63c9d2b

File tree

16 files changed

+916
-95
lines changed

16 files changed

+916
-95
lines changed

mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ add_mlir_doc(TransformOps TransformOps Dialects/ -gen-op-doc -dialect=transform)
2828
add_mlir_interface(TransformInterfaces)
2929
add_mlir_doc(TransformInterfaces TransformOpInterfaces Dialects/ -gen-op-interface-docs)
3030

31+
add_mlir_interface(MatchInterfaces)
32+
add_dependencies(MLIRMatchInterfacesIncGen MLIRTransformInterfacesIncGen)
33+
add_mlir_doc(TransformInterfaces MatchOpInterfaces Dialects/ -gen-op-interface-docs)
34+
3135
set(LLVM_TARGET_DEFINITIONS TransformInterfaces.td)
3236
mlir_tablegen(TransformTypeInterfaces.h.inc -gen-type-interface-decls)
3337
mlir_tablegen(TransformTypeInterfaces.cpp.inc -gen-type-interface-defs)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
//===- MatchInterfaces.h - Transform Dialect Interfaces ---------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
10+
#define MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
11+
12+
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
13+
#include "mlir/IR/OpDefinition.h"
14+
15+
namespace mlir {
16+
namespace transform {
17+
class MatchOpInterface;
18+
19+
template <typename OpTy>
20+
class SingleOpMatcherOpTrait
21+
: public OpTrait::TraitBase<OpTy, SingleOpMatcherOpTrait> {
22+
template <typename T>
23+
using has_get_operand_handle =
24+
decltype(std::declval<T &>().getOperandHandle());
25+
template <typename T>
26+
using has_match_operation = decltype(std::declval<T &>().matchOperation(
27+
std::declval<Operation *>(), std::declval<TransformResults &>(),
28+
std::declval<TransformState &>()));
29+
30+
public:
31+
static LogicalResult verifyTrait(Operation *op) {
32+
static_assert(llvm::is_detected<has_get_operand_handle, OpTy>::value,
33+
"SingleOpMatcherOpTrait expects operation type to have the "
34+
"getOperandHandle() method");
35+
static_assert(llvm::is_detected<has_match_operation, OpTy>::value,
36+
"SingleOpMatcherOpTrait expected operation type to have the "
37+
"matchOperation(Operation *, TransformResults &, "
38+
"TransformState &) method");
39+
40+
// This must be a dynamic assert because interface registration is dynamic.
41+
assert(isa<MatchOpInterface>(op) &&
42+
"SingleOpMatchOpTrait is only available on operations with "
43+
"MatchOpInterface");
44+
Value operandHandle = cast<OpTy>(op).getOperandHandle();
45+
if (!operandHandle.getType().isa<TransformHandleTypeInterface>()) {
46+
return op->emitError() << "SingleOpMatchOpTrait requires the op handle "
47+
"to be of TransformHandleTypeInterface";
48+
}
49+
50+
return success();
51+
}
52+
53+
DiagnosedSilenceableFailure apply(TransformResults &results,
54+
TransformState &state) {
55+
Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
56+
ArrayRef<Operation *> payload = state.getPayloadOps(operandHandle);
57+
if (payload.size() != 1) {
58+
return emitDefiniteFailure(this->getOperation()->getLoc())
59+
<< "SingleOpMatchOpTrait requires the operand handle to point to "
60+
"a single payload op";
61+
}
62+
63+
return cast<OpTy>(this->getOperation())
64+
.matchOperation(payload[0], results, state);
65+
}
66+
};
67+
68+
} // namespace transform
69+
} // namespace mlir
70+
71+
#include "mlir/Dialect/Transform/IR/MatchInterfaces.h.inc"
72+
73+
#endif // MLIR_DIALECT_TRANSFORM_IR_MATCHINTERFACES_H
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//===- MatchInterfaces.td - Transform dialect interfaces ---*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
include "mlir/IR/OpBase.td"
10+
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
11+
12+
def MatchOpInterface
13+
: OpInterface<"MatchOpInterface", [TransformOpInterface]> {
14+
let cppNamespace = "::mlir::transform";
15+
}
16+
17+
def SingleOpMatcher : NativeOpTrait<"SingleOpMatcherOpTrait"> {
18+
let cppNamespace = "::mlir::transform";
19+
20+
string extraDeclaration = [{
21+
::mlir::DiagnosedSilenceableFailure matchOperation(
22+
::mlir::Operation *current,
23+
::mlir::transform::TransformResults &results,
24+
::mlir::transform::TransformState &state);
25+
}];
26+
}

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
#include "mlir/Support/LogicalResult.h"
1818

1919
namespace mlir {
20+
2021
namespace transform {
2122

2223
class TransformOpInterface;
24+
class TransformResults;
2325

2426
/// Options controlling the application of transform operations by the
2527
/// TransformState.
@@ -400,6 +402,11 @@ class TransformState {
400402
return it->second;
401403
}
402404

405+
/// Updates the state to include the associations between op results and the
406+
/// provided result of applying a transform op.
407+
LogicalResult updateStateFromResults(const TransformResults &results,
408+
ResultRange opResults);
409+
403410
/// Sets the payload IR ops associated with the given transform IR value
404411
/// (handle). A payload op may be associated multiple handles as long as
405412
/// at most one of them gets consumed by further transformations.
@@ -690,6 +697,11 @@ LogicalResult verifyTransformOpInterface(Operation *op);
690697
void prepareValueMappings(
691698
SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings,
692699
ValueRange values, const transform::TransformState &state);
700+
701+
/// Populates `results` with payload associations that match exactly those of
702+
/// the operands to `block`'s terminator.
703+
void forwardTerminatorOperands(Block *block, transform::TransformState &state,
704+
transform::TransformResults &results);
693705
} // namespace detail
694706

695707
/// This trait is supposed to be attached to Transform dialect operations that

mlir/include/mlir/Dialect/Transform/IR/TransformOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H
1111

1212
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
13+
#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
1314
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
1415
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
1516
#include "mlir/IR/FunctionInterfaces.h"

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
1717
include "mlir/IR/FunctionInterfaces.td"
1818
include "mlir/IR/OpAsmInterface.td"
1919
include "mlir/IR/SymbolInterfaces.td"
20+
include "mlir/Dialect/Transform/IR/MatchInterfaces.td"
2021
include "mlir/Dialect/Transform/IR/TransformAttrs.td"
2122
include "mlir/Dialect/Transform/IR/TransformDialect.td"
2223
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
@@ -116,6 +117,69 @@ def CastOp : TransformDialectOp<"cast",
116117
}];
117118
}
118119

120+
def ForeachMatchOp : TransformDialectOp<"foreach_match", [
121+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
122+
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
123+
DeclareOpInterfaceMethods<TransformOpInterface>]> {
124+
let summary = "Applies named sequences when a named matcher succeeds";
125+
let description = [{
126+
Given a pair of co-indexed lists of transform dialect symbols (such as
127+
`transform.named_sequence`), walks the payload IR associated with the root
128+
handle and interprets the symbols as matcher/action pairs by applying the
129+
body of the corresponding symbol definition. The symbol from the first list
130+
is the matcher part: if it results in a silenceable error, the error is
131+
silenced and the next matcher is attempted. Definite failures from any
132+
matcher stop the application immediately and are propagated unconditionally.
133+
If none of the matchers succeeds, the next payload operation in walk order
134+
(post-order at the moment of writing, double check `Operation::walk`) is
135+
matched. If a matcher succeeds, the co-indexed action symbol is applied and
136+
the following matchers are not applied to the same payload operation. If the
137+
action succeeds, the next payload operation in walk order is matched. If it
138+
fails, both silenceable and definite errors are propagated as the result of
139+
this op.
140+
141+
The matcher symbol must take one operand of a type that implements the same
142+
transform dialect interface as the `root` operand (a check is performed at
143+
application time to see if the associated payload satisfies the constraints
144+
of the actual type). It must not consume the operand as multiple matchers
145+
may be applied. The matcher may produce any number of results. The action
146+
symbol paired with the matcher must take the same number of arguments as the
147+
matcher has results, and these arguments must implement the same transform
148+
dialect interfaces, but not necessarily have the exact same type (again, a
149+
check is performed at application time to see if the associated payload
150+
satisfies the constraints of actual types on both sides). The action symbol
151+
may not have results. The actions are expected to only modify payload
152+
operations nested in the `root` payload operations associated with the
153+
operand of this transform operation.
154+
155+
This operation consumes the operand and produces a new handle associated
156+
with the same payload. This is necessary to trigger invalidation of handles
157+
to any of the payload operations nested in the payload operations associated
158+
with the operand, as those are likely to be modified by actions. Note that
159+
the root payload operation associated with the operand are not matched.
160+
161+
The operation succeeds if none of the matchers produced a definite failure
162+
during application and if all of the applied actions produced success. Note
163+
that it also succeeds if all the matchers failed on all payload operations,
164+
i.e. failure to apply is not an error. The operation produces a silenceable
165+
failure if any applied action produced a silenceable failure. In this case,
166+
the resulting handle is associated with an empty payload. The operation
167+
produces a definite failure if any of the applied matchers or actions
168+
produced a definite failure.
169+
}];
170+
171+
let arguments = (ins TransformHandleTypeInterface:$root,
172+
SymbolRefArrayAttr:$matchers,
173+
SymbolRefArrayAttr:$actions);
174+
let results = (outs TransformHandleTypeInterface:$updated);
175+
176+
let assemblyFormat =
177+
"`in` $root custom<ForeachMatchSymbols>($matchers, $actions) "
178+
"attr-dict `:` functional-type($root, $updated)";
179+
180+
let hasVerifier = 1;
181+
}
182+
119183
def ForeachOp : TransformDialectOp<"foreach",
120184
[DeclareOpInterfaceMethods<TransformOpInterface>,
121185
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
@@ -270,6 +334,7 @@ def GetResultOp : TransformDialectOp<"get_result",
270334

271335
def IncludeOp : TransformDialectOp<"include",
272336
[CallOpInterface,
337+
MatchOpInterface,
273338
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
274339
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
275340
DeclareOpInterfaceMethods<TransformOpInterface>]> {

mlir/lib/Dialect/Transform/IR/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
add_mlir_dialect_library(MLIRTransformDialect
2+
MatchInterfaces.cpp
23
TransformDialect.cpp
34
TransformInterfaces.cpp
45
TransformOps.cpp
56
TransformTypes.cpp
67

78
DEPENDS
9+
MLIRMatchInterfacesIncGen
810
MLIRTransformDialectIncGen
911
MLIRTransformInterfacesIncGen
1012

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
//===- MatchInterfaces.cpp - Transform Dialect Interfaces -----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
10+
11+
using namespace mlir;
12+
13+
//===----------------------------------------------------------------------===//
14+
// Generated interface implementation.
15+
//===----------------------------------------------------------------------===//
16+
17+
#include "mlir/Dialect/Transform/IR/MatchInterfaces.cpp.inc"

mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -920,40 +920,44 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
920920
}
921921
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
922922

923-
for (OpResult result : transform->getResults()) {
924-
assert(result.getDefiningOp() == transform.getOperation() &&
925-
"payload IR association for a value other than the result of the "
926-
"current transform op");
923+
if (failed(updateStateFromResults(results, transform->getResults())))
924+
return DiagnosedSilenceableFailure::definiteFailure();
925+
926+
printOnFailureRAII.release();
927+
DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
928+
DBGS() << "Top-level payload:\n";
929+
getTopLevel()->print(llvm::dbgs());
930+
});
931+
return result;
932+
}
933+
934+
LogicalResult transform::TransformState::updateStateFromResults(
935+
const TransformResults &results, ResultRange opResults) {
936+
for (OpResult result : opResults) {
927937
if (result.getType().isa<TransformParamTypeInterface>()) {
928938
assert(results.isParam(result.getResultNumber()) &&
929939
"expected parameters for the parameter-typed result");
930940
if (failed(
931941
setParams(result, results.getParams(result.getResultNumber())))) {
932-
return DiagnosedSilenceableFailure::definiteFailure();
942+
return failure();
933943
}
934944
} else if (result.getType().isa<TransformValueHandleTypeInterface>()) {
935945
assert(results.isValue(result.getResultNumber()) &&
936946
"expected values for value-type-result");
937947
if (failed(setPayloadValues(
938948
result, results.getValues(result.getResultNumber())))) {
939-
return DiagnosedSilenceableFailure::definiteFailure();
949+
return failure();
940950
}
941951
} else {
942952
assert(!results.isParam(result.getResultNumber()) &&
943953
"expected payload ops for the non-parameter typed result");
944954
if (failed(
945955
setPayloadOps(result, results.get(result.getResultNumber())))) {
946-
return DiagnosedSilenceableFailure::definiteFailure();
956+
return failure();
947957
}
948958
}
949959
}
950-
951-
printOnFailureRAII.release();
952-
DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
953-
DBGS() << "Top-level payload:\n";
954-
getTopLevel()->print(llvm::dbgs());
955-
});
956-
return result;
960+
return success();
957961
}
958962

959963
//===----------------------------------------------------------------------===//
@@ -1193,7 +1197,7 @@ void transform::detail::setApplyToOneResults(
11931197
}
11941198

11951199
//===----------------------------------------------------------------------===//
1196-
// Utilities for PossibleTopLevelTransformOpTrait.
1200+
// Utilities for implementing transform ops with regions.
11971201
//===----------------------------------------------------------------------===//
11981202

11991203
void transform::detail::prepareValueMappings(
@@ -1213,6 +1217,29 @@ void transform::detail::prepareValueMappings(
12131217
}
12141218
}
12151219

1220+
void transform::detail::forwardTerminatorOperands(
1221+
Block *block, transform::TransformState &state,
1222+
transform::TransformResults &results) {
1223+
for (auto &&[terminatorOperand, result] :
1224+
llvm::zip(block->getTerminator()->getOperands(),
1225+
block->getParentOp()->getOpResults())) {
1226+
if (result.getType().isa<transform::TransformHandleTypeInterface>()) {
1227+
results.set(result, state.getPayloadOps(terminatorOperand));
1228+
} else if (result.getType()
1229+
.isa<transform::TransformValueHandleTypeInterface>()) {
1230+
results.setValues(result, state.getPayloadValues(terminatorOperand));
1231+
} else {
1232+
assert(result.getType().isa<transform::TransformParamTypeInterface>() &&
1233+
"unhandled transform type interface");
1234+
results.setParams(result, state.getParams(terminatorOperand));
1235+
}
1236+
}
1237+
}
1238+
1239+
//===----------------------------------------------------------------------===//
1240+
// Utilities for PossibleTopLevelTransformOpTrait.
1241+
//===----------------------------------------------------------------------===//
1242+
12161243
LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
12171244
TransformState &state, Operation *op, Region &region) {
12181245
SmallVector<Operation *> targets;

0 commit comments

Comments
 (0)