|
14 | 14 |
|
15 | 15 | #include "mlir/TableGen/CodeGenHelpers.h"
|
16 | 16 | #include "llvm/ADT/StringExtras.h"
|
| 17 | +#include "llvm/ADT/StringSet.h" |
17 | 18 | #include "llvm/ADT/TypeSwitch.h"
|
18 | 19 | #include "llvm/TableGen/Error.h"
|
19 | 20 | #include "llvm/TableGen/Record.h"
|
@@ -184,40 +185,59 @@ static void verifyClause(Record *op, Record *clause) {
|
184 | 185 | /// All kinds of values are represented as `mlir::Value` fields, whereas
|
185 | 186 | /// attributes are represented based on their `storageType`.
|
186 | 187 | ///
|
| 188 | +/// \param[in] name The name of the argument. |
187 | 189 | /// \param[in] init The `DefInit` object representing the argument.
|
188 | 190 | /// \param[out] rank Number of levels of array nesting associated with the
|
189 | 191 | /// type.
|
190 | 192 | ///
|
191 | 193 | /// \return the name of the base type to represent elements of the argument
|
192 | 194 | /// type.
|
193 |
| -static StringRef translateArgumentType(Init *init, int &rank) { |
| 195 | +static StringRef translateArgumentType(ArrayRef<SMLoc> loc, StringInit *name, |
| 196 | + Init *init, int &rank) { |
194 | 197 | Record *def = cast<DefInit>(init)->getDef();
|
195 |
| - bool isAttr = false, isValue = false; |
196 | 198 |
|
197 |
| - for (auto [sc, _] : def->getSuperClasses()) { |
198 |
| - std::string scName = sc->getNameInitAsString(); |
199 |
| - if (scName == "OptionalAttr") |
200 |
| - return translateArgumentType(def->getValue("baseAttr")->getValue(), rank); |
201 |
| - |
202 |
| - if (scName == "TypedArrayAttrBase") { |
203 |
| - ++rank; |
204 |
| - return translateArgumentType(def->getValue("elementAttr")->getValue(), |
205 |
| - rank); |
206 |
| - } |
207 |
| - |
208 |
| - if (scName == "ElementsAttrBase") { |
209 |
| - rank += def->getValueAsInt("rank"); |
210 |
| - return def->getValueAsString("elementReturnType").trim(); |
211 |
| - } |
212 |
| - |
213 |
| - if (scName == "Attr") |
214 |
| - isAttr = true; |
215 |
| - else if (scName == "TypeConstraint") |
216 |
| - isValue = true; |
217 |
| - else if (scName == "Variadic") |
218 |
| - ++rank; |
| 199 | + llvm::StringSet superClasses; |
| 200 | + for (auto [sc, _] : def->getSuperClasses()) |
| 201 | + superClasses.insert(sc->getNameInitAsString()); |
| 202 | + |
| 203 | + // Handle wrapper-style superclasses. |
| 204 | + if (superClasses.contains("OptionalAttr")) |
| 205 | + return translateArgumentType(loc, name, |
| 206 | + def->getValue("baseAttr")->getValue(), rank); |
| 207 | + |
| 208 | + if (superClasses.contains("TypedArrayAttrBase")) |
| 209 | + return translateArgumentType( |
| 210 | + loc, name, def->getValue("elementAttr")->getValue(), ++rank); |
| 211 | + |
| 212 | + // Handle ElementsAttrBase superclasses. |
| 213 | + if (superClasses.contains("ElementsAttrBase")) { |
| 214 | + // TODO: Support properly obtaining rank from ranked types. |
| 215 | + ++rank; |
| 216 | + |
| 217 | + if (superClasses.contains("IntElementsAttrBase")) |
| 218 | + return "::llvm::APInt"; |
| 219 | + if (superClasses.contains("FloatElementsAttr") || |
| 220 | + superClasses.contains("RankedFloatElementsAttr")) |
| 221 | + return "::llvm::APFloat"; |
| 222 | + if (superClasses.contains("DenseArrayAttrBase")) |
| 223 | + return stripPrefixAndSuffix(def->getValueAsString("returnType"), |
| 224 | + {"::llvm::ArrayRef<"}, {">"}); |
| 225 | + |
| 226 | + // Reset the rank in the case where the base type cannot be inferred, so |
| 227 | + // that the bare storageType is used instead of a vector. |
| 228 | + rank = 0; |
| 229 | + PrintWarning( |
| 230 | + loc, |
| 231 | + "could not infer array-like attribute element type for argument '" + |
| 232 | + name->getAsUnquotedString() + "', will use bare `storageType`"); |
219 | 233 | }
|
220 | 234 |
|
| 235 | + // Handle simple attribute and value types. |
| 236 | + bool isAttr = superClasses.contains("Attr"); |
| 237 | + bool isValue = superClasses.contains("TypeConstraint"); |
| 238 | + if (superClasses.contains("Variadic")) |
| 239 | + ++rank; |
| 240 | + |
221 | 241 | if (isValue) {
|
222 | 242 | assert(!isAttr &&
|
223 | 243 | "argument can't be simultaneously a value and an attribute");
|
@@ -246,7 +266,8 @@ static void genClauseOpsStruct(Record *clause, raw_ostream &os) {
|
246 | 266 | for (auto [name, arg] :
|
247 | 267 | zip_equal(arguments->getArgNames(), arguments->getArgs())) {
|
248 | 268 | int rank = 0;
|
249 |
| - StringRef baseType = translateArgumentType(arg, rank); |
| 269 | + StringRef baseType = |
| 270 | + translateArgumentType(clause->getLoc(), name, arg, rank); |
250 | 271 |
|
251 | 272 | if (rank > 0)
|
252 | 273 | os << " ::llvm::SmallVector<" << baseType << ">";
|
|
0 commit comments