Skip to content

Commit 70caf27

Browse files
committed
[mlir][emitc] Add a structured for operation
Add an emitc.for op to the EmitC dialect as a lowering target for scf.for, replacing its current direct translation to C; The translator now handles emitc.for instead.
1 parent f2c09e5 commit 70caf27

File tree

9 files changed

+426
-122
lines changed

9 files changed

+426
-122
lines changed

mlir/docs/Dialects/emitc.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,5 @@ translating the following operations:
3131
* `func.constant`
3232
* `func.func`
3333
* `func.return`
34-
* 'scf' Dialect
35-
* `scf.for`
36-
* `scf.yield`
3734
* 'arith' Dialect
3835
* `arith.constant`

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,67 @@ def EmitC_DivOp : EmitC_BinaryOp<"div", []> {
246246
let results = (outs FloatIntegerIndexOrOpaqueType);
247247
}
248248

249+
def EmitC_ForOp : EmitC_Op<"for",
250+
[AllTypesMatch<["lowerBound", "upperBound", "step"]>,
251+
SingleBlockImplicitTerminator<"emitc::YieldOp">,
252+
RecursiveMemoryEffects]> {
253+
let summary = "for operation";
254+
let description = [{
255+
The `emitc.for` operation represents a C loop of the following form:
256+
257+
```c++
258+
for (T i = lb; i < ub; i += step) { /* ... */ } // where T is typeof(lb)
259+
```
260+
261+
The operation takes 3 SSA values as operands that represent the lower bound,
262+
upper bound and step respectively, and defines an SSA value for its
263+
induction variable. It has one region capturing the loop body. The induction
264+
variable is represented as an argument of this region. This SSA value is a
265+
signless integer or index. The step is a value of same type.
266+
267+
This operation has no result. The body region must contain exactly one block
268+
that terminates with `emitc.yield`. Calling ForOp::build will create such a
269+
region and insert the terminator implicitly if none is defined, so will the
270+
parsing even in cases when it is absent from the custom format. For example:
271+
272+
```mlir
273+
// Index case.
274+
emitc.for %iv = %lb to %ub step %step {
275+
... // body
276+
}
277+
...
278+
// Integer case.
279+
emitc.for %iv_32 = %lb_32 to %ub_32 step %step_32 : i32 {
280+
... // body
281+
}
282+
```
283+
}];
284+
let arguments = (ins IntegerIndexOrOpaqueType:$lowerBound,
285+
IntegerIndexOrOpaqueType:$upperBound,
286+
IntegerIndexOrOpaqueType:$step);
287+
let results = (outs);
288+
let regions = (region SizedRegion<1>:$region);
289+
290+
let skipDefaultBuilders = 1;
291+
let builders = [
292+
OpBuilder<(ins "Value":$lowerBound, "Value":$upperBound, "Value":$step,
293+
CArg<"function_ref<void(OpBuilder &, Location, Value)>", "nullptr">)>
294+
];
295+
296+
let extraClassDeclaration = [{
297+
using BodyBuilderFn =
298+
function_ref<void(OpBuilder &, Location, Value)>;
299+
Value getInductionVar() { return getBody()->getArgument(0); }
300+
void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
301+
void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
302+
void setStep(Value step) { getOperation()->setOperand(2, step); }
303+
}];
304+
305+
let hasCanonicalizer = 1;
306+
let hasCustomAssemblyFormat = 1;
307+
let hasRegionVerifier = 1;
308+
}
309+
249310
def EmitC_IncludeOp
250311
: EmitC_Op<"include", [HasParent<"ModuleOp">]> {
251312
let summary = "Include operation";
@@ -430,7 +491,8 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
430491
let assemblyFormat = "$value `:` type($value) `to` $var `:` type($var) attr-dict";
431492
}
432493

