Skip to content

Commit 42b1603

Browse files
authored
[mlir][transform] Add an op for replacing values with function calls (#78398)
Adds `transform.func.cast_and_call` that takes a set of inputs and outputs and replaces the uses of those outputs with a call to a function at a specified insertion point. The idea with this operation is to allow users to author independent IR outside of a to-be-compiled module, and then match and replace a slice of the program with a call to the external function. Additionally adds a mechanism for populating a type converter with a set of conversion materialization functions that allow insertion of casts on the inputs/outputs to and from the types of the function signature.
1 parent 0784b1e commit 42b1603

File tree

10 files changed

+538
-5
lines changed

10 files changed

+538
-5
lines changed

mlir/include/mlir/Dialect/Func/TransformOps/FuncTransformOps.td

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
include "mlir/Dialect/Transform/IR/TransformDialect.td"
1313
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
1414
include "mlir/Dialect/Transform/IR/TransformTypes.td"
15+
include "mlir/Interfaces/SideEffectInterfaces.td"
16+
include "mlir/IR/RegionKindInterface.td"
1517
include "mlir/IR/OpBase.td"
1618

1719
def ApplyFuncToLLVMConversionPatternsOp : Op<Transform_Dialect,
@@ -26,4 +28,74 @@ def ApplyFuncToLLVMConversionPatternsOp : Op<Transform_Dialect,
2628
let assemblyFormat = "attr-dict";
2729
}
2830

31+
def CastAndCallOp : Op<Transform_Dialect,
32+
"func.cast_and_call",
33+
[DeclareOpInterfaceMethods<TransformOpInterface>,
34+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
35+
AttrSizedOperandSegments,
36+
ReportTrackingListenerFailuresOpTrait]
37+
# GraphRegionNoTerminator.traits> {
38+
let summary = "Casts values to the signature of a function and replaces them "
39+
"with a call";
40+
let description = [{
41+
This transform takes value handles to a set of `inputs` and `outputs` and
42+
attempts to cast them to the function signature of the attached function
43+
op, then builds a call to the function and replaces the users of the
44+
outputs. It is the responsibility of the user to ensure that the slice of
45+
the program replaced by this operation makes sense, i.e. there is no
46+
verification that the inputs to this operation have any relation to the
47+
outputs outside of basic dominance requirements needed for the call.
48+
49+
The casting materialization functions are specified in the graph region of
50+
this op. They must implement the `TypeConverterBuilderOpInterface`. The
51+
order of ops within the region is irrelevant.
52+
53+
The target function can be specified by a symbol name or by a handle to the
54+
operation.
55+
56+
This transform only reads the operand handles and only replaces the users of
57+
the outputs with the results of the call. No handles are consumed and no
58+
operations are removed. Users are expected to run cleanup separately if
59+
desired.
60+
61+
Warning: The replacement of the uses of the outputs could invalidate certain
62+
restricted value handle types (e.g. `transform.block_arg` if it existed, by
63+
replacing the use with something not coming from a block argument). The
64+
value will still exist in such cases but wouldn't verify against the type.
65+
See the discussion here for more information:
66+
https://github.com/llvm/llvm-project/pull/78398#discussion_r1455070087
67+
68+
This transform will emit a silenceable failure if:
69+
- The set of outputs isn't unique
70+
- The handle for the insertion point does not include exactly one operation
71+
- The insertion point op does not dominate any of the output users
72+
- The insertion point op is not dominated by any of the inputs
73+
- The function signature does not match the number of inputs/outputs
74+
75+
This transform will emit a definite failure if it fails to resolve the
76+
target function, or if it fails to materialize the conversion casts of
77+
either the inputs to the function argument types, or the call results to
78+
the output types.
79+
}];
80+
81+
let arguments = (ins
82+
TransformHandleTypeInterface:$insertion_point,
83+
UnitAttr:$insert_after,
84+
Optional<TransformValueHandleTypeInterface>:$inputs,
85+
Optional<TransformValueHandleTypeInterface>:$outputs,
86+
OptionalAttr<SymbolRefAttr>:$function_name,
87+
Optional<TransformHandleTypeInterface>:$function);
88+
let results = (outs TransformHandleTypeInterface:$result);
89+
let regions = (region MaxSizedRegion<1>:$conversions);
90+
91+
let assemblyFormat = [{
92+
($function_name^)? ($function^)?
93+
( `(` $inputs^ `)` )?
94+
( `->` $outputs^ )?
95+
(`after` $insert_after^):(`before`)? $insertion_point
96+
($conversions^)? attr-dict `:` functional-type(operands, results)
97+
}];
98+
let hasVerifier = 1;
99+
}
100+
29101
#endif // FUNC_TRANSFORM_OPS

mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ include "mlir/IR/OpBase.td"
1818
def MemrefToLLVMTypeConverterOp : Op<Transform_Dialect,
1919
"apply_conversion_patterns.memref.memref_to_llvm_type_converter",
2020
[DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface,
21-
["getTypeConverterType"]>]> {
21+
["getTypeConverter",
22+
"getTypeConverterType"]>]> {
2223
let description = [{
2324
This operation provides an "LLVMTypeConverter" that lowers memref types to
2425
LLVM types.

mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,4 +169,22 @@ def MakeLoopIndependentOp
169169
}];
170170
}
171171

172+
def TypeConversionCastShapeDynamicDimsOp : Op<Transform_Dialect,
173+
"type_conversion.tensor.cast_shape_dynamic_dims",
174+
[DeclareOpInterfaceMethods<TypeConverterBuilderOpInterface,
175+
["populateTypeMaterializations"]>]> {
176+
let description = [{
177+
Populates a type converter with conversion materialization functions that
178+
cast a tensor value between two cast-compatible tensors. See `tensor.cast`
179+
for more information on cast compatibility between tensors.
180+
181+
If `ignore_dynamic_info` is not set, this will set an additional constraint
182+
that source materializations do not cast dynamic dimensions to static ones.
183+
}];
184+
let arguments = (ins UnitAttr:$ignore_dynamic_info);
185+
186+
let assemblyFormat =
187+
"(`ignore_dynamic_info` $ignore_dynamic_info^)? attr-dict";
188+
}
189+
172190
#endif // TENSOR_TRANSFORM_OPS

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,14 @@ def TypeConverterBuilderOpInterface
284284
: OpInterface<"TypeConverterBuilderOpInterface"> {
285285
let description = [{
286286
This interface should be implemented by ops that specify a type converter
287-
for a dialect conversion. Such ops can be used with
288-
"apply_conversion_patterns".
287+
for a dialect conversion, or to populate a type converter with
288+
conversions.
289+
290+
When such ops are intended to be used with "apply_conversion_patterns" or
291+
other operations that expect a type converter, a non-default implementation
292+
of `getTypeConverter` should be implemented. For use with "cast_and_call"
293+
like ops that construct a type converter iteratively, non-default
294+
`populateTypeMaterializations` should be implemented.
289295
}];
290296

291297
let cppNamespace = "::mlir::transform";
@@ -297,7 +303,11 @@ def TypeConverterBuilderOpInterface
297303
}],
298304
/*returnType=*/"std::unique_ptr<::mlir::TypeConverter>",
299305
/*name=*/"getTypeConverter",
300-
/*arguments=*/(ins)
306+
/*arguments=*/(ins),
307+
/*methodBody=*/"",
308+
/*defaultImplementation=*/[{
309+
return std::make_unique<::mlir::TypeConverter>();
310+
}]
301311
>,
302312
StaticInterfaceMethod<
303313
/*desc=*/[{
@@ -310,6 +320,17 @@ def TypeConverterBuilderOpInterface
310320
/*methodBody=*/"",
311321
/*defaultImplementation=*/[{ return "TypeConverter"; }]
312322
>,
323+
InterfaceMethod<
324+
/*desc=*/[{
325+
Populate the given type converter with source/target materialization
326+
functions.
327+
}],
328+
/*returnType=*/"void",
329+
/*name=*/"populateTypeMaterializations",
330+
/*arguments=*/(ins "::mlir::TypeConverter &":$converter),
331+
/*methodBody=*/"",
332+
/*defaultImplementation=*/[{ return; }]
333+
>,
313334
];
314335
}
315336

mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
1616
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
1717
#include "mlir/Dialect/Transform/IR/TransformOps.h"
18+
#include "mlir/Transforms/DialectConversion.h"
1819

1920
using namespace mlir;
2021

@@ -36,6 +37,196 @@ transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
3637
return success();
3738
}
3839

40+
//===----------------------------------------------------------------------===//
41+
// CastAndCallOp
42+
//===----------------------------------------------------------------------===//
43+
44+
DiagnosedSilenceableFailure
45+
transform::CastAndCallOp::apply(transform::TransformRewriter &rewriter,
46+
transform::TransformResults &results,
47+
transform::TransformState &state) {
48+
SmallVector<Value> inputs;
49+
if (getInputs())
50+
llvm::append_range(inputs, state.getPayloadValues(getInputs()));
51+
52+
SetVector<Value> outputs;
53+
if (getOutputs()) {
54+
for (auto output : state.getPayloadValues(getOutputs()))
55+
outputs.insert(output);
56+
57+
// Verify that the set of output values to be replaced is unique.
58+
if (outputs.size() !=
59+
llvm::range_size(state.getPayloadValues(getOutputs()))) {
60+
return emitSilenceableFailure(getLoc())
61+
<< "cast and call output values must be unique";
62+
}
63+
}
64+
65+
// Get the insertion point for the call.
66+
auto insertionOps = state.getPayloadOps(getInsertionPoint());
67+
if (!llvm::hasSingleElement(insertionOps)) {
68+
return emitSilenceableFailure(getLoc())
69+
<< "Only one op can be specified as an insertion point";
70+
}
71+
bool insertAfter = getInsertAfter();
72+
Operation *insertionPoint = *insertionOps.begin();
73+
74+
// Check that all inputs dominate the insertion point, and the insertion
75+
// point dominates all users of the outputs.
76+
DominanceInfo dom(insertionPoint);
77+
for (Value output : outputs) {
78+
for (Operation *user : output.getUsers()) {
79+
// If we are inserting after the insertion point operation, the
80+
// insertion point operation must properly dominate the user. Otherwise
81+
// basic dominance is enough.
82+
bool doesDominate = insertAfter
83+
? dom.properlyDominates(insertionPoint, user)
84+
: dom.dominates(insertionPoint, user);
85+
if (!doesDominate) {
86+
return emitDefiniteFailure()
87+
<< "User " << user << " is not dominated by insertion point "
88+
<< insertionPoint;
89+
}
90+
}
91+
}
92+
93+
for (Value input : inputs) {
94+
// If we are inserting before the insertion point operation, the
95+
// input must properly dominate the insertion point operation. Otherwise
96+
// basic dominance is enough.
97+
bool doesDominate = insertAfter
98+
? dom.dominates(input, insertionPoint)
99+
: dom.properlyDominates(input, insertionPoint);
100+
if (!doesDominate) {
101+
return emitDefiniteFailure()
102+
<< "input " << input << " does not dominate insertion point "
103+
<< insertionPoint;
104+
}
105+
}
106+
107+
// Get the function to call. This can either be specified by symbol or as a
108+
// transform handle.
109+
func::FuncOp targetFunction = nullptr;
110+
if (getFunctionName()) {
111+
targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
112+
insertionPoint, *getFunctionName());
113+
if (!targetFunction) {
114+
return emitDefiniteFailure()
115+
<< "unresolved symbol " << *getFunctionName();
116+
}
117+
} else if (getFunction()) {
118+
auto payloadOps = state.getPayloadOps(getFunction());
119+
if (!llvm::hasSingleElement(payloadOps)) {
120+
return emitDefiniteFailure() << "requires a single function to call";
121+
}
122+
targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin());
123+
if (!targetFunction) {
124+
return emitDefiniteFailure() << "invalid non-function callee";
125+
}
126+
} else {
127+
llvm_unreachable("Invalid CastAndCall op without a function to call");
128+
return emitDefiniteFailure();
129+
}
130+
131+
// Verify that the function argument and result lengths match the inputs and
132+
// outputs given to this op.
133+
if (targetFunction.getNumArguments() != inputs.size()) {
134+
return emitSilenceableFailure(targetFunction.getLoc())
135+
<< "mismatch between number of function arguments "
136+
<< targetFunction.getNumArguments() << " and number of inputs "
137+
<< inputs.size();
138+
}
139+
if (targetFunction.getNumResults() != outputs.size()) {
140+
return emitSilenceableFailure(targetFunction.getLoc())
141+
<< "mismatch between number of function results "
142+
<< targetFunction->getNumResults() << " and number of outputs "
143+
<< outputs.size();
144+
}
145+
146+
// Gather all specified converters.
147+
mlir::TypeConverter converter;
148+
if (!getRegion().empty()) {
149+
for (Operation &op : getRegion().front()) {
150+
cast<transform::TypeConverterBuilderOpInterface>(&op)
151+
.populateTypeMaterializations(converter);
152+
}
153+
}
154+
155+
if (insertAfter)
156+
rewriter.setInsertionPointAfter(insertionPoint);
157+
else
158+
rewriter.setInsertionPoint(insertionPoint);
159+
160+
for (auto [input, type] :
161+
llvm::zip_equal(inputs, targetFunction.getArgumentTypes())) {
162+
if (input.getType() != type) {
163+
Value newInput = converter.materializeSourceConversion(
164+
rewriter, input.getLoc(), type, input);
165+
if (!newInput) {
166+
return emitDefiniteFailure() << "Failed to materialize conversion of "
167+
<< input << " to type " << type;
168+
}
169+
input = newInput;
170+
}
171+
}
172+
173+
auto callOp = rewriter.create<func::CallOp>(insertionPoint->getLoc(),
174+
targetFunction, inputs);
175+
176+
// Cast the call results back to the expected types. If any conversions fail
177+
// this is a definite failure as the call has been constructed at this point.
178+
for (auto [output, newOutput] :
179+
llvm::zip_equal(outputs, callOp.getResults())) {
180+
Value convertedOutput = newOutput;
181+
if (output.getType() != newOutput.getType()) {
182+
convertedOutput = converter.materializeTargetConversion(
183+
rewriter, output.getLoc(), output.getType(), newOutput);
184+
if (!convertedOutput) {
185+
return emitDefiniteFailure()
186+
<< "Failed to materialize conversion of " << newOutput
187+
<< " to type " << output.getType();
188+
}
189+
}
190+
rewriter.replaceAllUsesExcept(output, convertedOutput, callOp);
191+
}
192+
results.set(cast<OpResult>(getResult()), {callOp});
193+
return DiagnosedSilenceableFailure::success();
194+
}
195+
196+
LogicalResult transform::CastAndCallOp::verify() {
197+
if (!getRegion().empty()) {
198+
for (Operation &op : getRegion().front()) {
199+
if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
200+
InFlightDiagnostic diag = emitOpError()
201+
<< "expected children ops to implement "
202+
"TypeConverterBuilderOpInterface";
203+
diag.attachNote(op.getLoc()) << "op without interface";
204+
return diag;
205+
}
206+
}
207+
}
208+
if (!getFunction() && !getFunctionName()) {
209+
return emitOpError() << "expected a function handle or name to call";
210+
}
211+
if (getFunction() && getFunctionName()) {
212+
return emitOpError() << "function handle and name are mutually exclusive";
213+
}
214+
return success();
215+
}
216+
217+
void transform::CastAndCallOp::getEffects(
218+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
219+
transform::onlyReadsHandle(getInsertionPoint(), effects);
220+
if (getInputs())
221+
transform::onlyReadsHandle(getInputs(), effects);
222+
if (getOutputs())
223+
transform::onlyReadsHandle(getOutputs(), effects);
224+
if (getFunction())
225+
transform::onlyReadsHandle(getFunction(), effects);
226+
transform::producesHandle(getResult(), effects);
227+
transform::modifiesPayload(effects);
228+
}
229+
39230
//===----------------------------------------------------------------------===//
40231
// Transform op registration
41232
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)