15
15
#include " mlir/Dialect/Transform/IR/TransformDialect.h"
16
16
#include " mlir/Dialect/Transform/IR/TransformInterfaces.h"
17
17
#include " mlir/Dialect/Transform/IR/TransformOps.h"
18
+ #include " mlir/Transforms/DialectConversion.h"
18
19
19
20
using namespace mlir ;
20
21
@@ -36,6 +37,196 @@ transform::ApplyFuncToLLVMConversionPatternsOp::verifyTypeConverter(
36
37
return success ();
37
38
}
38
39
40
+ // ===----------------------------------------------------------------------===//
41
+ // CastAndCallOp
42
+ // ===----------------------------------------------------------------------===//
43
+
44
+ DiagnosedSilenceableFailure
45
+ transform::CastAndCallOp::apply (transform::TransformRewriter &rewriter,
46
+ transform::TransformResults &results,
47
+ transform::TransformState &state) {
48
+ SmallVector<Value> inputs;
49
+ if (getInputs ())
50
+ llvm::append_range (inputs, state.getPayloadValues (getInputs ()));
51
+
52
+ SetVector<Value> outputs;
53
+ if (getOutputs ()) {
54
+ for (auto output : state.getPayloadValues (getOutputs ()))
55
+ outputs.insert (output);
56
+
57
+ // Verify that the set of output values to be replaced is unique.
58
+ if (outputs.size () !=
59
+ llvm::range_size (state.getPayloadValues (getOutputs ()))) {
60
+ return emitSilenceableFailure (getLoc ())
61
+ << " cast and call output values must be unique" ;
62
+ }
63
+ }
64
+
65
+ // Get the insertion point for the call.
66
+ auto insertionOps = state.getPayloadOps (getInsertionPoint ());
67
+ if (!llvm::hasSingleElement (insertionOps)) {
68
+ return emitSilenceableFailure (getLoc ())
69
+ << " Only one op can be specified as an insertion point" ;
70
+ }
71
+ bool insertAfter = getInsertAfter ();
72
+ Operation *insertionPoint = *insertionOps.begin ();
73
+
74
+ // Check that all inputs dominate the insertion point, and the insertion
75
+ // point dominates all users of the outputs.
76
+ DominanceInfo dom (insertionPoint);
77
+ for (Value output : outputs) {
78
+ for (Operation *user : output.getUsers ()) {
79
+ // If we are inserting after the insertion point operation, the
80
+ // insertion point operation must properly dominate the user. Otherwise
81
+ // basic dominance is enough.
82
+ bool doesDominate = insertAfter
83
+ ? dom.properlyDominates (insertionPoint, user)
84
+ : dom.dominates (insertionPoint, user);
85
+ if (!doesDominate) {
86
+ return emitDefiniteFailure ()
87
+ << " User " << user << " is not dominated by insertion point "
88
+ << insertionPoint;
89
+ }
90
+ }
91
+ }
92
+
93
+ for (Value input : inputs) {
94
+ // If we are inserting before the insertion point operation, the
95
+ // input must properly dominate the insertion point operation. Otherwise
96
+ // basic dominance is enough.
97
+ bool doesDominate = insertAfter
98
+ ? dom.dominates (input, insertionPoint)
99
+ : dom.properlyDominates (input, insertionPoint);
100
+ if (!doesDominate) {
101
+ return emitDefiniteFailure ()
102
+ << " input " << input << " does not dominate insertion point "
103
+ << insertionPoint;
104
+ }
105
+ }
106
+
107
+ // Get the function to call. This can either be specified by symbol or as a
108
+ // transform handle.
109
+ func::FuncOp targetFunction = nullptr ;
110
+ if (getFunctionName ()) {
111
+ targetFunction = SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
112
+ insertionPoint, *getFunctionName ());
113
+ if (!targetFunction) {
114
+ return emitDefiniteFailure ()
115
+ << " unresolved symbol " << *getFunctionName ();
116
+ }
117
+ } else if (getFunction ()) {
118
+ auto payloadOps = state.getPayloadOps (getFunction ());
119
+ if (!llvm::hasSingleElement (payloadOps)) {
120
+ return emitDefiniteFailure () << " requires a single function to call" ;
121
+ }
122
+ targetFunction = dyn_cast<func::FuncOp>(*payloadOps.begin ());
123
+ if (!targetFunction) {
124
+ return emitDefiniteFailure () << " invalid non-function callee" ;
125
+ }
126
+ } else {
127
+ llvm_unreachable (" Invalid CastAndCall op without a function to call" );
128
+ return emitDefiniteFailure ();
129
+ }
130
+
131
+ // Verify that the function argument and result lengths match the inputs and
132
+ // outputs given to this op.
133
+ if (targetFunction.getNumArguments () != inputs.size ()) {
134
+ return emitSilenceableFailure (targetFunction.getLoc ())
135
+ << " mismatch between number of function arguments "
136
+ << targetFunction.getNumArguments () << " and number of inputs "
137
+ << inputs.size ();
138
+ }
139
+ if (targetFunction.getNumResults () != outputs.size ()) {
140
+ return emitSilenceableFailure (targetFunction.getLoc ())
141
+ << " mismatch between number of function results "
142
+ << targetFunction->getNumResults () << " and number of outputs "
143
+ << outputs.size ();
144
+ }
145
+
146
+ // Gather all specified converters.
147
+ mlir::TypeConverter converter;
148
+ if (!getRegion ().empty ()) {
149
+ for (Operation &op : getRegion ().front ()) {
150
+ cast<transform::TypeConverterBuilderOpInterface>(&op)
151
+ .populateTypeMaterializations (converter);
152
+ }
153
+ }
154
+
155
+ if (insertAfter)
156
+ rewriter.setInsertionPointAfter (insertionPoint);
157
+ else
158
+ rewriter.setInsertionPoint (insertionPoint);
159
+
160
+ for (auto [input, type] :
161
+ llvm::zip_equal (inputs, targetFunction.getArgumentTypes ())) {
162
+ if (input.getType () != type) {
163
+ Value newInput = converter.materializeSourceConversion (
164
+ rewriter, input.getLoc (), type, input);
165
+ if (!newInput) {
166
+ return emitDefiniteFailure () << " Failed to materialize conversion of "
167
+ << input << " to type " << type;
168
+ }
169
+ input = newInput;
170
+ }
171
+ }
172
+
173
+ auto callOp = rewriter.create <func::CallOp>(insertionPoint->getLoc (),
174
+ targetFunction, inputs);
175
+
176
+ // Cast the call results back to the expected types. If any conversions fail
177
+ // this is a definite failure as the call has been constructed at this point.
178
+ for (auto [output, newOutput] :
179
+ llvm::zip_equal (outputs, callOp.getResults ())) {
180
+ Value convertedOutput = newOutput;
181
+ if (output.getType () != newOutput.getType ()) {
182
+ convertedOutput = converter.materializeTargetConversion (
183
+ rewriter, output.getLoc (), output.getType (), newOutput);
184
+ if (!convertedOutput) {
185
+ return emitDefiniteFailure ()
186
+ << " Failed to materialize conversion of " << newOutput
187
+ << " to type " << output.getType ();
188
+ }
189
+ }
190
+ rewriter.replaceAllUsesExcept (output, convertedOutput, callOp);
191
+ }
192
+ results.set (cast<OpResult>(getResult ()), {callOp});
193
+ return DiagnosedSilenceableFailure::success ();
194
+ }
195
+
196
+ LogicalResult transform::CastAndCallOp::verify () {
197
+ if (!getRegion ().empty ()) {
198
+ for (Operation &op : getRegion ().front ()) {
199
+ if (!isa<transform::TypeConverterBuilderOpInterface>(&op)) {
200
+ InFlightDiagnostic diag = emitOpError ()
201
+ << " expected children ops to implement "
202
+ " TypeConverterBuilderOpInterface" ;
203
+ diag.attachNote (op.getLoc ()) << " op without interface" ;
204
+ return diag;
205
+ }
206
+ }
207
+ }
208
+ if (!getFunction () && !getFunctionName ()) {
209
+ return emitOpError () << " expected a function handle or name to call" ;
210
+ }
211
+ if (getFunction () && getFunctionName ()) {
212
+ return emitOpError () << " function handle and name are mutually exclusive" ;
213
+ }
214
+ return success ();
215
+ }
216
+
217
+ void transform::CastAndCallOp::getEffects (
218
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
219
+ transform::onlyReadsHandle (getInsertionPoint (), effects);
220
+ if (getInputs ())
221
+ transform::onlyReadsHandle (getInputs (), effects);
222
+ if (getOutputs ())
223
+ transform::onlyReadsHandle (getOutputs (), effects);
224
+ if (getFunction ())
225
+ transform::onlyReadsHandle (getFunction (), effects);
226
+ transform::producesHandle (getResult (), effects);
227
+ transform::modifiesPayload (effects);
228
+ }
229
+
39
230
// ===----------------------------------------------------------------------===//
40
231
// Transform op registration
41
232
// ===----------------------------------------------------------------------===//
0 commit comments