Skip to content

Commit 6983cf3

Browse files
committed
[MLIR][Shape] Allow unsafe shape.broadcast
In a context in which `shape.broadcast` is known not to produce an error value, we want it to operate solely on extent tensors. The operation's behavior is then undefined in the error case as the result type cannot hold this value. Differential Revision: https://reviews.llvm.org/D84933
1 parent 2da9b44 commit 6983cf3

File tree

3 files changed

+53
-24
lines changed

3 files changed

+53
-24
lines changed

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -49,34 +49,36 @@ def Shape_AddOp : Shape_Op<"add", [Commutative, NoSideEffect]> {
4949
def Shape_BroadcastOp : Shape_Op<"broadcast", [Commutative]> {
5050
let summary = "Returns the broadcasted output shape of two inputs";
5151
let description = [{
52-
Computes the broadcasted output shape following:
53-
1. If any inputs are unranked, output is unranked;
54-
2. Else the input array with number of dimensions smaller than the max
55-
input dimension, has 1’s prepended to its shapes and the output shape is
56-
calculated as follows:
57-
58-
output[i] = lhs[i] if lhs[i] == rhs[i] or rhs[i] is unknown/undefined
59-
= rhs[i] if lhs[i] is unknown/undefined
60-
= lhs[i] if rhs[i] == 1
61-
= rhs[i] if lhs[i] == 1
62-
= error if lhs[i] != rhs[i]
63-
64-
Op has an optional string attribute for the error case where there is no
65-
broadcastable output shape possible for the given inputs.
66-
67-
Op may also return an ExtentTensor, but this should only be done when this
68-
is statically guaranteed to never fail, either because of a dependency on a
69-
cstr_broadcastable operation or other details of the construction of the
70-
program.
52+
Returns the broadcasted shape for two input shapes or extent tensors. Both
53+
operands can be of type `shape.shape` or `tensor<?xindex>`. The result is of
54+
type `shape.shape` and, if both operands are tensors, may be of type
55+
`tensor<?xindex>`.
56+
57+
If the two operand shapes are of different rank the smaller one is padded
58+
with 1's from the left. The resulting broadcasted shape is then defined as
59+
60+
result[i] = lhs[i] if lhs[i] == rhs[i]
61+
= lhs[i] if rhs[i] == 1
62+
= rhs[i] if lhs[i] == 1.
63+
64+
In case the resulting shape is undefined, i.e. if corresponding extents are
65+
different from each other but none is 1, the result is an error shape.
66+
Likewise error values are propagated if any of the operands holds an error
67+
value. If the result type is an extent tensor (and can therefore not hold
68+
the error value) the behavior may be undefined. The optional string
69+
attribute can be used to describe the error case.
7170
}];
7271

7372
let arguments = (ins Shape_ShapeOrExtentTensorType:$lhs,
7473
Shape_ShapeOrExtentTensorType:$rhs,
7574
OptionalAttr<StrAttr>:$error);
7675
let results = (outs Shape_ShapeOrExtentTensorType:$result);
7776

78-
let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)";
77+
let assemblyFormat = [{
78+
$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
79+
}];
7980

81+
let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];
8082
let hasFolder = 1;
8183

8284
let verifier = [{ return ::verifyShapeOrExtentTensorOp(*this); }];

mlir/test/Dialect/Shape/canonicalize.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,31 @@ func @f() -> !shape.shape {
6060

6161
// -----
6262

63+
// Basic case including extent tensors.
64+
// CHECK-LABEL: @broadcast
65+
func @broadcast() -> tensor<?xindex> {
66+
// CHECK: shape.const_shape [7, 2] : tensor<?xindex>
67+
%0 = shape.const_shape [1, 2] : tensor<?xindex>
68+
%1 = shape.const_shape [7, 1] : tensor<?xindex>
69+
%2 = shape.broadcast %0, %1
70+
: tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
71+
return %2 : tensor<?xindex>
72+
}
73+
74+
// -----
75+
76+
// Basic case including extent tensors.
77+
// CHECK-LABEL: @broadcast
78+
func @broadcast() -> !shape.shape {
79+
// CHECK: shape.const_shape [7, 2] : !shape.shape
80+
%0 = shape.const_shape [1, 2] : tensor<?xindex>
81+
%1 = shape.const_shape [7, 1] : tensor<?xindex>
82+
%2 = shape.broadcast %0, %1 : tensor<?xindex>, tensor<?xindex> -> !shape.shape
83+
return %2 : !shape.shape
84+
}
85+
86+
// -----
87+
6388
// Rhs is a scalar.
6489
// CHECK-LABEL: func @f
6590
func @f(%arg0 : !shape.shape) -> !shape.shape {

mlir/test/Dialect/Shape/invalid.mlir

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,17 +138,19 @@ func @add(%lhs : !shape.size, %rhs : index) -> index {
138138

139139
// -----
140140

141-
func @broadcast_error_possible(%arg0 : !shape.shape, %arg1 : !shape.shape) -> tensor<?xindex> {
141+
func @broadcast(%arg0 : !shape.shape, %arg1 : !shape.shape) -> tensor<?xindex> {
142142
// expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `shape` to propagate them}}
143-
%result = shape.broadcast %arg0, %arg1 : !shape.shape, !shape.shape -> tensor<?xindex>
143+
%result = shape.broadcast %arg0, %arg1
144+
: !shape.shape, !shape.shape -> tensor<?xindex>
144145
return %result : tensor<?xindex>
145146
}
146147

147148

148149
// -----
149150

150-
func @broadcast_error_possible(%arg0 : !shape.shape, %arg1 : tensor<?xindex>) -> tensor<?xindex> {
151+
func @broadcast(%arg0 : !shape.shape, %arg1 : tensor<?xindex>) -> tensor<?xindex> {
151152
// expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `shape` to propagate them}}
152-
%result = shape.broadcast %arg0, %arg1 : !shape.shape, tensor<?xindex> -> tensor<?xindex>
153+
%result = shape.broadcast %arg0, %arg1
154+
: !shape.shape, tensor<?xindex> -> tensor<?xindex>
153155
return %result : tensor<?xindex>
154156
}

0 commit comments

Comments
 (0)