Skip to content

Commit 7f600da

Browse files
committed
[MLIR][Shape] Allow shape.any to operate on extent tensors
Differential Revision: https://reviews.llvm.org/D84433
1 parent 274db1d commit 7f600da

File tree

4 files changed

+43
-19
lines changed

4 files changed

+43
-19
lines changed

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -509,11 +509,14 @@ def Shape_ConcatOp : Shape_Op<"concat", []> {
509509
//===----------------------------------------------------------------------===//
510510

511511
// TODO: Move the code below and witnesses to a different file.
512-
def Shape_AnyOp : Shape_Op<"any", [Commutative, NoSideEffect]> {
512+
def Shape_AnyOp : Shape_Op<"any", [Commutative,
513+
NoSideEffect,
514+
SameOperandsAndResultType]> {
513515
let summary = "Return any combination of the input shapes";
514516
let description = [{
515-
This operation takes multiple input shapes and returns some combination of
516-
their dimensions. This can be best seen with examples below.
517+
This operation takes multiple input shapes or extent tensors and returns
518+
some combination of their dimensions. This can be best seen with examples
519+
below.
517520

518521
The result is undefined, but still side-effect free, in cases where the
519522
inputs have differing ranks or differ in extents of shared dimensions.
@@ -525,11 +528,10 @@ def Shape_AnyOp : Shape_Op<"any", [Commutative, NoSideEffect]> {
525528
```
526529
}];
527530

528-
let arguments = (ins Variadic<Shape_ShapeType>:$inputs);
529-
let results = (outs Shape_ShapeType:$result);
530-
531-
let assemblyFormat = "$inputs attr-dict";
531+
let arguments = (ins Variadic<Shape_ShapeOrExtentTensorType>:$inputs);
532+
let results = (outs Shape_ShapeOrExtentTensorType:$result);
532533

534+
let assemblyFormat = "$inputs `:` type($result) attr-dict";
533535
let hasFolder = 1;
534536
}
535537

mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,21 +165,22 @@ func @get_extent_from_extent_tensor(%extents : tensor<?xindex>,
165165
// Lower `any` to its first operand.
166166
// CHECK-LABEL: @any_of_three
167167
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>, %[[B:.*]]: tensor<?xindex>, %[[C:.*]]: tensor<?xindex>) -> tensor<?xindex>
168-
func @any_of_three(%a : !shape.shape, %b : !shape.shape, %c : !shape.shape)
169-
-> !shape.shape {
168+
func @any_of_three(%a : tensor<?xindex>,
169+
%b : tensor<?xindex>,
170+
%c : tensor<?xindex>) -> tensor<?xindex> {
170171
// CHECK: return %[[A]] : tensor<?xindex>
171-
%result = shape.any %a, %b, %c
172-
return %result : !shape.shape
172+
%result = shape.any %a, %b, %c : tensor<?xindex>
173+
return %result : tensor<?xindex>
173174
}
174175

175176
// -----
176177

177178
// Lower `any` to its first operand.
178179
// CHECK-LABEL: @any_of_one
179180
// CHECK-SAME: (%[[A:.*]]: tensor<?xindex>) -> tensor<?xindex>
180-
func @any_of_one(%a : !shape.shape) -> !shape.shape {
181+
func @any_of_one(%a : tensor<?xindex>) -> tensor<?xindex> {
181182
// CHECK: return %[[A]] : tensor<?xindex>
182-
%result = shape.any %a
183-
return %result : !shape.shape
183+
%result = shape.any %a : tensor<?xindex>
184+
return %result : tensor<?xindex>
184185
}
185186

mlir/test/Dialect/Shape/canonicalize.mlir

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,14 +364,25 @@ func @f() {
364364

365365
// any can be replaced with a constant input if it has one.
366366
// CHECK-LABEL: func @f
367-
func @f(%arg0 : !shape.shape) -> !shape.shape {
367+
func @f(%arg : !shape.shape) -> !shape.shape {
368368
// CHECK-NEXT: %[[CS:.*]] = shape.const_shape
369369
// CHECK-NEXT: return %[[CS]]
370370
%0 = shape.const_shape [2, 3, 4] : !shape.shape
371-
%1 = shape.any %0, %arg0
371+
%1 = shape.any %0, %arg : !shape.shape
372372
return %1 : !shape.shape
373373
}
374374

375+
// -----
376+
377+
// any can be replaced with a constant input if it has one.
378+
// CHECK-LABEL: func @f
379+
func @f(%arg : tensor<?xindex>) -> tensor<?xindex> {
380+
// CHECK-NEXT: %[[CS:.*]] = shape.const_shape [2, 3, 4] : tensor<?xindex>
381+
// CHECK-NEXT: return %[[CS]] : tensor<?xindex>
382+
%0 = shape.const_shape [2, 3, 4] : tensor<?xindex>
383+
%1 = shape.any %0, %arg : tensor<?xindex>
384+
return %1 : tensor<?xindex>
385+
}
375386

376387
// -----
377388

@@ -380,7 +391,7 @@ func @f(%arg0 : !shape.shape) -> !shape.shape {
380391
func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape {
381392
// CHECK-NEXT: %[[CS:.*]] = shape.any
382393
// CHECK-NEXT: return %[[CS]]
383-
%1 = shape.any %arg0, %arg1
394+
%1 = shape.any %arg0, %arg1 : !shape.shape
384395
return %1 : !shape.shape
385396
}
386397

mlir/test/Dialect/Shape/ops.mlir

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
// RUN: mlir-opt -split-input-file %s | mlir-opt | FileCheck %s
21
// Verify the printed output can be parsed.
32
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
43
// Verify the generic form can be parsed.
@@ -99,7 +98,7 @@ func @test_constraints() {
9998
%w3 = shape.const_witness false
10099
%w4 = shape.assuming_all %w0, %w1, %w2, %w3
101100
shape.assuming %w4 -> !shape.shape {
102-
%2 = shape.any %0, %1
101+
%2 = shape.any %0, %1 : !shape.shape
103102
shape.assuming_yield %2 : !shape.shape
104103
}
105104
return
@@ -173,3 +172,14 @@ func @get_extent_on_extent_tensor(%arg : tensor<?xindex>) -> !shape.size {
173172
%result = shape.get_extent %arg, %c0 : tensor<?xindex>
174173
return %result : !shape.size
175174
}
175+
176+
func @any() {
177+
%0 = shape.const_shape [1, 2, 3] : !shape.shape
178+
%1 = shape.const_shape [4, 5, 6] : !shape.shape
179+
%2 = shape.any %0, %1 : !shape.shape
180+
%3 = shape.const_shape [1, 2, 3] : tensor<?xindex>
181+
%4 = shape.const_shape [4, 5, 6] : tensor<?xindex>
182+
%5 = shape.any %3, %4 : tensor<?xindex>
183+
return
184+
}
185+

0 commit comments

Comments
 (0)