Skip to content

Commit 11a57c3

Browse files
committed
[mlir][emitc] Add 'emitc.switch' op to the dialect
1 parent ad8a2e4 commit 11a57c3

File tree

6 files changed

+641
-4
lines changed

6 files changed

+641
-4
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ bool isSupportedFloatType(mlir::Type type);
4747
/// Determines whether \p type is a emitc.size_t/ssize_t type.
4848
bool isPointerWideType(mlir::Type type);
4949

50+
/// Determines whether \p type is a valid integer or opaque type for SwitchOp.
51+
bool isSwitchOperandType(Type type);
52+
5053
} // namespace emitc
5154
} // namespace mlir
5255

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ def CExpression : NativeOpTrait<"emitc::CExpression">;
5454
def IntegerIndexOrOpaqueType : Type<CPred<"emitc::isIntegerIndexOrOpaqueType($_self)">,
5555
"integer, index or opaque type supported by EmitC">;
5656
def FloatIntegerIndexOrOpaqueType : AnyTypeOf<[EmitCFloatType, IntegerIndexOrOpaqueType]>;
57+
def SwitchOperandType : Type<CPred<"emitc::isSwitchOperandType($_self)">,
58+
"integer type for switch operation">;
5759

5860
def EmitC_AddOp : EmitC_BinaryOp<"add", [CExpression]> {
5961
let summary = "Addition operation";
@@ -1188,7 +1190,7 @@ def EmitC_AssignOp : EmitC_Op<"assign", []> {
11881190
}
11891191

11901192
def EmitC_YieldOp : EmitC_Op<"yield",
1191-
[Pure, Terminator, ParentOneOf<["ExpressionOp", "IfOp", "ForOp"]>]> {
1193+
[Pure, Terminator, ParentOneOf<["ExpressionOp", "IfOp", "ForOp", "SwitchOp"]>]> {
11921194
let summary = "Block termination operation";
11931195
let description = [{
11941196
The `emitc.yield` terminates its parent EmitC op's region, optionally yielding
@@ -1302,5 +1304,90 @@ def EmitC_SubscriptOp : EmitC_Op<"subscript", []> {
13021304
let assemblyFormat = "$value `[` $indices `]` attr-dict `:` functional-type(operands, results)";
13031305
}
13041306

1307+
def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
1308+
SingleBlockImplicitTerminator<"emitc::YieldOp">,
1309+
DeclareOpInterfaceMethods<RegionBranchOpInterface,
1310+
["getRegionInvocationBounds",
1311+
"getEntrySuccessorRegions"]>]> {
1312+
let summary = "Switch operation";
1313+
let description = [{
1314+
The `emitc.switch` is a control-flow operation that branches to one of
1315+
the given regions based on the values of the argument and the cases. The
1316+
argument is always of type opaque or integer (singed or unsigned), excluding i8
1317+
and i1.
1318+
1319+
The operation always has a "default" region and any number of case regions
1320+
denoted by integer constants. Control-flow transfers to the case region
1321+
whose constant value equals the value of the argument. If the argument does
1322+
not equal any of the case values, control-flow transfer to the "default"
1323+
region.
1324+
1325+
The operation does not return any value. Moreover, case regions and
1326+
default region must be terminated using the `emitc.yield` operation.
1327+
1328+
Example:
1329+
1330+
```mlir
1331+
// Cases with i32 type.
1332+
%0 = "emitc.variable"(){value = 42 : i32} : () -> i32
1333+
emitc.switch %0 : i32
1334+
case 2: {
1335+
%1 = emitc.call_opaque "func_b" () : () -> i32
1336+
emitc.yield
1337+
}
1338+
case 5: {
1339+
%2 = emitc.call_opaque "func_a" () : () -> i32
1340+
emitc.yield
1341+
}
1342+
default : {
1343+
%3 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
1344+
%4 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
1345+
1346+
emitc.call_opaque "func2" (%3) : (f32) -> ()
1347+
emitc.call_opaque "func3" (%3, %4) { args = [1 : index, 0 : index] } : (f32, f32) -> ()
1348+
emitc.yield
1349+
}
1350+
...
1351+
// Cases with i16 type.
1352+
%0 = "emitc.variable"(){value = 42 : i16} : () -> i16
1353+
emitc.switch %0 : i16
1354+
case 2: {
1355+
%1 = emitc.call_opaque "func_b" () : () -> i32
1356+
emitc.yield
1357+
}
1358+
case 5: {
1359+
%2 = emitc.call_opaque "func_a" () : () -> i32
1360+
emitc.yield
1361+
}
1362+
default : {
1363+
%3 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
1364+
%4 = "emitc.variable"(){value = 42.0 : f32} : () -> f32
1365+
1366+
emitc.call_opaque "func2" (%3) : (f32) -> ()
1367+
emitc.call_opaque "func3" (%3, %4) { args = [1 : index, 0 : index] } : (f32, f32) -> ()
1368+
emitc.yield
1369+
}
1370+
```
1371+
}];
1372+
1373+
let arguments = (ins SwitchOperandType:$arg, DenseI64ArrayAttr:$cases);
1374+
let results = (outs);
1375+
let regions = (region SizedRegion<1>:$defaultRegion,
1376+
VariadicRegion<SizedRegion<1>>:$caseRegions);
1377+
1378+
let extraClassDeclaration = [{
1379+
/// Get the number of cases.
1380+
unsigned getNumCases();
1381+
1382+
/// Get the default region body.
1383+
Block &getDefaultBlock();
1384+
1385+
/// Get the body of a case region.
1386+
Block &getCaseBlock(unsigned idx);
1387+
}];
1388+
1389+
let hasCustomAssemblyFormat = 1;
1390+
let hasVerifier = 1;
1391+
}
13051392

13061393
#endif // MLIR_DIALECT_EMITC_IR_EMITC

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,13 @@ bool mlir::emitc::isPointerWideType(Type type) {
131131
type);
132132
}
133133

134+
bool mlir::emitc::isSwitchOperandType(Type type) {
135+
auto intType = llvm::dyn_cast<IntegerType>(type);
136+
return isa<emitc::OpaqueType>(type) ||
137+
(isSupportedIntegerType(type) && intType.getWidth() != 1 &&
138+
intType.getWidth() != 8);
139+
}
140+
134141
/// Check that the type of the initial value is compatible with the operations
135142
/// result type.
136143
static LogicalResult verifyInitializationAttribute(Operation *op,
@@ -1096,6 +1103,205 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
10961103
return success();
10971104
}
10981105

1106+
//===----------------------------------------------------------------------===//
1107+
// SwitchOp
1108+
//===----------------------------------------------------------------------===//
1109+
1110+
/// Parse the case regions and values.
1111+
static ParseResult
1112+
parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases,
1113+
SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
1114+
SmallVector<int64_t> caseValues;
1115+
while (succeeded(parser.parseOptionalKeyword("case"))) {
1116+
int64_t value;
1117+
Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
1118+
1119+
if (parser.parseInteger(value) || parser.parseColon() ||
1120+
parser.parseRegion(region, /*arguments=*/{}))
1121+
return failure();
1122+
caseValues.push_back(value);
1123+
}
1124+
cases = parser.getBuilder().getDenseI64ArrayAttr(caseValues);
1125+
return success();
1126+
}
1127+
1128+
/// Print the case regions and values.
1129+
static void printSwitchCases(OpAsmPrinter &parser, Operation *op,
1130+
DenseI64ArrayAttr cases, RegionRange caseRegions) {
1131+
for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
1132+
parser.printNewline();
1133+
parser << "case " << value << ": ";
1134+
parser.printRegion(*region, /*printEntryBlockArgs=*/false);
1135+
}
1136+
return;
1137+
}
1138+
1139+
ParseResult SwitchOp::parse(OpAsmParser &parser, OperationState &result) {
1140+
OpAsmParser::UnresolvedOperand arg;
1141+
DenseI64ArrayAttr casesAttr;
1142+
SmallVector<std::unique_ptr<Region>, 2> caseRegionsRegions;
1143+
std::unique_ptr<Region> defaultRegionRegion = std::make_unique<Region>();
1144+
1145+
if (parser.parseOperand(arg))
1146+
return failure();
1147+
1148+
Type argType;
1149+
// Parse the case's type.
1150+
if (parser.parseColon() || parser.parseType(argType))
1151+
return failure();
1152+
1153+
auto loc = parser.getCurrentLocation();
1154+
if (parser.parseOptionalAttrDict(result.attributes))
1155+
return failure();
1156+
1157+
if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
1158+
return parser.emitError(loc)
1159+
<< "'" << result.name.getStringRef() << "' op ";
1160+
})))
1161+
return failure();
1162+
1163+
auto odsResult = parseSwitchCases(parser, casesAttr, caseRegionsRegions);
1164+
if (odsResult)
1165+
return failure();
1166+
1167+
result.getOrAddProperties<SwitchOp::Properties>().cases = casesAttr;
1168+
1169+
if (parser.parseKeyword("default") || parser.parseColon())
1170+
return failure();
1171+
1172+
if (parser.parseRegion(*defaultRegionRegion))
1173+
return failure();
1174+
1175+
result.addRegion(std::move(defaultRegionRegion));
1176+
result.addRegions(caseRegionsRegions);
1177+
1178+
if (parser.resolveOperand(arg, argType, result.operands))
1179+
return failure();
1180+
1181+
return success();
1182+
}
1183+
1184+
void SwitchOp::print(OpAsmPrinter &parser) {
1185+
parser << ' ';
1186+
parser << getArg();
1187+
SmallVector<StringRef, 2> elidedAttrs;
1188+
elidedAttrs.push_back("cases");
1189+
parser.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
1190+
parser << ' ';
1191+
printSwitchCases(parser, *this, getCasesAttr(), getCaseRegions());
1192+
parser.printNewline();
1193+
parser << "default";
1194+
parser << ' ';
1195+
parser.printRegion(getDefaultRegion(), /*printEntryBlockArgs=*/true,
1196+
/*printBlockTerminators=*/true);
1197+
1198+
return;
1199+
}
1200+
1201+
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region,
1202+
const Twine &name) {
1203+
auto yield = dyn_cast<emitc::YieldOp>(region.front().back());
1204+
if (!yield)
1205+
return op.emitOpError("expected region to end with emitc.yield, but got ")
1206+
<< region.front().back().getName();
1207+
1208+
if (yield.getNumOperands() != 0) {
1209+
return (op.emitOpError("expected each region to return ")
1210+
<< "0 values, but " << name << " returns "
1211+
<< yield.getNumOperands())
1212+
.attachNote(yield.getLoc())
1213+
<< "see yield operation here";
1214+
}
1215+
return success();
1216+
}
1217+
1218+
LogicalResult emitc::SwitchOp::verify() {
1219+
if (!isSwitchOperandType(getArg().getType()))
1220+
return emitOpError("unsupported type ") << getArg().getType();
1221+
1222+
if (getCases().size() != getCaseRegions().size()) {
1223+
return emitOpError("has ")
1224+
<< getCaseRegions().size() << " case regions but "
1225+
<< getCases().size() << " case values";
1226+
}
1227+
1228+
DenseSet<int64_t> valueSet;
1229+
for (int64_t value : getCases())
1230+
if (!valueSet.insert(value).second)
1231+
return emitOpError("has duplicate case value: ") << value;
1232+
1233+
if (failed(verifyRegion(*this, getDefaultRegion(), "default region")))
1234+
return failure();
1235+
1236+
for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
1237+
if (failed(verifyRegion(*this, caseRegion, "case region #" + Twine(idx))))
1238+
return failure();
1239+
1240+
return success();
1241+
}
1242+
1243+
unsigned emitc::SwitchOp::getNumCases() { return getCases().size(); }
1244+
1245+
Block &emitc::SwitchOp::getDefaultBlock() { return getDefaultRegion().front(); }
1246+
1247+
Block &emitc::SwitchOp::getCaseBlock(unsigned idx) {
1248+
assert(idx < getNumCases() && "case index out-of-bounds");
1249+
return getCaseRegions()[idx].front();
1250+
}
1251+
1252+
void SwitchOp::getSuccessorRegions(
1253+
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
1254+
llvm::copy(getRegions(), std::back_inserter(successors));
1255+
return;
1256+
}
1257+
1258+
void SwitchOp::getEntrySuccessorRegions(
1259+
ArrayRef<Attribute> operands,
1260+
SmallVectorImpl<RegionSuccessor> &successors) {
1261+
FoldAdaptor adaptor(operands, *this);
1262+
1263+
// If a constant was not provided, all regions are possible successors.
1264+
auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
1265+
if (!arg) {
1266+
llvm::copy(getRegions(), std::back_inserter(successors));
1267+
return;
1268+
}
1269+
1270+
// Otherwise, try to find a case with a matching value. If not, the
1271+
// default region is the only successor.
1272+
for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
1273+
if (caseValue == arg.getInt()) {
1274+
successors.emplace_back(&caseRegion);
1275+
return;
1276+
}
1277+
}
1278+
successors.emplace_back(&getDefaultRegion());
1279+
return;
1280+
}
1281+
1282+
void SwitchOp::getRegionInvocationBounds(
1283+
ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
1284+
auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
1285+
if (!operandValue) {
1286+
// All regions are invoked at most once.
1287+
bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
1288+
return;
1289+
}
1290+
1291+
unsigned liveIndex = getNumRegions() - 1;
1292+
const auto *iteratorToInt = llvm::find(getCases(), operandValue.getInt());
1293+
1294+
liveIndex = iteratorToInt != getCases().end()
1295+
? std::distance(getCases().begin(), iteratorToInt)
1296+
: liveIndex;
1297+
1298+
for (unsigned regIndex = 0, regNum = getNumRegions(); regIndex < regNum;
1299+
++regIndex)
1300+
bounds.emplace_back(/*lb=*/0, /*ub=*/regIndex == liveIndex);
1301+
1302+
return;
1303+
}
1304+
10991305
//===----------------------------------------------------------------------===//
11001306
// TableGen'd op method definitions
11011307
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)