@@ -131,6 +131,12 @@ bool mlir::emitc::isPointerWideType(Type type) {
131
131
type);
132
132
}
133
133
134
+ bool mlir::emitc::isSwitchOperandType (Type type) {
135
+ auto intType = llvm::dyn_cast<IntegerType>(type);
136
+ return isSupportedIntegerType (type) && intType.getWidth () != 1 &&
137
+ intType.getWidth () != 8 ;
138
+ }
139
+
134
140
// / Check that the type of the initial value is compatible with the operations
135
141
// / result type.
136
142
static LogicalResult verifyInitializationAttribute (Operation *op,
@@ -1096,6 +1102,205 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1096
1102
return success ();
1097
1103
}
1098
1104
1105
+ // ===----------------------------------------------------------------------===//
1106
+ // SwitchOp
1107
+ // ===----------------------------------------------------------------------===//
1108
+
1109
+ // / Parse the case regions and values.
1110
+ static ParseResult
1111
+ parseSwitchCases (OpAsmParser &parser, DenseI64ArrayAttr &cases,
1112
+ SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
1113
+ SmallVector<int64_t > caseValues;
1114
+ while (succeeded (parser.parseOptionalKeyword (" case" ))) {
1115
+ int64_t value;
1116
+ Region ®ion = *caseRegions.emplace_back (std::make_unique<Region>());
1117
+
1118
+ if (parser.parseInteger (value) || parser.parseColon () ||
1119
+ parser.parseRegion (region, /* arguments=*/ {}))
1120
+ return failure ();
1121
+ caseValues.push_back (value);
1122
+ }
1123
+ cases = parser.getBuilder ().getDenseI64ArrayAttr (caseValues);
1124
+ return success ();
1125
+ }
1126
+
1127
+ // / Print the case regions and values.
1128
+ static void printSwitchCases (OpAsmPrinter &parser, Operation *op,
1129
+ DenseI64ArrayAttr cases, RegionRange caseRegions) {
1130
+ for (auto [value, region] : llvm::zip (cases.asArrayRef (), caseRegions)) {
1131
+ parser.printNewline ();
1132
+ parser << " case " << value << " : " ;
1133
+ parser.printRegion (*region, /* printEntryBlockArgs=*/ false );
1134
+ }
1135
+ return ;
1136
+ }
1137
+
1138
+ ParseResult SwitchOp::parse (OpAsmParser &parser, OperationState &result) {
1139
+ OpAsmParser::UnresolvedOperand arg;
1140
+ DenseI64ArrayAttr casesAttr;
1141
+ SmallVector<std::unique_ptr<Region>, 2 > caseRegionsRegions;
1142
+ std::unique_ptr<Region> defaultRegionRegion = std::make_unique<Region>();
1143
+
1144
+ if (parser.parseOperand (arg))
1145
+ return failure ();
1146
+
1147
+ Type argType = parser.getBuilder ().getI32Type ();
1148
+ // Parse optional type, else assume i32.
1149
+ if (!parser.parseOptionalColon () && parser.parseType (argType))
1150
+ return failure ();
1151
+
1152
+ auto loc = parser.getCurrentLocation ();
1153
+ if (parser.parseOptionalAttrDict (result.attributes ))
1154
+ return failure ();
1155
+
1156
+ if (failed (verifyInherentAttrs (result.name , result.attributes , [&]() {
1157
+ return parser.emitError (loc)
1158
+ << " '" << result.name .getStringRef () << " ' op " ;
1159
+ })))
1160
+ return failure ();
1161
+
1162
+ auto odsResult = parseSwitchCases (parser, casesAttr, caseRegionsRegions);
1163
+ if (odsResult)
1164
+ return failure ();
1165
+
1166
+ result.getOrAddProperties <SwitchOp::Properties>().cases = casesAttr;
1167
+
1168
+ if (parser.parseKeyword (" default" ) || parser.parseColon ())
1169
+ return failure ();
1170
+
1171
+ if (parser.parseRegion (*defaultRegionRegion))
1172
+ return failure ();
1173
+
1174
+ result.addRegion (std::move (defaultRegionRegion));
1175
+ result.addRegions (caseRegionsRegions);
1176
+
1177
+ if (parser.resolveOperand (arg, argType, result.operands ))
1178
+ return failure ();
1179
+
1180
+ return success ();
1181
+ }
1182
+
1183
+ void SwitchOp::print (OpAsmPrinter &parser) {
1184
+ parser << ' ' ;
1185
+ parser << getArg ();
1186
+ SmallVector<StringRef, 2 > elidedAttrs;
1187
+ elidedAttrs.push_back (" cases" );
1188
+ parser.printOptionalAttrDict ((*this )->getAttrs (), elidedAttrs);
1189
+ parser << ' ' ;
1190
+ printSwitchCases (parser, *this , getCasesAttr (), getCaseRegions ());
1191
+ parser.printNewline ();
1192
+ parser << " default" ;
1193
+ parser << ' ' ;
1194
+ parser.printRegion (getDefaultRegion (), /* printEntryBlockArgs=*/ true ,
1195
+ /* printBlockTerminators=*/ true );
1196
+
1197
+ return ;
1198
+ }
1199
+
1200
+ static LogicalResult verifyRegion (emitc::SwitchOp op, Region ®ion,
1201
+ const Twine &name) {
1202
+ auto yield = dyn_cast<emitc::YieldOp>(region.front ().back ());
1203
+ if (!yield)
1204
+ return op.emitOpError (" expected region to end with emitc.yield, but got " )
1205
+ << region.front ().back ().getName ();
1206
+
1207
+ if (yield.getNumOperands () != 0 ) {
1208
+ return (op.emitOpError (" expected each region to return " )
1209
+ << " 0 values, but " << name << " returns "
1210
+ << yield.getNumOperands ())
1211
+ .attachNote (yield.getLoc ())
1212
+ << " see yield operation here" ;
1213
+ }
1214
+ return success ();
1215
+ }
1216
+
1217
+ LogicalResult emitc::SwitchOp::verify () {
1218
+ if (!isSwitchOperandType (getArg ().getType ()))
1219
+ return emitOpError (" unsupported type " ) << getArg ().getType ();
1220
+
1221
+ if (getCases ().size () != getCaseRegions ().size ()) {
1222
+ return emitOpError (" has " )
1223
+ << getCaseRegions ().size () << " case regions but "
1224
+ << getCases ().size () << " case values" ;
1225
+ }
1226
+
1227
+ DenseSet<int64_t > valueSet;
1228
+ for (int64_t value : getCases ())
1229
+ if (!valueSet.insert (value).second )
1230
+ return emitOpError (" has duplicate case value: " ) << value;
1231
+
1232
+ if (failed (verifyRegion (*this , getDefaultRegion (), " default region" )))
1233
+ return failure ();
1234
+
1235
+ for (auto [idx, caseRegion] : llvm::enumerate (getCaseRegions ()))
1236
+ if (failed (verifyRegion (*this , caseRegion, " case region #" + Twine (idx))))
1237
+ return failure ();
1238
+
1239
+ return success ();
1240
+ }
1241
+
1242
+ unsigned emitc::SwitchOp::getNumCases () { return getCases ().size (); }
1243
+
1244
+ Block &emitc::SwitchOp::getDefaultBlock () { return getDefaultRegion ().front (); }
1245
+
1246
+ Block &emitc::SwitchOp::getCaseBlock (unsigned idx) {
1247
+ assert (idx < getNumCases () && " case index out-of-bounds" );
1248
+ return getCaseRegions ()[idx].front ();
1249
+ }
1250
+
1251
+ void SwitchOp::getSuccessorRegions (
1252
+ RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
1253
+ llvm::copy (getRegions (), std::back_inserter (successors));
1254
+ return ;
1255
+ }
1256
+
1257
+ void SwitchOp::getEntrySuccessorRegions (
1258
+ ArrayRef<Attribute> operands,
1259
+ SmallVectorImpl<RegionSuccessor> &successors) {
1260
+ FoldAdaptor adaptor (operands, *this );
1261
+
1262
+ // If a constant was not provided, all regions are possible successors.
1263
+ auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg ());
1264
+ if (!arg) {
1265
+ llvm::copy (getRegions (), std::back_inserter (successors));
1266
+ return ;
1267
+ }
1268
+
1269
+ // Otherwise, try to find a case with a matching value. If not, the
1270
+ // default region is the only successor.
1271
+ for (auto [caseValue, caseRegion] : llvm::zip (getCases (), getCaseRegions ())) {
1272
+ if (caseValue == arg.getInt ()) {
1273
+ successors.emplace_back (&caseRegion);
1274
+ return ;
1275
+ }
1276
+ }
1277
+ successors.emplace_back (&getDefaultRegion ());
1278
+ return ;
1279
+ }
1280
+
1281
+ void SwitchOp::getRegionInvocationBounds (
1282
+ ArrayRef<Attribute> operands, SmallVectorImpl<InvocationBounds> &bounds) {
1283
+ auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front ());
1284
+ if (!operandValue) {
1285
+ // All regions are invoked at most once.
1286
+ bounds.append (getNumRegions (), InvocationBounds (/* lb=*/ 0 , /* ub=*/ 1 ));
1287
+ return ;
1288
+ }
1289
+
1290
+ unsigned liveIndex = getNumRegions () - 1 ;
1291
+ const auto *iteratorToInt = llvm::find (getCases (), operandValue.getInt ());
1292
+
1293
+ liveIndex = iteratorToInt != getCases ().end ()
1294
+ ? std::distance (getCases ().begin (), iteratorToInt)
1295
+ : liveIndex;
1296
+
1297
+ for (unsigned regIndex = 0 , regNum = getNumRegions (); regIndex < regNum;
1298
+ ++regIndex)
1299
+ bounds.emplace_back (/* lb=*/ 0 , /* ub=*/ regIndex == liveIndex);
1300
+
1301
+ return ;
1302
+ }
1303
+
1099
1304
// ===----------------------------------------------------------------------===//
1100
1305
// TableGen'd op method definitions
1101
1306
// ===----------------------------------------------------------------------===//
0 commit comments