Skip to content

Commit 88f07a7

Browse files
committed
[mlir] Make UnitAttr's default val in unwrapped builder
UnitAttr is optional but unwrapped builders require it. Make Change onstructing from bool as required for when not set at moment (for UnitAttr nothing needs to be constructed, this is true for others here too and can be addressed together). Differential Revision: https://reviews.llvm.org/D135058
1 parent 65a961f commit 88f07a7

File tree

5 files changed

+23
-21
lines changed

5 files changed

+23
-21
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -460,9 +460,6 @@ def SparseTensor_SortOp : SparseTensor_Op<"sort", [AttrSizedOperandSegments]>,
460460
let assemblyFormat = "(`stable` $stable^)? $n"
461461
"`,`$xs (`jointly` $ys^)? attr-dict"
462462
"`:` type($xs) (`jointly` type($ys)^)?";
463-
let builders = [
464-
OpBuilder<(ins "Value":$n, "ValueRange":$xs, "ValueRange":$ys)>
465-
];
466463
let hasVerifier = 1;
467464
}
468465

mlir/include/mlir/IR/OpBase.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1246,9 +1246,10 @@ class TypeAttrOf<Type ty>
12461246
// "true" if the attribute is present and "false" otherwise.
12471247
def UnitAttr : Attr<CPred<"$_self.isa<::mlir::UnitAttr>()">, "unit attribute"> {
12481248
let storageType = [{ ::mlir::UnitAttr }];
1249-
let constBuilderCall = "$_builder.getUnitAttr()";
1249+
let constBuilderCall = "(($0) ? $_builder.getUnitAttr() : nullptr)";
12501250
let convertFromStorage = "$_self != nullptr";
12511251
let returnType = "bool";
1252+
let defaultValue = "false";
12521253
let valueType = NoneType;
12531254
let isOptional = 1;
12541255
}
@@ -1575,7 +1576,7 @@ class ConstantAttr<Attr attribute, string val> : AttrConstraint<
15751576
class ConstF32Attr<string val> : ConstantAttr<F32Attr, val>;
15761577
def ConstBoolAttrFalse : ConstantAttr<BoolAttr, "false">;
15771578
def ConstBoolAttrTrue : ConstantAttr<BoolAttr, "true">;
1578-
def ConstUnitAttr : ConstantAttr<UnitAttr, "unit">;
1579+
def ConstUnitAttr : ConstantAttr<UnitAttr, "true">;
15791580

15801581
// Constant string-based attribute. Wraps the desired string in escaped quotes.
15811582
class ConstantStrAttr<Attr attribute, string val>

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -706,11 +706,6 @@ LogicalResult SelectOp::verify() {
706706
return success();
707707
}
708708

709-
void SortOp::build(OpBuilder &odsBuilder, OperationState &odsState, Value n,
710-
ValueRange xs, ValueRange ys) {
711-
build(odsBuilder, odsState, n, xs, ys, /*stable=*/false);
712-
}
713-
714709
LogicalResult SortOp::verify() {
715710
if (getXs().empty())
716711
return emitError("need at least one xs buffer.");

mlir/test/mlir-tblgen/op-attribute.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,11 @@ def UnitAttrOp : NS_Op<"unit_attr_op", []> {
488488
// DEF-NEXT: (*this)->removeAttr(getAttrAttrName());
489489

490490
// DEF: build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, /*optional*/::mlir::UnitAttr attr)
491+
// DEF: build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, /*optional*/bool attr)
492+
493+
// DECL-LABEL: UnitAttrOp declarations
494+
// DECL-NOT: declarations
495+
// DECL: build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, /*optional*/bool attr = false)
491496

492497

493498
// Test elementAttr field of TypedArrayAttr.

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1635,9 +1635,9 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() {
16351635
}
16361636

16371637
void OpEmitter::genPopulateDefaultAttributes() {
1638-
// All done if no attributes have default values.
1638+
// All done if no attributes, except optional ones, have default values.
16391639
if (llvm::all_of(op.getAttributes(), [](const NamedAttribute &named) {
1640-
return !named.attr.hasDefaultValue();
1640+
return !named.attr.hasDefaultValue() || named.attr.isOptional();
16411641
}))
16421642
return;
16431643

@@ -1667,8 +1667,8 @@ void OpEmitter::genPopulateDefaultAttributes() {
16671667
fctx.withBuilder(odsBuilder);
16681668
std::string defaultValue = std::string(
16691669
tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
1670-
body.indent() << formatv(" attributes.append(attrNames[{0}], {1});\n",
1671-
index, defaultValue);
1670+
body.indent() << formatv("attributes.append(attrNames[{0}], {1});\n", index,
1671+
defaultValue);
16721672
body.unindent() << "}\n";
16731673
}
16741674
}
@@ -2143,12 +2143,16 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
21432143
if (attr.isDerivedAttr() || inferredAttributes.contains(namedAttr.name))
21442144
continue;
21452145

2146-
// TODO(jpienaar): The wrapping of optional is different for default or not,
2147-
// so don't unwrap for default ones that would fail below.
2148-
bool emitNotNullCheck = (attr.isOptional() && !attr.hasDefaultValue()) ||
2149-
(attr.hasDefaultValue() && !isRawValueAttr);
2146+
// TODO: The wrapping of optional is different for default or not, so don't
2147+
// unwrap for default ones that would fail below.
2148+
bool emitNotNullCheck =
2149+
(attr.isOptional() && !attr.hasDefaultValue()) ||
2150+
(attr.hasDefaultValue() && !isRawValueAttr) ||
2151+
// TODO: UnitAttr is optional, not wrapped, but needs to be guarded as
2152+
// the constant materialization is only for true case.
2153+
(isRawValueAttr && attr.getAttrDefName() == "UnitAttr");
21502154
if (emitNotNullCheck)
2151-
body << formatv(" if ({0}) ", namedAttr.name) << "{\n";
2155+
body.indent() << formatv("if ({0}) ", namedAttr.name) << "{\n";
21522156

21532157
if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
21542158
// If this is a raw value, then we need to wrap it in an Attribute
@@ -2175,7 +2179,7 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
21752179
namedAttr.name);
21762180
}
21772181
if (emitNotNullCheck)
2178-
body << " }\n";
2182+
body.unindent() << " }\n";
21792183
}
21802184

21812185
// Create the correct number of regions.
@@ -2966,7 +2970,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(
29662970
// call. This should be set instead.
29672971
std::string defaultValue = std::string(
29682972
tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue()));
2969-
body << " if (!attr)\n attr = " << defaultValue << ";\n";
2973+
body << "if (!attr)\n attr = " << defaultValue << ";\n";
29702974
}
29712975
body << "return attr;\n";
29722976
};

0 commit comments

Comments
 (0)