Skip to content

Commit 97f0ab7

Browse files
authored
[mlir][emitc] Add 'emitc.switch' op to the dialect (#102331)
This PR is continuation of the [previous one](#101478). As a result, the `emitc::SwitchOp` op was developed inspired by `scf::IndexSwitchOp`. Main points of PR: - Added the `emitc::SwitchOp` op to the EmitC dialect + CppEmitter - Corresponding tests were added - Conversion from the SCF dialect to the EmitC dialect for the op
1 parent 7afb51e commit 97f0ab7

File tree

9 files changed

+1391
-20
lines changed

9 files changed

+1391
-20
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1188,7 +1188,7 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
11881188
}
11891189

11901190
def EmitC_YieldOp : EmitC_Op<"yield",
1191-
[Pure, Terminator, ParentOneOf<["ExpressionOp", "IfOp", "ForOp"]>]> {
1191+
[Pure, Terminator, ParentOneOf<["ExpressionOp", "IfOp", "ForOp", "SwitchOp"]>]> {
11921192
let summary = "Block termination operation";
11931193
let description = [{
11941194
The `emitc.yield` terminates its parent EmitC op's region, optionally yielding
@@ -1302,5 +1302,87 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
13021302
let assemblyFormat = "$value `[` $indices `]` attr-dict `:` functional-type(operands, results)";
13031303
}
13041304

1305+
def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
1306+
SingleBlockImplicitTerminator<"emitc::YieldOp">,
1307+
DeclareOpInterfaceMethods<RegionBranchOpInterface,
1308+
["getRegionInvocationBounds",
1309+
"getEntrySuccessorRegions"]>]> {
1310+
let summary = "Switch operation";
1311+
let description = [{
1312+
The `emitc.switch` is a control-flow operation that branches to one of
1313+
the given regions based on the values of the argument and the cases.
1314+
The operand to a switch operation is a opaque, integral or pointer
1315+
wide types.
1316+
1317+
The operation always has a "default" region and any number of case regions
1318+
denoted by integer constants. Control-flow transfers to the case region
1319+
whose constant value equals the value of the argument. If the argument does
1320+
not equal any of the case values, control-flow transfer to the "default"
1321+
region.
1322+
1323+
The operation does not return any value. Moreover, case regions must be
1324+
explicitly terminated using the `emitc.yield` operation. Default region is
1325+
yielded implicitly.
1326+
1327+
Example:
1328+
1329+
```mlir
1330+
// Example:
1331+
emitc.switch %0 : i32
1332+
case 2 {
1333+
%1 = emitc.call_opaque "func_b" () : () -> i32
1334+
emitc.yield
1335+
}
1336+
case 5 {
1337+
%2 = emitc.call_opaque "func_a" () : () -> i32
1338+
emitc.yield
1339+
}
1340+
default {
1341+
%3 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
1342+
emitc.call_opaque "func2" (%3) : (f32) -> ()
1343+
}
1344+
```
1345+
```c++
1346+
// Code emitted for the operations above.
1347+
switch (v1) {
1348+
case 2: {
1349+
int32_t v2 = func_b();
1350+
break;
1351+
}
1352+
case 5: {
1353+
int32_t v3 = func_a();
1354+
break;
1355+
}
1356+
default: {
1357+
float v4 = 4.200000000e+01f;
1358+
func2(v4);
1359+
break;
1360+
}
1361+
```
1362+
}];
1363+
1364+
let arguments = (ins IntegerIndexOrOpaqueType:$arg, DenseI64ArrayAttr:$cases);
1365+
let results = (outs);
1366+
let regions = (region SizedRegion<1>:$defaultRegion,
1367+
VariadicRegion<SizedRegion<1>>:$caseRegions);
1368+
1369+
let assemblyFormat = [{
1370+
$arg `:` type($arg) attr-dict custom<SwitchCases>($cases, $caseRegions) `\n`
1371+
`` `default` $defaultRegion
1372+
}];
1373+
1374+
let extraClassDeclaration = [{
1375+
/// Get the number of cases.
1376+
unsigned getNumCases();
1377+
1378+
/// Get the default region body.
1379+
Block &getDefaultBlock();
1380+
1381+
/// Get the body of a case region.
1382+
Block &getCaseBlock(unsigned idx);
1383+
}];
1384+
1385+
let hasVerifier = 1;
1386+
}
13051387

13061388
#endif // MLIR_DIALECT_EMITC_IR_EMITC

mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,19 @@ static void lowerYield(SmallVector<Value> &resultVariables,
9494
rewriter.eraseOp(yield);
9595
}
9696

97+
// Lower the contents of an scf::if/scf::index_switch regions to an
98+
// emitc::if/emitc::switch region. The contents of the lowering region is
99+
// moved into the respective lowered region, but the scf::yield is replaced not
100+
// only with an emitc::yield, but also with a sequence of emitc::assign ops that
101+
// set the yielded values into the result variables.
102+
static void lowerRegion(SmallVector<Value> &resultVariables,
103+
PatternRewriter &rewriter, Region &region,
104+
Region &loweredRegion) {
105+
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
106+
Operation *terminator = loweredRegion.back().getTerminator();
107+
lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
108+
}
109+
97110
LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
98111
PatternRewriter &rewriter) const {
99112
Location loc = forOp.getLoc();
@@ -145,18 +158,6 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
145158
SmallVector<Value> resultVariables =
146159
createVariablesForResults(ifOp, rewriter);
147160

148-
// Utility function to lower the contents of an scf::if region to an emitc::if
149-
// region. The contents of the scf::if regions is moved into the respective
150-
// emitc::if regions, but the scf::yield is replaced not only with an
151-
// emitc::yield, but also with a sequence of emitc::assign ops that set the
152-
// yielded values into the result variables.
153-
auto lowerRegion = [&resultVariables, &rewriter](Region &region,
154-
Region &loweredRegion) {
155-
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
156-
Operation *terminator = loweredRegion.back().getTerminator();
157-
lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
158-
};
159-
160161
Region &thenRegion = ifOp.getThenRegion();
161162
Region &elseRegion = ifOp.getElseRegion();
162163

@@ -166,20 +167,59 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
166167
rewriter.create<emitc::IfOp>(loc, ifOp.getCondition(), false, false);
167168

168169
Region &loweredThenRegion = loweredIf.getThenRegion();
169-
lowerRegion(thenRegion, loweredThenRegion);
170+
lowerRegion(resultVariables, rewriter, thenRegion, loweredThenRegion);
170171

171172
if (hasElseBlock) {
172173
Region &loweredElseRegion = loweredIf.getElseRegion();
173-
lowerRegion(elseRegion, loweredElseRegion);
174+
lowerRegion(resultVariables, rewriter, elseRegion, loweredElseRegion);
174175
}
175176

176177
rewriter.replaceOp(ifOp, resultVariables);
177178
return success();
178179
}
179180

181+
// Lower scf::index_switch to emitc::switch, implementing result values as
182+
// emitc::variable's updated within the case and default regions.
183+
struct IndexSwitchOpLowering : public OpRewritePattern<IndexSwitchOp> {
184+
using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
185+
186+
LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp,
187+
PatternRewriter &rewriter) const override;
188+
};
189+
190+
LogicalResult
191+
IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp,
192+
PatternRewriter &rewriter) const {
193+
Location loc = indexSwitchOp.getLoc();
194+
195+
// Create an emitc::variable op for each result. These variables will be
196+
// assigned to by emitc::assign ops within the case and default regions.
197+
SmallVector<Value> resultVariables =
198+
createVariablesForResults(indexSwitchOp, rewriter);
199+
200+
auto loweredSwitch = rewriter.create<emitc::SwitchOp>(
201+
loc, indexSwitchOp.getArg(), indexSwitchOp.getCases(),
202+
indexSwitchOp.getNumCases());
203+
204+
// Lowering all case regions.
205+
for (auto pair : llvm::zip(indexSwitchOp.getCaseRegions(),
206+
loweredSwitch.getCaseRegions())) {
207+
lowerRegion(resultVariables, rewriter, std::get<0>(pair),
208+
std::get<1>(pair));
209+
}
210+
211+
// Lowering default region.
212+
lowerRegion(resultVariables, rewriter, indexSwitchOp.getDefaultRegion(),
213+
loweredSwitch.getDefaultRegion());
214+
215+
rewriter.replaceOp(indexSwitchOp, resultVariables);
216+
return success();
217+
}
218+
180219
void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) {
181220
patterns.add<ForLowering>(patterns.getContext());
182221
patterns.add<IfLowering>(patterns.getContext());
222+
patterns.add<IndexSwitchOpLowering>(patterns.getContext());
183223
}
184224

