Skip to content

Commit 56a445c

Browse files
committed
Address review comments
1 parent 8948a6b commit 56a445c

File tree

3 files changed

+58
-45
lines changed

3 files changed

+58
-45
lines changed

mlir/include/mlir/IR/CommonAttrConstraints.td

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -408,26 +408,17 @@ class ElementsAttrBase<Pred condition, string summary> :
408408
let storageType = [{ ::mlir::ElementsAttr }];
409409
let returnType = [{ ::mlir::ElementsAttr }];
410410
let convertFromStorage = "$_self";
411-
412-
// The underlying C++ value type of each element.
413-
string elementReturnType = ?;
414-
415-
// The number of dimensions represented by the element collection.
416-
int rank = 1;
417411
}
418412

419413
def ElementsAttr : ElementsAttrBase<CPred<"::llvm::isa<::mlir::ElementsAttr>($_self)">,
420-
"constant vector/tensor attribute"> {
421-
let elementReturnType = [{ ::mlir::Attribute }];
422-
}
414+
"constant vector/tensor attribute">;
423415

424416
class IntElementsAttrBase<Pred condition, string summary> :
425417
ElementsAttrBase<And<[CPred<"::llvm::isa<::mlir::DenseIntElementsAttr>($_self)">,
426418
condition]>,
427419
summary> {
428420
let storageType = [{ ::mlir::DenseIntElementsAttr }];
429421
let returnType = [{ ::mlir::DenseIntElementsAttr }];
430-
let elementReturnType = [{ ::llvm::APInt }];
431422

432423
let convertFromStorage = "$_self";
433424
}
@@ -437,7 +428,6 @@ class DenseArrayAttrBase<string denseAttrName, string cppType, string summaryNam
437428
summaryName # " dense array attribute"> {
438429
let storageType = "::mlir::" # denseAttrName;
439430
let returnType = "::llvm::ArrayRef<" # cppType # ">";
440-
let elementReturnType = cppType;
441431
let constBuilderCall = "$_builder.get" # denseAttrName # "($0)";
442432
}
443433
def DenseBoolArrayAttr : DenseArrayAttrBase<"DenseBoolArrayAttr", "bool", "i1">;
@@ -496,8 +486,6 @@ class RankedSignlessIntElementsAttr<int width, list<int> dims> :
496486
let constBuilderCall = "::mlir::DenseIntElementsAttr::get("
497487
"::mlir::RankedTensorType::get({" # !interleave(dims, ", ") #
498488
"}, $_builder.getIntegerType(" # width # ")), ::llvm::ArrayRef($0))";
499-
500-
let rank = !size(dims);
501489
}
502490

503491
class RankedI32ElementsAttr<list<int> dims> :
@@ -513,7 +501,6 @@ class FloatElementsAttr<int width> : ElementsAttrBase<
513501

514502
let storageType = [{ ::mlir::DenseElementsAttr }];
515503
let returnType = [{ ::mlir::DenseElementsAttr }];
516-
let elementReturnType = [{ ::llvm::APFloat }];
517504

518505
// Note that this is only constructing scalar elements attribute.
519506
let constBuilderCall = "::mlir::DenseElementsAttr::get("
@@ -539,8 +526,6 @@ class RankedFloatElementsAttr<int width, list<int> dims> : ElementsAttrBase<
539526

540527
let storageType = [{ ::mlir::DenseFPElementsAttr }];
541528
let returnType = [{ ::mlir::DenseFPElementsAttr }];
542-
let elementReturnType = [{ ::llvm::APFloat }];
543-
let rank = !size(dims);
544529

545530
let constBuilderCall = "::llvm::cast<::mlir::DenseFPElementsAttr>("
546531
"::mlir::DenseElementsAttr::get("
@@ -559,7 +544,6 @@ def StringElementsAttr : ElementsAttrBase<
559544