433-
def EmitC_YieldOp : EmitC_Op<"yield", [Pure, Terminator, ParentOneOf<["IfOp"]>]> {
494+
def EmitC_YieldOp : EmitC_Op<"yield",
495+
[Pure, Terminator, ParentOneOf<["IfOp", "ForOp"]>]> {
434496
let summary = "block termination operation";
435497
let description = [{
436498
"yield" terminates blocks within EmitC control-flow operations. Since

mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp

Lines changed: 99 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,100 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
3737
void runOnOperation() override;
3838
};
3939

40-
// Lower scf::if to emitc::if, implementing return values as emitc::variable's
40+
// Lower scf::for to emitc::for, implementing result values using
41+
// emitc::variable's updated within the loop body.
42+
struct ForLowering : public OpRewritePattern<ForOp> {
43+
using OpRewritePattern<ForOp>::OpRewritePattern;
44+
45+
LogicalResult matchAndRewrite(ForOp forOp,
46+
PatternRewriter &rewriter) const override;
47+
};
48+
49+
// Create an uninitialized emitc::variable op for each result of the given op.
50+
template <typename T>
51+
static SmallVector<Value> createVariablesForResults(T op,
52+
PatternRewriter &rewriter) {
53+
SmallVector<Value> resultVariables;
54+
55+
if (!op.getNumResults())
56+
return resultVariables;
57+
58+
Location loc = op->getLoc();
59+
MLIRContext *context = op.getContext();
60+
61+
OpBuilder::InsertionGuard guard(rewriter);
62+
rewriter.setInsertionPoint(op);
63+
64+
for (OpResult result : op.getResults()) {
65+
Type resultType = result.getType();
66+
emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
67+
emitc::VariableOp var =
68+
rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
69+
resultVariables.push_back(var);
70+
}
71+
72+
return resultVariables;
73+
}
74+
75+
// Create a series of assign ops assigning given values to given variables at
76+
// the current insertion point of given rewriter.
77+
static void assignValues(ValueRange values, SmallVector<Value> &variables,
78+
PatternRewriter &rewriter, Location loc) {
79+
for (auto [value, var] : llvm::zip(values, variables))
80+
rewriter.create<emitc::AssignOp>(loc, var, value);
81+
}
82+
83+
static void lowerYield(SmallVector<Value> &resultVariables,
84+
PatternRewriter &rewriter, scf::YieldOp yield) {
85+
Location loc = yield.getLoc();
86+
ValueRange operands = yield.getOperands();
87+
88+
OpBuilder::InsertionGuard guard(rewriter);
89+
rewriter.setInsertionPoint(yield);
90+
91+
assignValues(operands, resultVariables, rewriter, loc);
92+
93+
rewriter.create<emitc::YieldOp>(loc);
94+
rewriter.eraseOp(yield);
95+
}
96+
97+
LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
98+
PatternRewriter &rewriter) const {
99+
Location loc = forOp.getLoc();
100+
101+
// Create an emitc::variable op for each result. These variables will be
102+
// assigned to by emitc::assign ops within the loop body.
103+
SmallVector<Value> resultVariables =
104+
createVariablesForResults(forOp, rewriter);
105+
SmallVector<Value> iterArgsVariables =
106+
createVariablesForResults(forOp, rewriter);
107+
108+
assignValues(forOp.getInits(), iterArgsVariables, rewriter, loc);
109+
110+
emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
111+
loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
112+
113+
Block *loweredBody = loweredFor.getBody();
114+
115+
// Erase the auto-generated terminator for the lowered for op.
116+
rewriter.eraseOp(loweredBody->getTerminator());
117+
118+
SmallVector<Value> replacingValues;
119+
replacingValues.push_back(loweredFor.getInductionVar());
120+
replacingValues.append(iterArgsVariables.begin(), iterArgsVariables.end());
121+
122+
rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
123+
lowerYield(iterArgsVariables, rewriter,
124+
cast<scf::YieldOp>(loweredBody->getTerminator()));
125+
126+
// Copy iterArgs into results after the for loop.
127+
assignValues(iterArgsVariables, resultVariables, rewriter, loc);
128+
129+
rewriter.replaceOp(forOp, resultVariables);
130+
return success();
131+
}
132+
133+
// Lower scf::if to emitc::if, implementing result values as emitc::variable's
41134
// updated within the then and else regions.
42135
struct IfLowering : public OpRewritePattern<IfOp> {
43136
using OpRewritePattern<IfOp>::OpRewritePattern;
@@ -52,20 +145,10 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
52145
PatternRewriter &rewriter) const {
53146
Location loc = ifOp.getLoc();
54147

55-
SmallVector<Value> resultVariables;
56-
57148
// Create an emitc::variable op for each result. These variables will be
58149
// assigned to by emitc::assign ops within the then & else regions.
59-
if (ifOp.getNumResults()) {
60-
MLIRContext *context = ifOp.getContext();
61-
rewriter.setInsertionPoint(ifOp);
62-
for (OpResult result : ifOp.getResults()) {
63-
Type resultType = result.getType();
64-
auto noInit = emitc::OpaqueAttr::get(context, "");
65-
auto var = rewriter.create<emitc::VariableOp>(loc, resultType, noInit);
66-
resultVariables.push_back(var);
67-
}
68-
}
150+
SmallVector<Value> resultVariables =
151+
createVariablesForResults(ifOp, rewriter);
69152

70153
// Utility function to lower the contents of an scf::if region to an emitc::if
71154
// region. The contents of the scf::if regions is moved into the respective
@@ -76,16 +159,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
76159
Region &loweredRegion) {
77160
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
78161
Operation *terminator = loweredRegion.back().getTerminator();
79-
Location terminatorLoc = terminator->getLoc();
80-
ValueRange terminatorOperands = terminator->getOperands();
81-
rewriter.setInsertionPointToEnd(&loweredRegion.back());
82-
for (auto value2Var : llvm::zip(terminatorOperands, resultVariables)) {
83-
Value resultValue = std::get<0>(value2Var);
84-
Value resultVar = std::get<1>(value2Var);
85-
rewriter.create<emitc::AssignOp>(terminatorLoc, resultVar, resultValue);
86-
}
87-
rewriter.create<emitc::YieldOp>(terminatorLoc);
88-
rewriter.eraseOp(terminator);
162+
lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
89163
};
90164

91165
Region &thenRegion = ifOp.getThenRegion();
@@ -109,6 +183,7 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
109183
}
110184

111185
void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) {
186+
patterns.add<ForLowering>(patterns.getContext());
112187
patterns.add<IfLowering>(patterns.getContext());
113188
}
114189

