@@ -28,7 +28,9 @@ include "mlir/IR/SymbolInterfaces.td"
28
28
class Shape_Op<string mnemonic, list<OpTrait> traits = []> :
29
29
Op<ShapeDialect, mnemonic, traits>;
30
30
31
- def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> {
31
+ def Shape_AddOp : Shape_Op<"add",
32
+ [Commutative, NoSideEffect,
33
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
32
34
let summary = "Addition of sizes and indices";
33
35
let description = [{
34
36
Adds two sizes or indices. If either operand is an error it will be
@@ -47,6 +49,12 @@ def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> {
47
49
}];
48
50
49
51
let verifier = [{ return verifySizeOrIndexOp(*this); }];
52
+
53
+ let extraClassDeclaration = [{
54
+ // Returns when two result types are compatible for this op; method used by
55
+ // InferTypeOpInterface
56
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
57
+ }];
50
58
}
51
59
52
60
def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> {
@@ -77,6 +85,8 @@ def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative, NoSideEffect]> {
77
85
OptionalAttr<StrAttr>:$error);
78
86
let results = (outs Shape_ShapeOrExtentTensorType:$result);
79
87
88
+ let builders = [OpBuilder<(ins "Value":$shape)>];
89
+
80
90
let assemblyFormat = [{
81
91
$shapes attr-dict `:` type($shapes) `->` type($result)
82
92
}];
@@ -145,7 +155,8 @@ def Shape_ConstSizeOp : Shape_Op<"const_size", [
145
155
let hasFolder = 1;
146
156
}
147
157
148
- def Shape_DivOp : Shape_Op<"div", [NoSideEffect]> {
158
+ def Shape_DivOp : Shape_Op<"div", [NoSideEffect,
159
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
149
160
let summary = "Division of sizes and indices";
150
161
let description = [{
151
162
Divides two sizes or indices. If either operand is an error it will be
@@ -173,10 +184,16 @@ def Shape_DivOp : Shape_Op<"div", [NoSideEffect]> {
173
184
174
185
let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
175
186
let hasFolder = 1;
187
+
188
+ let extraClassDeclaration = [{
189
+ // Returns when two result types are compatible for this op; method used by
190
+ // InferTypeOpInterface
191
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
192
+ }];
176
193
}
177
194
178
- def Shape_ShapeEqOp : Shape_Op<"shape_eq", [NoSideEffect, Commutative,
179
- InferTypeOpInterface]> {
195
+ def Shape_ShapeEqOp : Shape_Op<"shape_eq",
196
+ [NoSideEffect, Commutative, InferTypeOpInterface]> {
180
197
let summary = "Returns whether the input shapes or extent tensors are equal";
181
198
let description = [{
182
199
Takes one or more shape or extent tensor operands and determines whether
@@ -290,7 +307,8 @@ def Shape_IsBroadcastableOp : Shape_Op<"is_broadcastable",
290
307
let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
291
308
}
292
309
293
- def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
310
+ def Shape_RankOp : Shape_Op<"rank",
311
+ [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
294
312
let summary = "Gets the rank of a shape";
295
313
let description = [{
296
314
Returns the rank of the shape or extent tensor, i.e. the number of extents.
@@ -304,6 +322,12 @@ def Shape_RankOp : Shape_Op<"rank", [NoSideEffect]> {
304
322
let hasFolder = 1;
305
323
let hasCanonicalizer = 1;
306
324
let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
325
+
326
+ let extraClassDeclaration = [{
327
+ // Returns when two result types are compatible for this op; method used by
328
+ // InferTypeOpInterface
329
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
330
+ }];
307
331
}
308
332
309
333
def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
@@ -324,7 +348,8 @@ def Shape_ToExtentTensorOp : Shape_Op<"to_extent_tensor", [NoSideEffect]> {
324
348
let hasFolder = 1;
325
349
}
326
350
327
- def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
351
+ def Shape_GetExtentOp : Shape_Op<"get_extent",
352
+ [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
328
353
let summary = "Gets the specified extent from a shape or extent tensor";
329
354
let description = [{
330
355
Gets the extent indexed by `dim` from the `shape` operand. If the shape is
@@ -344,6 +369,9 @@ def Shape_GetExtentOp : Shape_Op<"get_extent", [NoSideEffect]> {
344
369
let extraClassDeclaration = [{
345
370
/// Get the `dim` value as integer if it is constant.
346
371
Optional<int64_t> getConstantDim();
372
+ /// Returns when two result types are compatible for this op; method used by
373
+ /// InferTypeOpInterface
374
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
347
375
}];
348
376
349
377
let hasFolder = 1;
@@ -369,7 +397,8 @@ def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> {
369
397
let hasCanonicalizer = 1;
370
398
}
371
399
372
- def Shape_JoinOp : Shape_Op<"join", [Commutative]> {
400
+ def Shape_JoinOp : Shape_Op<"join",
401
+ [Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
373
402
let summary = "Returns the least general shape.shape of its operands";
374
403
let description = [{
375
404
An operation that computes the least general shape of input operands.
@@ -405,9 +434,17 @@ def Shape_JoinOp : Shape_Op<"join", [Commutative]> {
405
434
$arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`
406
435
type($arg0) `,` type($arg1) `->` type($result)
407
436
}];
437
+
438
+ let extraClassDeclaration = [{
439
+ // Returns when two result types are compatible for this op; method used by
440
+ // InferTypeOpInterface
441
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
442
+ }];
408
443
}
409
444
410
- def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> {
445
+ def Shape_MaxOp : Shape_Op<"max",
446
+ [Commutative, NoSideEffect,
447
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
411
448
let summary = "Elementwise maximum";
412
449
let description = [{
413
450
Computes the elementwise maximum of two sizes or shapes with equal ranks.
@@ -424,9 +461,17 @@ def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> {
424
461
}];
425
462
426
463
let hasFolder = 1;
464
+
465
+ let extraClassDeclaration = [{
466
+ // Returns when two result types are compatible for this op; method used by
467
+ // InferTypeOpInterface
468
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
469
+ }];
427
470
}
428
471
429
- def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> {
472
+ def Shape_MinOp : Shape_Op<"min",
473
+ [Commutative, NoSideEffect,
474
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
430
475
let summary = "Elementwise minimum";
431
476
let description = [{
432
477
Computes the elementwise minimum of two sizes or shapes with equal ranks.
@@ -443,9 +488,17 @@ def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> {
443
488
}];
444
489
445
490
let hasFolder = 1;
491
+
492
+ let extraClassDeclaration = [{
493
+ // Returns when two result types are compatible for this op; method used by
494
+ // InferTypeOpInterface
495
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
496
+ }];
446
497
}
447
498
448
- def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
499
+ def Shape_MulOp : Shape_Op<"mul",
500
+ [Commutative, NoSideEffect,
501
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
449
502
let summary = "Multiplication of sizes and indices";
450
503
let description = [{
451
504
Multiplies two sizes or indices. If either operand is an error it will be
@@ -465,9 +518,16 @@ def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
465
518
466
519
let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
467
520
let hasFolder = 1;
521
+
522
+ let extraClassDeclaration = [{
523
+ // Returns when two result types are compatible for this op; method used by
524
+ // InferTypeOpInterface
525
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
526
+ }];
468
527
}
469
528
470
- def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
529
+ def Shape_NumElementsOp : Shape_Op<"num_elements",
530
+ [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
471
531
let summary = "Returns the number of elements for a given shape";
472
532
let description = [{
473
533
Returns the number of elements for a given shape which is the product of its
@@ -480,12 +540,15 @@ def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> {
480
540
let arguments = (ins Shape_ShapeOrExtentTensorType:$shape);
481
541
let results = (outs Shape_SizeOrIndexType:$result);
482
542
483
- let builders = [OpBuilder<(ins "Value":$shape)>];
484
-
485
543
let assemblyFormat = "$shape attr-dict `:` type($shape) `->` type($result)";
486
544
487
545
let hasFolder = 1;
488
546
let verifier = [{ return ::verifySizeOrIndexOp(*this); }];
547
+ let extraClassDeclaration = [{
548
+ // Returns when two result types are compatible for this op; method used by
549
+ // InferTypeOpInterface
550
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
551
+ }];
489
552
}
490
553
491
554
def Shape_ReduceOp : Shape_Op<"reduce",
@@ -535,7 +598,8 @@ def Shape_ReduceOp : Shape_Op<"reduce",
535
598
let parser = [{ return ::parse$cppClass(parser, result); }];
536
599
}
537
600
538
- def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> {
601
+ def Shape_ShapeOfOp : Shape_Op<"shape_of",
602
+ [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
539
603
let summary = "Returns shape of a value or shaped type operand";
540
604
541
605
let description = [{
@@ -548,11 +612,15 @@ def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> {
548
612
549
613
let assemblyFormat = "$arg attr-dict `:` type($arg) `->` type($result)";
550
614
551
- let builders = [OpBuilder<(ins "Value":$arg)>];
552
-
553
615
let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
554
616
let hasCanonicalizer = 1;
555
617
let hasFolder = 1;
618
+
619
+ let extraClassDeclaration = [{
620
+ // Returns when two result types are compatible for this op; method used by
621
+ // InferTypeOpInterface
622
+ static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
623
+ }];
556
624
}
557
625
558
626
def Shape_SizeToIndexOp : Shape_Op<"size_to_index", [NoSideEffect]> {
0 commit comments