Skip to content

Commit c714b44

Browse files
committed
[mlir][Shape] Make cstr_eq more like cstr_broadcastable
This includes allowing extents and not just shapes. Differential Revision: https://reviews.llvm.org/D97716
1 parent cc3d25b commit c714b44

File tree

4 files changed

+27
-10
lines changed

4 files changed

+27
-10
lines changed

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,7 @@ def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable",
783783
let verifier = [{ return ::verify(*this); }];
784784
}
785785

786-
def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> {
786+
def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative, InferTypeOpInterface]> {
787787
let summary = "Determines if all input shapes are equal";
788788
let description = [{
789789
Given 1 or more input shapes, determine if all shapes are the exact same.
@@ -796,10 +796,21 @@ def Shape_CstrEqOp : Shape_Op<"cstr_eq", [Commutative]> {
796796
%w1 = shape.cstr_eq [2,2], [1,2] // Failure
797797
```
798798
}];
799-
let arguments = (ins Variadic<Shape_ShapeType>:$inputs);
799+
let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$shapes);
800800
let results = (outs Shape_WitnessType:$result);
801801

802-
let assemblyFormat = "$inputs attr-dict";
802+
let assemblyFormat = "$shapes attr-dict `:` type($shapes)";
803+
804+
let extraClassDeclaration = [{
805+
// TODO: This should really be automatic. Figure out how to not need this defined.
806+
static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context,
807+
::llvm::Optional<::mlir::Location> location, ::mlir::ValueRange operands,
808+
::mlir::DictionaryAttr attributes, ::mlir::RegionRange regions,
809+
::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes) {
810+
inferredReturnTypes.push_back(::mlir::shape::WitnessType::get(context));
811+
return success();
812+
};
813+
}];
803814

804815
let hasCanonicalizer = 1;
805816
let hasFolder = 1;

mlir/test/Dialect/Shape/canonicalize.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ func @f(%arg0 : !shape.shape) {
360360
// CHECK-NEXT: shape.const_witness true
361361
// CHECK-NEXT: consume.witness
362362
// CHECK-NEXT: return
363-
%0 = shape.cstr_eq %arg0, %arg0, %arg0
363+
%0 = shape.cstr_eq %arg0, %arg0, %arg0 : !shape.shape, !shape.shape, !shape.shape
364364
"consume.witness"(%0) : (!shape.witness) -> ()
365365
return
366366
}
@@ -375,7 +375,7 @@ func @f() {
375375
%cs0 = shape.const_shape [0, 1] : !shape.shape
376376
%cs1 = shape.const_shape [0, 1] : !shape.shape
377377
%cs2 = shape.const_shape [0, 1] : !shape.shape
378-
%0 = shape.cstr_eq %cs0, %cs1, %cs2
378+
%0 = shape.cstr_eq %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape
379379
"consume.witness"(%0) : (!shape.witness) -> ()
380380
return
381381
}
@@ -391,7 +391,7 @@ func @f() {
391391
// CHECK-NEXT: return
392392
%cs0 = shape.const_shape [0, 1] : !shape.shape
393393
%cs1 = shape.const_shape [3, 1] : !shape.shape
394-
%0 = shape.cstr_eq %cs0, %cs1
394+
%0 = shape.cstr_eq %cs0, %cs1 : !shape.shape, !shape.shape
395395
"consume.witness"(%0) : (!shape.witness) -> ()
396396
return
397397
}
@@ -403,7 +403,7 @@ func @f(%arg0: !shape.shape, %arg1: !shape.shape) {
403403
// CHECK-NEXT: shape.cstr_eq
404404
// CHECK-NEXT: consume.witness
405405
// CHECK-NEXT: return
406-
%0 = shape.cstr_eq %arg0, %arg1
406+
%0 = shape.cstr_eq %arg0, %arg1 : !shape.shape, !shape.shape
407407
"consume.witness"(%0) : (!shape.witness) -> ()
408408
return
409409
}

mlir/test/Dialect/Shape/ops.mlir

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ func @test_constraints() {
102102
%1 = shape.const_shape [1, 2, 3] : !shape.shape
103103
%true = constant true
104104
%w0 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
105-
%w1 = shape.cstr_eq %0, %1
105+
%w1 = shape.cstr_eq %0, %1 : !shape.shape, !shape.shape
106106
%w2 = shape.const_witness true
107107
%w3 = shape.const_witness false
108108
%w4 = shape.cstr_require %true, "msg"
@@ -114,6 +114,12 @@ func @test_constraints() {
114114
return
115115
}
116116

117+
func @eq_on_extent_tensors(%lhs : tensor<?xindex>,
118+
%rhs : tensor<?xindex>) {
119+
%w0 = shape.cstr_eq %lhs, %rhs : tensor<?xindex>, tensor<?xindex>
120+
return
121+
}
122+
117123
func @broadcastable_on_extent_tensors(%lhs : tensor<?xindex>,
118124
%rhs : tensor<?xindex>) {
119125
%w0 = shape.cstr_broadcastable %lhs, %rhs : tensor<?xindex>, tensor<?xindex>

mlir/test/Dialect/Shape/remove-shape-constraints.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
2929
// REPLACE: shape.assuming %[[WITNESS]]
3030
// CANON-NEXT: test.source
3131
// CANON-NEXT: return
32-
%0 = shape.cstr_eq %arg0, %arg1
32+
%0 = shape.cstr_eq %arg0, %arg1 : !shape.shape, !shape.shape
3333
%1 = shape.assuming %0 -> index {
3434
%2 = "test.source"() : () -> (index)
3535
shape.assuming_yield %2 : index
@@ -46,7 +46,7 @@ func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> index {
4646
// CANON-NEXT: test.source
4747
// CANON-NEXT: return
4848
%0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape
49-
%1 = shape.cstr_eq %arg0, %arg1
49+
%1 = shape.cstr_eq %arg0, %arg1 : !shape.shape, !shape.shape
5050
%2 = shape.assuming_all %0, %1
5151
%3 = shape.assuming %0 -> index {
5252
%4 = "test.source"() : () -> (index)

0 commit comments

Comments
 (0)