@@ -118,7 +193,7 @@ void SCFToEmitCPass::runOnOperation() {
118193

119194
// Configure conversion to lower out SCF operations.
120195
ConversionTarget target(getContext());
121-
target.addIllegalOp<scf::IfOp>();
196+
target.addIllegalOp<scf::ForOp, scf::IfOp>();
122197
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
123198
if (failed(
124199
applyPartialConversion(getOperation(), target, std::move(patterns))))

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,101 @@ LogicalResult emitc::ConstantOp::verify() {
189189

190190
OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
191191

192+
//===----------------------------------------------------------------------===//
193+
// ForOp
194+
//===----------------------------------------------------------------------===//
195+
196+
void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
197+
Value ub, Value step, BodyBuilderFn bodyBuilder) {
198+
result.addOperands({lb, ub, step});
199+
Type t = lb.getType();
200+
Region *bodyRegion = result.addRegion();
201+
bodyRegion->push_back(new Block);
202+
Block &bodyBlock = bodyRegion->front();
203+
bodyBlock.addArgument(t, result.location);
204+
205+
// Create the default terminator if the builder is not provided.
206+
if (!bodyBuilder) {
207+
ForOp::ensureTerminator(*bodyRegion, builder, result.location);
208+
} else {
209+
OpBuilder::InsertionGuard guard(builder);
210+
builder.setInsertionPointToStart(&bodyBlock);
211+
bodyBuilder(builder, result.location, bodyBlock.getArgument(0));
212+
}
213+
}
214+
215+
void ForOp::getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) {}
216+
217+
ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
218+
Builder &builder = parser.getBuilder();
219+
Type type;
220+
221+
OpAsmParser::Argument inductionVariable;
222+
OpAsmParser::UnresolvedOperand lb, ub, step;
223+
224+
// Parse the induction variable followed by '='.
225+
if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
226+
// Parse loop bounds.
227+
parser.parseOperand(lb) || parser.parseKeyword("to") ||
228+
parser.parseOperand(ub) || parser.parseKeyword("step") ||
229+
parser.parseOperand(step))
230+
return failure();
231+
232+
// Parse the optional initial iteration arguments.
233+
SmallVector<OpAsmParser::Argument, 4> regionArgs;
234+
SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
235+
regionArgs.push_back(inductionVariable);
236+
237+
// Parse optional type, else assume Index.
238+
if (parser.parseOptionalColon())
239+
type = builder.getIndexType();
240+
else if (parser.parseType(type))
241+
return failure();
242+
243+
// Resolve input operands.
244+
regionArgs.front().type = type;
245+
if (parser.resolveOperand(lb, type, result.operands) ||
246+
parser.resolveOperand(ub, type, result.operands) ||
247+
parser.resolveOperand(step, type, result.operands))
248+
return failure();
249+
250+
// Parse the body region.
251+
Region *body = result.addRegion();
252+
if (parser.parseRegion(*body, regionArgs))
253+
return failure();
254+
255+
ForOp::ensureTerminator(*body, builder, result.location);
256+
257+
// Parse the optional attribute list.
258+
if (parser.parseOptionalAttrDict(result.attributes))
259+
return failure();
260+
261+
return success();
262+
}
263+
264+
void ForOp::print(OpAsmPrinter &p) {
265+
p << " " << getInductionVar() << " = " << getLowerBound() << " to "
266+
<< getUpperBound() << " step " << getStep();
267+
268+
p << ' ';
269+
if (Type t = getInductionVar().getType(); !t.isIndex())
270+
p << " : " << t << ' ';
271+
p.printRegion(getRegion(),
272+
/*printEntryBlockArgs=*/false,
273+
/*printBlockTerminators=*/false);
274+
p.printOptionalAttrDict((*this)->getAttrs());
275+
}
276+
277+
LogicalResult ForOp::verifyRegions() {
278+
// Check that the body defines as single block argument for the induction
279+
// variable.
280+
if (getInductionVar().getType() != getLowerBound().getType())
281+
return emitOpError(
282+
"expected induction variable to be same type as bounds and step");
283+
284+
return success();
285+
}
286+
192287
//===----------------------------------------------------------------------===//
193288
// IfOp
194289
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)