@@ -131,6 +131,13 @@ bool mlir::emitc::isPointerWideType(Type type) {
131
131
type);
132
132
}
133
133
134
+ bool mlir::emitc::isSwitchOperandType (Type type) {
135
+ auto intType = llvm::dyn_cast<IntegerType>(type);
136
+ return isa<emitc::OpaqueType>(type) ||
137
+ (isSupportedIntegerType (type) && intType.getWidth () != 1 &&
138
+ intType.getWidth () != 8 );
139
+ }
140
+
134
141
// / Check that the type of the initial value is compatible with the operations
135
142
// / result type.
136
143
static LogicalResult verifyInitializationAttribute (Operation *op,
@@ -1096,6 +1103,205 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1096
1103
return success ();
1097
1104
}
1098
1105
1106
+ // ===----------------------------------------------------------------------===//
1107
+ // SwitchOp
1108
+ // ===----------------------------------------------------------------------===//
1109
+
1110
+ // / Parse the case regions and values.
1111
+ static ParseResult
1112
+ parseSwitchCases (OpAsmParser &parser, DenseI64ArrayAttr &cases,
1113
+ SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
1114
+ SmallVector<int64_t > caseValues;
1115
+ while (succeeded (parser.parseOptionalKeyword (" case" ))) {
1116
+ int64_t value;
1117
+ Region ®ion = *caseRegions.emplace_back (std::make_unique<Region>());
1118
+
1119
+ if (parser.parseInteger (value) || parser.parseColon () ||
1120
+ parser.parseRegion (region, /* arguments=*/ {}))
1121
+ return failure ();
1122
+ caseValues.push_back (value);
1123
+ }
1124
+ cases = parser.getBuilder ().getDenseI64ArrayAttr (caseValues);
1125
+ return success ();
1126
+ }
1127
+
1128
+ // / Print the case regions and values.
1129
+ static void printSwitchCases (OpAsmPrinter &parser, Operation *op,
1130
+ DenseI64ArrayAttr cases, RegionRange caseRegions) {
1131
+ for (auto [value, region] : llvm::zip (cases.asArrayRef (), caseRegions)) {
1132
+ parser.printNewline ();
1133
+ parser << " case " << value << " : " ;
1134
+ parser.printRegion (*region, /* printEntryBlockArgs=*/ false );
1135
+ }
1136
+ return ;
1137
+ }
1138
+
1139
+ ParseResult SwitchOp::parse (OpAsmParser &parser, OperationState &result) {
1140
+ OpAsmParser::UnresolvedOperand arg;
1141
+ DenseI64ArrayAttr casesAttr;
1142
+ SmallVector<std::unique_ptr<Region>, 2 > caseRegionsRegions;
1143
+ std::unique_ptr<Region> defaultRegionRegion = std::make_unique<Region>();
1144
+
1145
+ if (parser.parseOperand (arg))
1146
+ return failure ();
1147
+
1148
+ Type argType;
1149
+ // Parse the case's type.
1150
+ if (parser.parseColon () || parser.parseType (argType))
1151
+ return failure ();
1152
+
1153
+ auto loc = parser.getCurrentLocation ();
1154
+ if (parser.parseOptionalAttrDict (result.attributes ))
1155
+ return failure ();
1156
+
1157
+ if (failed (verifyInherentAttrs (result.name , result.attributes , [&]() {
1158
+ return parser.emitError (loc)
1159
+ << " '" << result.name .getStringRef () << " ' op " ;
1160
+ })))
1161
+ return failure ();
1162
+
1163
+ auto odsResult = parseSwitchCases (parser, casesAttr, caseRegionsRegions);
1164
+ if (odsResult)
1165
+ return failure ();
1166
+
1167
+ result.getOrAddProperties <SwitchOp::Properties>().cases = casesAttr;
1168
+
1169
+ if (parser.parseKeyword (" default" ) || parser.parseColon ())
1170
+ return failure ();
1171
+
1172
+ if (parser.parseRegion (*defaultRegionRegion))
1173
+ return failure ();
1174
+
1175
+ result.addRegion (std::move (defaultRegionRegion));
1176
+ result.addRegions (caseRegionsRegions);
1177
+
1178
+ if (parser.resolveOperand (arg, argType, result.operands ))
1179
+ return failure ();
1180
+
1181
+ return success ();
1182
+ }
1183
+
1184
+ void SwitchOp::print (OpAsmPrinter &parser) {
1185
+ parser << ' ' ;
1186
+ parser << getArg ();
1187
+ SmallVector<StringRef, 2 > elidedAttrs;
1188
+ elidedAttrs.push_back (" cases" );
1189
+ parser.printOptionalAttrDict ((*this )->getAttrs (), elidedAttrs);
1190
+ parser << ' ' ;
1191
+ printSwitchCases (parser, *this , getCasesAttr (), getCaseRegions ());
1192
+ parser.printNewline ();
1193
+ parser << " default" ;
1194
+ parser << ' ' ;
1195
+ parser.printRegion (getDefaultRegion (), /* printEntryBlockArgs=*/ true ,
1196
+ /* printBlockTerminators=*/ true );
1197
+
1198
+ return ;
1199
+ }
1200
+
1201
+ static LogicalResult verifyRegion (emitc::SwitchOp op, Region ®ion,
1202
+ const Twine &name) {
1203
+ auto yield = dyn_cast<emitc::YieldOp>(region.front ().back ());
1204
+ if (!yield)
1205
+ return op.emitOpError (" expected region to end with emitc.yield, but got " )
1206
+ << region.front ().back ().getName ();
1207
+
1208
+ if (yield.getNumOperands () != 0 ) {
1209
+ return (op.emitOpError (" expected each region to return " )
1210
+ << " 0 values, but " << name << " returns "
1211
+ << yield.getNumOperands ())
1212
+ .attachNote (yield.getLoc ())
1213
+ << " see yield operation here" ;
1214
+ }
1215
+ return success ();
1216
+ }
1217
+
1218
+ LogicalResult emitc::SwitchOp::verify () {
1219
+ if (!isSwitchOperandType (getArg ().getType ()))
1220
+ return emitOpError (" unsupported type " ) << getArg ().getType ();
1221
+
1222
+ if (getCases ().size () != getCaseRegions ().size ()) {
1223
+ return emitOpError (" has " )
1224
+ << getCaseRegions ().size () << " case regions but "
1225
+ << getCases ().size () << " case values" ;
1226
+ }
1227
+
1228
+ DenseSet<int64_t > valueSet;
1229
+ for (int64_t value : getCases ())
1230
+ if (!valueSet.insert (value).second )
1231
+ return emitOpError (" has duplicate case value: " ) << value;
1232
+
1233
+ if (failed (verifyRegion (*this , getDefaultRegion (), " default region" )))
1234
+ return failure ();
1235
+
1236
+ for (auto [idx, caseRegion] : llvm::enumerate (getCaseRegions ()))
1237
+ if (failed (verifyRegion (*this , caseRegion, " case region #" + Twine (idx))))
1238
+ return failure ();
1239
+
1240
+ return success ();
1241
+ }
1242
+
1243
+ unsigned emitc::SwitchOp::getNumCases () { return getCases ().size (); }
1244
+
1245
+ Block &emitc::SwitchOp::getDefaultBlock () { return getDefaultRegion ().front (); }
1246
+
1247
+ Block &emitc::SwitchOp::getCaseBlock (unsigned idx) {
1248
+ assert (idx < getNumCases () && " case index out-of-bounds" );
1249
+ return getCaseRegions ()[idx].front ();
1250
+ }
1251
+
1252
+ void SwitchOp::getSuccessorRegions (
1253
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
1254
+ llvm::copy (getRegions (), std::back_inserter (successors));
1255
+ return ;
1256
+ }
1257
+
1258
+ void SwitchOp::getEntrySuccessorRegions (
1259
+ ArrayRef<Attribute> operands,
1260
+ SmallVectorImpl<RegionSuccessor> &successors) {
1261
+ FoldAdaptor adaptor (operands, *this );
1262
+
1263
+ // If a constant was not provided, all regions are possible successors.
1264
+ auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg ());
1265
+ if (!arg) {
1266
+ llvm::copy (getRegions (), std::back_inserter (successors));
1267
+ return ;
1268
+ }
1269
+
1270
+ // Otherwise, try to find a case with a matching value. If not, the
1271
+ // default region is the only successor.
1272
+ for (auto [caseValue, caseRegion] : llvm::zip (getCases (), getCaseRegions ())) {
1273
+ if (caseValue == arg.getInt ()) {
1274
+ successors.emplace_back (&caseRegion);
1275
+ return ;
1276
+ }
1277
+ }
1278
+ successors.emplace_back (&getDefaultRegion ());
1279
+ return ;
1280
+ }
1281
+
1282
+ void SwitchOp::getRegionInvocationBounds (
1283
+ ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
1284
+ auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front ());
1285
+ if (!operandValue) {
1286
+ // All regions are invoked at most once.
1287
+ bounds.append (getNumRegions (), InvocationBounds (/* lb=*/ 0 , /* ub=*/ 1 ));
1288
+ return ;
1289
+ }
1290
+
1291
+ unsigned liveIndex = getNumRegions () - 1 ;
1292
+ const auto *iteratorToInt = llvm::find (getCases (), operandValue.getInt ());
1293
+
1294
+ liveIndex = iteratorToInt != getCases ().end ()
1295
+ ? std::distance (getCases ().begin (), iteratorToInt)
1296
+ : liveIndex;
1297
+
1298
+ for (unsigned regIndex = 0 , regNum = getNumRegions (); regIndex < regNum;
1299
+ ++regIndex)
1300
+ bounds.emplace_back (/* lb=*/ 0 , /* ub=*/ regIndex == liveIndex);
1301
+
1302
+ return ;
1303
+ }
1304
+
1099
1305
// ===----------------------------------------------------------------------===//
1100
1306
// TableGen'd op method definitions
1101
1307
// ===----------------------------------------------------------------------===//
0 commit comments