185225
void SCFToEmitCPass::runOnOperation() {
@@ -188,7 +228,7 @@ void SCFToEmitCPass::runOnOperation() {
188228

189229
// Configure conversion to lower out SCF operations.
190230
ConversionTarget target(getContext());
191-
target.addIllegalOp<scf::ForOp, scf::IfOp>();
231+
target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>();
192232
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
193233
if (failed(
194234
applyPartialConversion(getOperation(), target, std::move(patterns))))

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,138 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
10961096
return success();
10971097
}
10981098

1099+
//===----------------------------------------------------------------------===//
1100+
// SwitchOp
1101+
//===----------------------------------------------------------------------===//
1102+
1103+
/// Parse the case regions and values.
1104+
static ParseResult
1105+
parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases,
1106+
SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
1107+
SmallVector<int64_t> caseValues;
1108+
while (succeeded(parser.parseOptionalKeyword("case"))) {
1109+
int64_t value;
1110+
Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
1111+
if (parser.parseInteger(value) ||
1112+
parser.parseRegion(region, /*arguments=*/{}))
1113+
return failure();
1114+
caseValues.push_back(value);
1115+
}
1116+
cases = parser.getBuilder().getDenseI64ArrayAttr(caseValues);
1117+
return success();
1118+
}
1119+
1120+
/// Print the case regions and values.
1121+
static void printSwitchCases(OpAsmPrinter &p, Operation *op,
1122+
DenseI64ArrayAttr cases, RegionRange caseRegions) {
1123+
for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
1124+
p.printNewline();
1125+
p << "case " << value << ' ';
1126+
p.printRegion(*region, /*printEntryBlockArgs=*/false);
1127+
}
1128+
}
1129+
1130+
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region,
1131+
const Twine &name) {
1132+
auto yield = dyn_cast<emitc::YieldOp>(region.front().back());
1133+
if (!yield)
1134+
return op.emitOpError("expected region to end with emitc.yield, but got ")
1135+
<< region.front().back().getName();
1136+
1137+
if (yield.getNumOperands() != 0) {
1138+
return (op.emitOpError("expected each region to return ")
1139+
<< "0 values, but " << name << " returns "
1140+
<< yield.getNumOperands())
1141+
.attachNote(yield.getLoc())
1142+
<< "see yield operation here";
1143+
}
1144+
1145+
return success();
1146+
}
1147+
1148+
LogicalResult emitc::SwitchOp::verify() {
1149+
if (!isIntegerIndexOrOpaqueType(getArg().getType()))
1150+
return emitOpError("unsupported type ") << getArg().getType();
1151+
1152+
if (getCases().size() != getCaseRegions().size()) {
1153+
return emitOpError("has ")
1154+
<< getCaseRegions().size() << " case regions but "
1155+
<< getCases().size() << " case values";
1156+
}
1157+
1158+
DenseSet<int64_t> valueSet;
1159+
for (int64_t value : getCases())
1160+
if (!valueSet.insert(value).second)
1161+
return emitOpError("has duplicate case value: ") << value;
1162+
1163+
if (failed(verifyRegion(*this, getDefaultRegion(), "default region")))
1164+
return failure();
1165+
1166+
for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
1167+
if (failed(verifyRegion(*this, caseRegion, "case region #" + Twine(idx))))
1168+
return failure();
1169+
1170+
return success();
1171+
}
1172+
1173+
unsigned emitc::SwitchOp::getNumCases() { return getCases().size(); }
1174+
1175+
Block &emitc::SwitchOp::getDefaultBlock() { return getDefaultRegion().front(); }
1176+
1177+
Block &emitc::SwitchOp::getCaseBlock(unsigned idx) {
1178+
assert(idx < getNumCases() && "case index out-of-bounds");
1179+
return getCaseRegions()[idx].front();
1180+
}
1181+
1182+
void SwitchOp::getSuccessorRegions(
1183+
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
1184+
llvm::copy(getRegions(), std::back_inserter(successors));
1185+
}
1186+
1187+
void SwitchOp::getEntrySuccessorRegions(
1188+
ArrayRef<Attribute> operands,
1189+
SmallVectorImpl<RegionSuccessor> &successors) {
1190+
FoldAdaptor adaptor(operands, *this);
1191+
1192+
// If a constant was not provided, all regions are possible successors.
1193+
auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
1194+
if (!arg) {
1195+
llvm::copy(getRegions(), std::back_inserter(successors));
1196+
return;
1197+
}
1198+
1199+
// Otherwise, try to find a case with a matching value. If not, the
1200+
// default region is the only successor.
1201+
for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
1202+
if (caseValue == arg.getInt()) {
1203+
successors.emplace_back(&caseRegion);
1204+
return;
1205+
}
1206+
}
1207+
successors.emplace_back(&getDefaultRegion());
1208+
}
1209+
1210+
void SwitchOp::getRegionInvocationBounds(
1211+
ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
1212+
auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
1213+
if (!operandValue) {
1214+
// All regions are invoked at most once.
1215+
bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
1216+
return;
1217+
}
1218+
1219+
unsigned liveIndex = getNumRegions() - 1;
1220+
const auto *iteratorToInt = llvm::find(getCases(), operandValue.getInt());
1221+
1222+
liveIndex = iteratorToInt != getCases().end()
1223+
? std::distance(getCases().begin(), iteratorToInt)
1224+
: liveIndex;
1225+
1226+
for (unsigned regIndex = 0, regNum = getNumRegions(); regIndex < regNum;
1227+
++regIndex)
1228+
bounds.emplace_back(/*lb=*/0, /*ub=*/regIndex == liveIndex);
1229+
}
1230+
10991231
//===----------------------------------------------------------------------===//
11001232
// TableGen'd op method definitions
11011233
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)