560545
let storageType = [{ ::mlir::DenseElementsAttr }];
561546
let returnType = [{ ::mlir::DenseElementsAttr }];
562-
let elementReturnType = [{ ::llvm::SmallString }];
563547

564548
let convertFromStorage = "$_self";
565549
}

mlir/test/mlir-tblgen/openmp-clause-ops.td

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
// RUN: %S/../../../llvm/include/llvm/Frontend/OpenMP/OMP.td \
88
// RUN: -I %S/../../../llvm/include > %t/mlir/Dialect/OpenMP/OmpCommon.td
99

10-
// RUN: mlir-tblgen -gen-openmp-clause-ops -I %S/../../include -I %t %s | FileCheck %s
10+
// RUN: mlir-tblgen -gen-openmp-clause-ops -I %S/../../include -I %t %s 2>&1 | FileCheck %s
1111

1212
include "mlir/Dialect/OpenMP/OpenMPOpBase.td"
1313

@@ -34,7 +34,13 @@ def OpenMP_MyFirstClause : OpenMP_Clause<
3434
OptionalAttr<DenseI8ArrayAttr>:$opt_int_elems_attr,
3535

3636
// Multi-level composition
37-
ConfinedAttr<OptionalAttr<I64Attr>, [IntMinValue<0>]>:$complex_opt_int_attr
37+
ConfinedAttr<OptionalAttr<I64Attr>, [IntMinValue<0>]>:$complex_opt_int_attr,
38+
39+
// ElementsAttrBase-related edge cases.
40+
// CHECK: warning: could not infer array-like attribute element type for argument 'elements_attr', will use bare `storageType`
41+
ElementsAttr:$elements_attr,
42+
// CHECK: warning: could not infer array-like attribute element type for argument 'string_elements_attr', will use bare `storageType`
43+
StringElementsAttr:$string_elements_attr
3844
);
3945
}
4046
// CHECK: struct MyFirstClauseOps {
@@ -46,13 +52,15 @@ def OpenMP_MyFirstClause : OpenMP_Clause<
4652
// CHECK-NEXT: ::llvm::SmallVector<::mlir::Attribute> strArrayAttr;
4753
// CHECK-NEXT: ::llvm::SmallVector<::llvm::APInt> anyintElemsAttr;
4854
// CHECK-NEXT: ::llvm::SmallVector<::llvm::APFloat> floatNdElemsAttr;
49-
// CHECK-NEXT: int floatNdElemsAttrDims[3];
5055

5156
// CHECK-NEXT: ::mlir::BoolAttr optBoolAttr;
5257
// CHECK-NEXT: ::llvm::SmallVector<::mlir::Attribute> optIntArrayAttr;
5358
// CHECK-NEXT: ::llvm::SmallVector<int8_t> optIntElemsAttr;
5459

5560
// CHECK-NEXT: ::mlir::IntegerAttr complexOptIntAttr;
61+
62+
// CHECK-NEXT: ::mlir::ElementsAttr elementsAttr;
63+
// CHECK-NEXT: ::mlir::DenseElementsAttr stringElementsAttr;
5664
// CHECK-NEXT: }
5765

5866
def OpenMP_MySecondClause : OpenMP_Clause<

