Skip to content

Commit 98b086a

Browse files
committed
[mlir][emitc] Add 'emitc.switch' op to the dialect
1 parent ad8a2e4 commit 98b086a

File tree

6 files changed

+600
-4
lines changed

6 files changed

+600
-4
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ bool isSupportedFloatType(mlir::Type type);
4747
/// Determines whether \p type is a emitc.size_t/ssize_t type.
4848
bool isPointerWideType(mlir::Type type);
4949

50+
/// Determines whether \p type is a valid integer type for SwitchOp.
51+
bool isSwitchOperandType(Type type);
52+
5053
} // namespace emitc
5154
} // namespace mlir
5255

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def CExpression : NativeOpTrait<"emitc::CExpression">;
5454
def IntegerIndexOrOpaqueType : Type<CPred<"emitc::isIntegerIndexOrOpaqueType($_self)">,
5555
"integer, index or opaque type supported by EmitC">;
5656
def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[EmitCFloatType, IntegerIndexOrOpaqueType]>;
57+
def SwitchOperandType : Type<CPred<"emitc::isSwitchOperandType($_self)">,
58+
"integer type for switch operation">;
5759

5860
def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> {
5961
let summary = "Addition operation";
@@ -1188,7 +1190,7 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
11881190
}
11891191

11901192
def EmitC_YieldOp : EmitC_Op<"yield",
1191-
[Pure, Terminator, ParentOneOf<["ExpressionOp", "IfOp", "ForOp"]>]> {
1193+
[Pure, Terminator, ParentOneOf<["ExpressionOp", "IfOp", "ForOp", "SwitchOp"]>]> {
11921194
let summary = "Block termination operation";
11931195
let description = [{
11941196
The `emitc.yield` terminates its parent EmitC op's region, optionally yielding
@@ -1302,5 +1304,90 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
13021304
let assemblyFormat = "$value `[` $indices `]` attr-dict `:` functional-type(operands, results)";
13031305
}
13041306

1307+
def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
1308+
SingleBlockImplicitTerminator<"emitc::YieldOp">,
1309+
DeclareOpInterfaceMethods<RegionBranchOpInterface,
1310+
["getRegionInvocationBounds",
1311+
"getEntrySuccessorRegions"]>]> {
1312+
let summary = "Switch operation";
1313+
let description = [{
1314+
The `emitc.switch` is a control-flow operation that branches to one of
1315+
the given regions based on the values of the argument and the cases. The
1316+
argument is always of type integer (singed or unsigned), excluding i8 and i1.
1317+
If the type is not specified, then i32 will be used by default.
1318+
1319+
The operation always has a "default" region and any number of case regions
1320+
denoted by integer constants. Control-flow transfers to the case region
1321+
whose constant value equals the value of the argument. If the argument does
1322+
not equal any of the case values, control-flow transfer to the "default"
1323+
region.
1324+
1325+
The operation does not return any value. Moreover, case regions and
1326+
default region must be terminated using the `emitc.yield` operation.
1327+
1328+
Example:
1329+
1330+
```mlir
1331+
// Cases with i32 type.
1332+
%0 = "emitc.variable"(){value = 42 : i32} : () -> i32
1333+
emitc.switch %0 : i32
1334+
case 2: {
1335+
%1 = emitc.call_opaque "func_b" () : () -> i32
1336+
emitc.yield
1337+
}
1338+
case 5: {
1339+
%2 = emitc.call_opaque "func_a" () : () -> i32
1340+
emitc.yield
1341+
}
1342+
default : {
1343+
%3 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
1344+
%4 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
1345+
1346+
emitc.call_opaque "func2" (%3) : (f32) -> ()
1347+
emitc.call_opaque "func3" (%3, %4) { args = [1 : index, 0 : index] } : (f32, f32) -> ()
1348+
emitc.yield
1349+
}
1350+
...
1351+
// Cases with i16 type.
1352+
%0 = "emitc.variable"(){value = 42 : i16} : () -> i16
1353+
emitc.switch %0 : i16
1354+
case 2: {
1355+
%1 = emitc.call_opaque "func_b" () : () -> i32
1356+
emitc.yield
1357+
}
1358+
case 5: {
1359+
%2 = emitc.call_opaque "func_a" () : () -> i32
1360+
emitc.yield
1361+
}
1362+
default : {
1363+
%3 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
1364+
%4 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
1365+
1366+
emitc.call_opaque "func2" (%3) : (f32) -> ()
1367+
emitc.call_opaque "func3" (%3, %4) { args = [1 : index, 0 : index] } : (f32, f32) -> ()
1368+
emitc.yield
1369+
}
1370+
```
1371+
}];
1372+
1373+
let arguments = (ins SwitchOperandType:$arg, DenseI64ArrayAttr:$cases);
1374+
let results = (outs);
1375+
let regions = (region SizedRegion<1>:$defaultRegion,
1376+
VariadicRegion<SizedRegion<1>>:$caseRegions);
1377+
1378+
let extraClassDeclaration = [{
1379+
/// Get the number of cases.
1380+
unsigned getNumCases();
1381+
1382+
/// Get the default region body.
1383+
Block &getDefaultBlock();
1384+
1385+
/// Get the body of a case region.
1386+
Block &getCaseBlock(unsigned idx);
1387+
}];
1388+
1389+
let hasCustomAssemblyFormat = 1;
1390+
let hasVerifier = 1;
1391+
}
13051392

13061393
#endif // MLIR_DIALECT_EMITC_IR_EMITC

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,12 @@ bool mlir::emitc::isPointerWideType(Type type) {
131131
type);
132132
}
133133

134+
bool mlir::emitc::isSwitchOperandType(Type type) {
135+
auto intType = llvm::dyn_cast<IntegerType>(type);
136+
return isSupportedIntegerType(type) && intType.getWidth() != 1 &&
137+
intType.getWidth() != 8;
138+
}
139+
134140
/// Check that the type of the initial value is compatible with the operations
135141
/// result type.
136142
static LogicalResult verifyInitializationAttribute(Operation *op,
@@ -1096,6 +1102,205 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
10961102
return success();
10971103
}
10981104

1105+
//===----------------------------------------------------------------------===//
1106+
// SwitchOp
1107+
//===----------------------------------------------------------------------===//
1108+
1109+
/// Parse the case regions and values.
1110+
static ParseResult
1111+
parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases,
1112+
SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
1113+
SmallVector<int64_t> caseValues;
1114+
while (succeeded(parser.parseOptionalKeyword("case"))) {
1115+
int64_t value;
1116+
Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
1117+
1118+
if (parser.parseInteger(value) || parser.parseColon() ||
1119+
parser.parseRegion(region, /*arguments=*/{}))
1120+
return failure();
1121+
caseValues.push_back(value);
1122+
}
1123+
cases = parser.getBuilder().getDenseI64ArrayAttr(caseValues);
1124+
return success();
1125+
}
1126+
1127+
/// Print the case regions and values.
1128+
static void printSwitchCases(OpAsmPrinter &parser, Operation *op,
1129+
DenseI64ArrayAttr cases, RegionRange caseRegions) {
1130+
for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
1131+
parser.printNewline();
1132+
parser << "case " << value << ": ";
1133+
parser.printRegion(*region, /*printEntryBlockArgs=*/false);
1134+
}
1135+
return;
1136+
}
1137+
1138+
ParseResult SwitchOp::parse(OpAsmParser &parser, OperationState &result) {
1139+
OpAsmParser::UnresolvedOperand arg;
1140+
DenseI64ArrayAttr casesAttr;
1141+
SmallVector<std::unique_ptr<Region>, 2> caseRegionsRegions;
1142+
std::unique_ptr<Region> defaultRegionRegion = std::make_unique<Region>();
1143+
1144+
if (parser.parseOperand(arg))
1145+
return failure();
1146+
1147+
Type argType = parser.getBuilder().getI32Type();
1148+
// Parse optional type, else assume i32.
1149+
if (!parser.parseOptionalColon() && parser.parseType(argType))
1150+
return failure();
1151+
1152+
auto loc = parser.getCurrentLocation();
1153+
if (parser.parseOptionalAttrDict(result.attributes))
1154+
return failure();
1155+
1156+
if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
1157+
return parser.emitError(loc)
1158+
<< "'" << result.name.getStringRef() << "' op ";
1159+
})))
1160+
return failure();
1161+
1162+
auto odsResult = parseSwitchCases(parser, casesAttr, caseRegionsRegions);
1163+
if (odsResult)
1164+
return failure();
1165+
1166+
result.getOrAddProperties<SwitchOp::Properties>().cases = casesAttr;
1167+
1168+
if (parser.parseKeyword("default") || parser.parseColon())
1169+
return failure();
1170+
1171+
if (parser.parseRegion(*defaultRegionRegion))
1172+
return failure();
1173+
1174+
result.addRegion(std::move(defaultRegionRegion));
1175+
result.addRegions(caseRegionsRegions);
1176+
1177+
if (parser.resolveOperand(arg, argType, result.operands))
1178+
return failure();
1179+
1180+
return success();
1181+
}
1182+
1183+
void SwitchOp::print(OpAsmPrinter &parser) {
1184+
parser << ' ';
1185+
parser << getArg();
1186+
SmallVector<StringRef, 2> elidedAttrs;
1187+
elidedAttrs.push_back("cases");
1188+
parser.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
1189+
parser << ' ';
1190+
printSwitchCases(parser, *this, getCasesAttr(), getCaseRegions());
1191+
parser.printNewline();
1192+
parser << "default";
1193+
parser << ' ';
1194+
parser.printRegion(getDefaultRegion(), /*printEntryBlockArgs=*/true,
1195+
/*printBlockTerminators=*/true);
1196+
1197+
return;
1198+
}
1199+
1200+
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region,
1201+
const Twine &name) {
1202+
auto yield = dyn_cast<emitc::YieldOp>(region.front().back());
1203+
if (!yield)
1204+
return op.emitOpError("expected region to end with emitc.yield, but got ")
1205+
<< region.front().back().getName();
1206+
1207+
if (yield.getNumOperands() != 0) {
1208+
return (op.emitOpError("expected each region to return ")
1209+
<< "0 values, but " << name << " returns "
1210+
<< yield.getNumOperands())
1211+
.attachNote(yield.getLoc())
1212+
<< "see yield operation here";
1213+
}
1214+
return success();
1215+
}
1216+
1217+
LogicalResult emitc::SwitchOp::verify() {
1218+
if (!isSwitchOperandType(getArg().getType()))
1219+
return emitOpError("unsupported type ") << getArg().getType();
1220+
1221+
if (getCases().size() != getCaseRegions().size()) {
1222+
return emitOpError("has ")
1223+
<< getCaseRegions().size() << " case regions but "
1224+
<< getCases().size() << " case values";
1225+
}
1226+
1227+
DenseSet<int64_t> valueSet;
1228+
for (int64_t value : getCases())
1229+
if (!valueSet.insert(value).second)
1230+
return emitOpError("has duplicate case value: ") << value;
1231+
1232+
if (failed(verifyRegion(*this, getDefaultRegion(), "default region")))
1233+
return failure();
1234+
1235+
for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
1236+
if (failed(verifyRegion(*this, caseRegion, "case region #" + Twine(idx))))
1237+
return failure();
1238+
1239+
return success();
1240+
}
1241+
1242+
unsigned emitc::SwitchOp::getNumCases() { return getCases().size(); }
1243+
1244+
Block &emitc::SwitchOp::getDefaultBlock() { return getDefaultRegion().front(); }
1245+
1246+
Block &emitc::SwitchOp::getCaseBlock(unsigned idx) {
1247+
assert(idx < getNumCases() && "case index out-of-bounds");
1248+
return getCaseRegions()[idx].front();
1249+
}
1250+
1251+
void SwitchOp::getSuccessorRegions(
1252+
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
1253+
llvm::copy(getRegions(), std::back_inserter(successors));
1254+
return;
1255+
}
1256+
1257+
void SwitchOp::getEntrySuccessorRegions(
1258+
ArrayRef<Attribute> operands,
1259+
SmallVectorImpl<RegionSuccessor> &successors) {
1260+
FoldAdaptor adaptor(operands, *this);
1261+
1262+
// If a constant was not provided, all regions are possible successors.
1263+
auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
1264+
if (!arg) {
1265+
llvm::copy(getRegions(), std::back_inserter(successors));
1266+
return;
1267+
}
1268+
1269+
// Otherwise, try to find a case with a matching value. If not, the
1270+
// default region is the only successor.
1271+
for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
1272+
if (caseValue == arg.getInt()) {
1273+
successors.emplace_back(&caseRegion);
1274+
return;
1275+
}
1276+
}
1277+
successors.emplace_back(&getDefaultRegion());
1278+
return;
1279+
}
1280+
1281+
void SwitchOp::getRegionInvocationBounds(
1282+
ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
1283+
auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
1284+
if (!operandValue) {
1285+
// All regions are invoked at most once.
1286+
bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
1287+
return;
1288+
}
1289+
1290+
unsigned liveIndex = getNumRegions() - 1;
1291+
const auto *iteratorToInt = llvm::find(getCases(), operandValue.getInt());
1292+
1293+
liveIndex = iteratorToInt != getCases().end()
1294+
? std::distance(getCases().begin(), iteratorToInt)
1295+
: liveIndex;
1296+
1297+
for (unsigned regIndex = 0, regNum = getNumRegions(); regIndex < regNum;
1298+
++regIndex)
1299+
bounds.emplace_back(/*lb=*/0, /*ub=*/regIndex == liveIndex);
1300+
1301+
return;
1302+
}
1303+
10991304
//===----------------------------------------------------------------------===//
11001305
// TableGen'd op method definitions
11011306
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)