Skip to content

Commit 41e5dbe

Browse files
committed
Enables inferring return types for Shape op if possible
Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D102565
1 parent c22b64e commit 41e5dbe

File tree

4 files changed

+316
-37
lines changed

4 files changed

+316
-37
lines changed

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td

Lines changed: 84 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ include "mlir/IR/SymbolInterfaces.td"
2828
class Shape_Op<string mnemonic, list<OpTrait> traits = []> :
2929
Op<ShapeDialect, mnemonic, traits>;
3030

31-
def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> {
31+
def Shape_AddOp : Shape_Op<"add",
32+
[Commutative, NoSideEffect,
33+
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
3234
let summary = "Addition of sizes and indices";
3335
let description = [{
3436
Adds two sizes or indices. If either operand is an error it will be
@@ -47,6 +49,12 @@ def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> {
4749
}];
4850

4951
let verifier = [{ return verifySizeOrIndexOp(*this); }];
52+
53+
let extraClassDeclaration = [{
54+
// Returns when two result types are compatible for this op; method used by
55+
// InferTypeOpInterface
56+
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
57+
}];
5058
}
5159

5260
def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> {
@@ -77,6 +85,8 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> {
7785
OptionalAttr<StrAttr>:$error);
7886
let results = (outs Shape_ShapeOrExtentTensorType:$result);
7987

88+
let builders = [OpBuilder<(ins "Value":$shape)>];
89+
8090
let assemblyFormat = [{
8191
$shapes attr-dict `:` type($shapes) `->` type($result)
8292
}];
@@ -145,7 +155,8 @@ def Shape_ConstSizeOp : Shape_Op<"const_size", [
145155
let hasFolder = 1;
146156
}
147157

148-
def Shape_DivOp : Shape_Op<"div", [NoSideEffect]> {
158+
def Shape_DivOp : Shape_Op<"div", [NoSideEffect,
159+
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
149160
let summary = "Division of sizes and indices";
150161
let description = [{
151162
Divides two sizes or indices. If either operand is an error it will be
@@ -173,10 +184,16 @@ def Shape_DivOp : Shape_Op<"div", [NoSideEffect]> {
173184

174185
let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
175186
let hasFolder = 1;
187+
188+
let extraClassDeclaration = [{
189+
// Returns when two result types are compatible for this op; method used by
190+
// InferTypeOpInterface
191+
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
192+
}];
176193
}
177194

178-
def Shape_ShapeEqOp : Shape_Op<"shape_eq", [NoSideEffect, Commutative,
179-
InferTypeOpInterface]> {
195+
def Shape_ShapeEqOp : Shape_Op<"shape_eq",
196+
[NoSideEffect, Commutative, InferTypeOpInterface]> {
180197
let summary = "Returns whether the input shapes or extent tensors are equal";
181198
let description = [{
182199
Takes one or more shape or extent tensor operands and determines whether
@@ -290,7 +307,8 @@ def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable",
290307
let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
291308
}
292309

293-
def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
310+
def Shape_RankOp : Shape_Op<"rank",
311+
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
294312
let summary = "Gets the rank of a shape";
295313
let description = [{
296314
Returns the rank of the shape or extent tensor, i.e. the number of extents.
@@ -304,6 +322,12 @@ def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
304322
let hasFolder = 1;
305323
let hasCanonicalizer = 1;
306324
let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
325+
326+
let extraClassDeclaration = [{
327+
// Returns when two result types are compatible for this op; method used by
328+
// InferTypeOpInterface
329+
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
330+
}];
307331
}
308332

309333
def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
@@ -324,7 +348,8 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
324348
let hasFolder = 1;
325349
}
326350

327-
def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
351+
def Shape_GetExtentOp : Shape_Op<"get_extent",
352+
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
328353
let summary = "Gets the specified extent from a shape or extent tensor";
329354
let description = [{
330355
Gets the extent indexed by `dim` from the `shape` operand. If the shape is
@@ -344,6 +369,9 @@ def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
344369
let extraClassDeclaration = [{
345370
/// Get the `dim` value as integer if it is constant.
346371
Optional<int64_t> getConstantDim();
372+
/// Returns when two result types are compatible for this op; method used by
373+
/// InferTypeOpInterface
374+
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
347375
}];
348376

349377
let hasFolder = 1;
@@ -369,7 +397,8 @@ def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> {
369397
let hasCanonicalizer = 1;
370398
}
371399

372-
def Shape_JoinOp : Shape_Op<"join", [Commutative]> {
400+
def Shape_JoinOp : Shape_Op<"join",
401+
[Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
373402
let summary = "Returns the least general shape.shape of its operands";
374403
let description = [{
375404
An operation that computes the least general shape of input operands.
@@ -405,9 +434,17 @@ def Shape_JoinOp : Shape_Op<"join", [Commutative]> {
405434
$arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`
406435
type($arg0) `,` type($arg1) `->` type($result)
407436
}];
437+
438+
let extraClassDeclaration = [{
439+
// Returns when two result types are compatible for this op; method used by
440+
// InferTypeOpInterface
441+
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
442+
}];
408443
}
409444

410-
def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> {
445+
def Shape_MaxOp : Shape_Op<"max",
446+
[Commutative, NoSideEffect,
447+
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
411448
let summary = "Elementwise maximum";
412449
let description = [{
413450
Computes the elementwise maximum of two sizes or shapes with equal ranks.
@@ -424,9 +461,17 @@ def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> {
424461
}];
425462

426463
let hasFolder = 1;
464+
465+
let extraClassDeclaration = [{
466+
// Returns when two result types are compatible for this op; method used by
467+
// InferTypeOpInterface
468+
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
469+
}];
427470
}
428471

429-
def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> {
472+
def Shape_MinOp : Shape_Op<"min",
473+
[Commutative, NoSideEffect,
474+
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
430475
let summary = "Elementwise minimum";
431476
let description = [{
432477
Computes the elementwise minimum of two sizes or shapes with equal ranks.
@@ -443,9 +488,17 @@ def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> {
443488
}];
444489

445490
let hasFolder = 1;
491+
492+
let extraClassDeclaration = [{
493+
// Returns when two result types are compatible for this op; method used by
494+
// InferTypeOpInterface
495+
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
496+
}];
446497
}
447498

448-
def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
499+
def Shape_MulOp : Shape_Op<"mul",
500+
[Commutative, NoSideEffect,
501+
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
449502
let summary = "Multiplication of sizes and indices";
450503
let description = [{
451504
Multiplies two sizes or indices. If either operand is an error it will be
@@ -465,9 +518,16 @@ def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
465518

466519
let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
467520
let hasFolder = 1;
521+
522+
let extraClassDeclaration = [{
523+
// Returns when two result types are compatible for this op; method used by
524+
// InferTypeOpInterface
525+
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
526+
}];
468527
}
469528

470-
def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
529+
def Shape_NumElementsOp : Shape_Op<"num_elements",
530+
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
471531
let summary = "Returns the number of elements for a given shape";
472532
let description = [{
473533
Returns the number of elements for a given shape which is the product of its
@@ -480,12 +540,15 @@ def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
480540
let arguments = (ins Shape_ShapeOrExtentTensorType:$shape);
481541
let results = (outs Shape_SizeOrIndexType:$result);
482542

483-
let builders = [OpBuilder<(ins "Value":$shape)>];
484-
485543
let assemblyFormat = "$shape attr-dict `:` type($shape) `->` type($result)";
486544

487545
let hasFolder = 1;
488546
let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
547+
let extraClassDeclaration = [{
548+
// Returns when two result types are compatible for this op; method used by
549+
// InferTypeOpInterface
550+
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
551+
}];
489552
}
490553

491554
def Shape_ReduceOp : Shape_Op<"reduce",
@@ -535,7 +598,8 @@ def Shape_ReduceOp : Shape_Op<"reduce",
535598
let parser = [{ return ::parse$cppClass(parser, result); }];
536599
}
537600

538-
def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> {
601+
def Shape_ShapeOfOp : Shape_Op<"shape_of",
602+
[NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
539603
let summary = "Returns shape of a value or shaped type operand";
540604

541605
let description = [{
@@ -548,11 +612,15 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> {
548612

549613
let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)";
550614

551-
let builders = [OpBuilder<(ins "Value":$arg)>];
552-
553615
let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
554616
let hasCanonicalizer = 1;
555617
let hasFolder = 1;
618+
619+
let extraClassDeclaration = [{
620+
// Returns when two result types are compatible for this op; method used by
621+
// InferTypeOpInterface
622+
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
623+
}];
556624
}
557625

558626
def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> {

mlir/include/mlir/Interfaces/InferTypeOpInterface.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ def InferTypeOpInterface : OpInterface<"InferTypeOpInterface"> {
3434
The method takes an optional location which, if set, will be used to
3535
report errors on. The operands and attributes correspond to those with
3636
which an Operation would be created (e.g., as used in Operation::create)
37-
and the regions of the op.
37+
and the regions of the op. Be aware that this method is supposed to be
38+
called with valid arguments, e.g., operands are verified, or it may result
39+
in an undefined behavior.
3840
}],
3941
/*retTy=*/"::mlir::LogicalResult",
4042
/*methodName=*/"inferReturnTypes",

0 commit comments

Comments
 (0)