mlir/tools/mlir-tblgen/OmpOpGen.cpp

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "mlir/TableGen/CodeGenHelpers.h"
1616
#include "llvm/ADT/StringExtras.h"
17+
#include "llvm/ADT/StringSet.h"
1718
#include "llvm/ADT/TypeSwitch.h"
1819
#include "llvm/TableGen/Error.h"
1920
#include "llvm/TableGen/Record.h"
@@ -184,40 +185,59 @@ static void verifyClause(Record *op, Record *clause) {
184185
/// All kinds of values are represented as `mlir::Value` fields, whereas
185186
/// attributes are represented based on their `storageType`.
186187
///
188+
/// \param[in] name The name of the argument.
187189
/// \param[in] init The `DefInit` object representing the argument.
188190
/// \param[out] rank Number of levels of array nesting associated with the
189191
/// type.
190192
///
191193
/// \return the name of the base type to represent elements of the argument
192194
/// type.
193-
static StringRef translateArgumentType(Init *init, int &rank) {
195+
static StringRef translateArgumentType(ArrayRef<SMLoc> loc, StringInit *name,
196+
Init *init, int &rank) {
194197
Record *def = cast<DefInit>(init)->getDef();
195-
bool isAttr = false, isValue = false;
196198

197-
for (auto [sc, _] : def->getSuperClasses()) {
198-
std::string scName = sc->getNameInitAsString();
199-
if (scName == "OptionalAttr")
200-
return translateArgumentType(def->getValue("baseAttr")->getValue(), rank);
201-
202-
if (scName == "TypedArrayAttrBase") {
203-
++rank;
204-
return translateArgumentType(def->getValue("elementAttr")->getValue(),
205-
rank);
206-
}
207-
208-
if (scName == "ElementsAttrBase") {
209-
rank += def->getValueAsInt("rank");
210-
return def->getValueAsString("elementReturnType").trim();
211-
}
212-
213-
if (scName == "Attr")
214-
isAttr = true;
215-
else if (scName == "TypeConstraint")
216-
isValue = true;
217-
else if (scName == "Variadic")
218-
++rank;
199+
llvm::StringSet superClasses;
200+
for (auto [sc, _] : def->getSuperClasses())
201+
superClasses.insert(sc->getNameInitAsString());
202+
203+
// Handle wrapper-style superclasses.
204+
if (superClasses.contains("OptionalAttr"))
205+
return translateArgumentType(loc, name,
206+
def->getValue("baseAttr")->getValue(), rank);
207+
208+
if (superClasses.contains("TypedArrayAttrBase"))
209+
return translateArgumentType(
210+
loc, name, def->getValue("elementAttr")->getValue(), ++rank);
211+
212+
// Handle ElementsAttrBase superclasses.
213+
if (superClasses.contains("ElementsAttrBase")) {
214+
// TODO: Support properly obtaining rank from ranked types.
215+
++rank;
216+
217+
if (superClasses.contains("IntElementsAttrBase"))
218+
return "::llvm::APInt";
219+
if (superClasses.contains("FloatElementsAttr") ||
220+
superClasses.contains("RankedFloatElementsAttr"))
221+
return "::llvm::APFloat";
222+
if (superClasses.contains("DenseArrayAttrBase"))
223+
return stripPrefixAndSuffix(def->getValueAsString("returnType"),
224+
{"::llvm::ArrayRef<"}, {">"});
225+
226+
// Reset the rank in the case where the base type cannot be inferred, so
227+
// that the bare storageType is used instead of a vector.
228+
rank = 0;
229+
PrintWarning(
230+
loc,
231+
"could not infer array-like attribute element type for argument '" +
232+
name->getAsUnquotedString() + "', will use bare `storageType`");
219233
}
220234

235+
// Handle simple attribute and value types.
236+
bool isAttr = superClasses.contains("Attr");
237+
bool isValue = superClasses.contains("TypeConstraint");
238+
if (superClasses.contains("Variadic"))
239+
++rank;
240+
221241
if (isValue) {
222242
assert(!isAttr &&
223243
"argument can't be simultaneously a value and an attribute");
@@ -246,7 +266,8 @@ static void genClauseOpsStruct(Record *clause, raw_ostream &os) {
246266
for (auto [name, arg] :
247267
zip_equal(arguments->getArgNames(), arguments->getArgs())) {
248268
int rank = 0;
249-
StringRef baseType = translateArgumentType(arg, rank);
269+
StringRef baseType =
270+
translateArgumentType(clause->getLoc(), name, arg, rank);
250271

251272
if (rank > 0)
252273
os << " ::llvm::SmallVector<" << baseType << ">";

0 commit comments

Comments
 (0)