Skip to content

Commit b4660dc

Browse files
committed
[mlir][ODS] Add OptionalTypesMatchWith and remove a custom assemblyFormat
This is just a slight specialization of `TypesMatchWith` that returns success if an optional parameter is missing. There may be other places this could help e.g.: https://github.com/llvm/llvm-project/blob/eb21049b4b904b072679ece60e73c6b0dc0d1ebf/mlir/include/mlir/Dialect/X86Vector/X86Vector.td#L58C5-L58C5 But I'm leaving those to avoid some churn. (This constraint will be handy for us in some later patches)
1 parent b98b567 commit b4660dc

File tree

5 files changed

+20
-52
lines changed

5 files changed

+20
-52
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ def Vector_ReductionOp :
215215
Vector_Op<"reduction", [Pure,
216216
PredOpTrait<"source operand and result have same element type",
217217
TCresVTEtIsSameAsOpBase<0, 0>>,
218+
OptionalTypesMatchWith<"dest and acc have the same type",
219+
"dest", "acc", "::llvm::cast<Type>($_self)">,
218220
DeclareOpInterfaceMethods<ArithFastMathInterface>,
219221
DeclareOpInterfaceMethods<MaskableOpInterface>,
220222
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
@@ -263,9 +265,8 @@ def Vector_ReductionOp :
263265
"::mlir::arith::FastMathFlags::none">:$fastMathFlags)>
264266
];
265267

266-
// TODO: Migrate to assemblyFormat once `AllTypesMatch` supports optional
267-
// operands.
268-
let hasCustomAssemblyFormat = 1;
268+
let assemblyFormat = "$kind `,` $vector (`,` $acc^)? (`fastmath` $fastmath^)?"
269+
" attr-dict `:` type($vector) `into` type($dest)";
269270
let hasCanonicalizer = 1;
270271
let hasVerifier = 1;
271272
}

mlir/include/mlir/IR/OpBase.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,21 @@ class TypesMatchWith<string summary, string lhsArg, string rhsArg,
568568
string transformer = transform;
569569
}
570570

571+
// Helper which makes the first letter of a string uppercase.
572+
// e.g. cat -> Cat
573+
class first_char_to_upper<string str>
574+
{
575+
string ret = !toupper(!substr(str, 0, 1)) # !substr(str, 1);
576+
}
577+
578+
// The same as TypesMatchWith but if either `lhsArg` or `rhsArg` are optional
579+
// and not present returns success.
580+
class OptionalTypesMatchWith<string summary, string lhsArg, string rhsArg,
581+
string transform, string comparator = "std::equal_to<>()">
582+
: TypesMatchWith<summary, lhsArg, rhsArg, transform,
583+
"!get" # first_char_to_upper<lhsArg>.ret # "()"
584+
# " || !get" # first_char_to_upper<rhsArg>.ret # "() || " # comparator>;
585+
571586
// Special variant of `TypesMatchWith` that provides a comparator suitable for
572587
// ranged arguments.
573588
class RangedTypesMatchWith<string summary, string lhsArg, string rhsArg,

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -485,47 +485,6 @@ LogicalResult ReductionOp::verify() {
485485
return success();
486486
}
487487

488-
ParseResult ReductionOp::parse(OpAsmParser &parser, OperationState &result) {
489-
SmallVector<OpAsmParser::UnresolvedOperand, 2> operandsInfo;
490-
Type redType;
491-
Type resType;
492-
CombiningKindAttr kindAttr;
493-
arith::FastMathFlagsAttr fastMathAttr;
494-
if (parser.parseCustomAttributeWithFallback(kindAttr, Type{}, "kind",
495-
result.attributes) ||
496-
parser.parseComma() || parser.parseOperandList(operandsInfo) ||
497-
(succeeded(parser.parseOptionalKeyword("fastmath")) &&
498-
parser.parseCustomAttributeWithFallback(fastMathAttr, Type{}, "fastmath",
499-
result.attributes)) ||
500-
parser.parseColonType(redType) ||
501-
parser.parseKeywordType("into", resType) ||
502-
(!operandsInfo.empty() &&
503-
parser.resolveOperand(operandsInfo[0], redType, result.operands)) ||
504-
(operandsInfo.size() > 1 &&
505-
parser.resolveOperand(operandsInfo[1], resType, result.operands)) ||
506-
parser.addTypeToList(resType, result.types))
507-
return failure();
508-
if (operandsInfo.empty() || operandsInfo.size() > 2)
509-
return parser.emitError(parser.getNameLoc(),
510-
"unsupported number of operands");
511-
return success();
512-
}
513-
514-
void ReductionOp::print(OpAsmPrinter &p) {
515-
p << " ";
516-
getKindAttr().print(p);
517-
p << ", " << getVector();
518-
if (getAcc())
519-
p << ", " << getAcc();
520-
521-
if (getFastmathAttr() &&
522-
getFastmathAttr().getValue() != arith::FastMathFlags::none) {
523-
p << ' ' << getFastmathAttrName().getValue();
524-
p.printStrippedAttrOrType(getFastmathAttr());
525-
}
526-
p << " : " << getVector().getType() << " into " << getDest().getType();
527-
}
528-
529488
// MaskableOpInterface methods.
530489

531490
/// Returns the mask type expected by this operation.

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,13 +1168,6 @@ func.func @reduce_unsupported_attr(%arg0: vector<16xf32>) -> i32 {
11681168

11691169
// -----
11701170

1171-
func.func @reduce_unsupported_third_argument(%arg0: vector<16xf32>, %arg1: f32) -> f32 {
1172-
// expected-error@+1 {{'vector.reduction' unsupported number of operands}}
1173-
%0 = vector.reduction <add>, %arg0, %arg1, %arg1 : vector<16xf32> into f32
1174-
}
1175-
1176-
// -----
1177-
11781171
func.func @reduce_unsupported_rank(%arg0: vector<4x16xf32>) -> f32 {
11791172
// expected-error@+1 {{'vector.reduction' op unsupported reduction rank: 2}}
11801173
%0 = vector.reduction <add>, %arg0 : vector<4x16xf32> into f32

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1042,7 +1042,7 @@ func.func @contraction_masked_scalable(%A: vector<3x4xf32>,
10421042

10431043
// CHECK-LABEL: func.func @fastmath(
10441044
func.func @fastmath(%x: vector<42xf32>) -> f32 {
1045-
// CHECK: vector.reduction <minf>, %{{.*}} fastmath<reassoc,nnan,ninf>
1045+
// CHECK: vector.reduction <minf>, %{{.*}} fastmath <reassoc,nnan,ninf>
10461046
%min = vector.reduction <minf>, %x fastmath<reassoc,nnan,ninf> : vector<42xf32> into f32
10471047
return %min: f32
10481048
}

0 commit comments

Comments
 (0)