Skip to content

Commit bd077e9

Browse files
rafaelubalmwjpienaar
authored andcommitted
Released restriction that prevented implicit dynamic-to-static dimension type cast in TOSA ops.
Reviewed By: jpienaar, gflegar Differential Revision: https://reviews.llvm.org/D156714
1 parent ea90e28 commit bd077e9

File tree

4 files changed

+43
-27
lines changed

4 files changed

+43
-27
lines changed

mlir/docs/Traits/Broadcastable.md

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,17 @@ Given the shapes of two ranked input operands, the result's shape is inferred by
6969
```python
7070
InferShape(shape0, shape1):
7171

72-
# Equalize ranks
73-
rank = max(GetRank(shape0), GetRank(shape1))
74-
ExpandRank(shape0, rank)
75-
ExpandRank(shape1, rank)
72+
# Equalize ranks
73+
rank = max(GetRank(shape0), GetRank(shape1))
74+
ExpandRank(shape0, rank)
75+
ExpandRank(shape1, rank)
7676

77-
# Infer shape
78-
inferredShape = []
79-
for (dim0, dim1) in zip(shape0, shape1):
80-
inferredDim = InferDim(dim0, dim1)
81-
inferredShape.append(inferredDim)
82-
return inferredShape
77+
# Infer shape
78+
inferredShape = []
79+
for (dim0, dim1) in zip(shape0, shape1):
80+
inferredDim = InferDim(dim0, dim1)
81+
inferredShape.append(inferredDim)
82+
return inferredShape
8383
```
8484

8585
The result shape for an operation with an arbitrary number of input operands is then inferred by discarding unranked operands, applying shape inference on the first ranked operand pair, and updating the inferred shape with each additional ranked operand. If the operation has no ranked operands, the result shape cannot be inferred. If the operation has exactly one ranked operand, its shape is directly provided as the inferred result shape. Formally:
@@ -111,7 +111,7 @@ Once a rank match is guaranteed, each dimension of the inferred shape is compare
111111
| `inferredDim` | `actualDim` | Verification outcome |
112112
| ------------- | ----------- | -------------------- |
113113
| ? | ? | **OK** |
114-
| ? | static | **Error** <br> An inferred dimension being dynamic indicates that its size cannot be inferred at compile time from its input operands. The presence of a static dimension in the actual result is counterintuitive and is therefore not allowed. |
114+
| ? | static | **OK** <br> A failure to guarantee that the runtime dimension size of the result is equal to `actualDim` causes undefined behavior. While unusual, this implicit dynamic-to-static cast is convenient in certain scenarios, such as an intermediate state of a shape inference pass. Ultimately, a static dimension in the result implies that all input dimension sizes are also known at compile time and may therefore become static as well, preferably. |
115115
| static | ? | **OK** <br> The actual result dimension may be dynamic even when a static size can be inferred at compile time. The programmer may choose to relax the specificity of the result dimension for forward compatibility of the result type. |
116116
| static | static | **OK if equal** <br> When both the inferred and actual dimensions are static, they must be set to the same size. |
117117

@@ -134,7 +134,6 @@ Verify(op):
134134

135135
# Verify
136136
for (inferredDim, actualDim) in zip(inferredShape, actualShape):
137-
ERROR_IF(IsDynamic(inferredDim) and IsStatic(actualDim))
138137
ERROR_IF(IsStatic(actualDim) and inferredDim != actualDim)
139138
```
140139
@@ -195,3 +194,5 @@ The following are incorrect uses of broadcastable ops:
195194
// tensor<4xi32>. Broadcast semantics are not applicable for results.
196195
%result = "test.broadcastable"(%arg0, %arg1) : (tensor<1xi32>, tensor<1xi32) -> tensor<4xi32>
197196
```
197+
198+

mlir/lib/Dialect/Traits.cpp

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -195,17 +195,10 @@ static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
195195

196196
static bool isCompatibleInferredReturnShape(ArrayRef<int64_t> inferred,
197197
ArrayRef<int64_t> existing) {
198+
// If both interred and existing dimensions are static, they must be equal.
198199
auto isCompatible = [](int64_t inferredDim, int64_t existingDim) {
199-
// The following criterion is used to determine the validity of an existing
200-
// dimension:
201-
//
202-
// inferredDim existingDim Behavior
203-
// ----------- ----------- --------
204-
// dynamic dynamic OK
205-
// dynamic static Error
206-
// static dynamic OK
207-
// static static OK if equal
208-
return ShapedType::isDynamic(existingDim) || inferredDim == existingDim;
200+
return ShapedType::isDynamic(existingDim) ||
201+
ShapedType::isDynamic(inferredDim) || inferredDim == existingDim;
209202
};
210203
if (inferred.size() != existing.size())
211204
return false;

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,45 @@ func.func @test_abs_scalar(%arg0: tensor<f32>) -> tensor<f32> {
2020
// -----
2121

2222
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
23-
// CHECK-LABEL: @test_abs_1d_cast_result
23+
// CHECK-LABEL: @test_abs_1d_cast_static_to_dynamic
2424
// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]
25-
func.func @test_abs_1d_cast_result(%arg0: tensor<5xf32>) -> tensor<?xf32> {
25+
func.func @test_abs_1d_cast_static_to_dynamic(%arg0: tensor<5xf32>) -> tensor<?xf32> {
2626
// CHECK: [[EMPTY:%.+]] = tensor.empty() : tensor<5xf32>
2727
// CHECK: [[RESULT:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[ARG0]] : tensor<5xf32>) outs([[EMPTY]] : tensor<5xf32>) {
2828
// CHECK: ^bb0([[IN0:%.+]]: f32, [[OUT0:%.+]]: f32):
2929
// CHECK: [[ABS:%.+]] = math.absf [[IN0]] : f32
3030
// CHECK: linalg.yield [[ABS]] : f32
3131
// CHECK: } -> tensor<5xf32>
32+
// CHECK: [[CAST_RESULT:%.+]] = tensor.cast [[RESULT]] : tensor<5xf32> to tensor<?xf32>
3233
%0 = "tosa.abs"(%arg0) : (tensor<5xf32>) -> tensor<?xf32>
3334

34-
// CHECK: [[CAST_RESULT:%.+]] = tensor.cast [[RESULT]] : tensor<5xf32> to tensor<?xf32>
3535
// CHECK: return [[CAST_RESULT]] : tensor<?xf32>
3636
return %0 : tensor<?xf32>
3737
}
3838

3939
// -----
4040

41+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
42+
// CHECK-LABEL: @test_abs_1d_cast_dynamic_to_static
43+
// CHECK-SAME: (%[[ARG0:[0-9a-zA-Z_]*]]
44+
func.func @test_abs_1d_cast_dynamic_to_static(%arg0: tensor<?xf32>) -> tensor<5xf32> {
45+
// CHECK: %[[ZERO:.*]] = arith.constant 0 : index
46+
// CHECK: %[[DIM_SIZE:.*]] = tensor.dim %[[ARG0]], %[[ZERO]] : tensor<?xf32>
47+
// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM_SIZE]]) : tensor<?xf32>
48+
// CHECK: %[[RESULT:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins(%[[ARG0]] : tensor<?xf32>) outs(%[[EMPTY]] : tensor<?xf32>) {
49+
// CHECK: ^bb0(%[[VAL_0:.*]]: f32, %[[VAL_1:.*]]: f32):
50+
// CHECK: %[[VAL_2:.*]] = math.absf %[[VAL_0]] : f32
51+
// CHECK: linalg.yield %[[VAL_2]] : f32
52+
// CHECK: } -> tensor<?xf32>
53+
// CHECK: %[[CAST_RESULT:.*]] = tensor.cast %[[RESULT]] : tensor<?xf32> to tensor<5xf32>
54+
%0 = "tosa.abs"(%arg0) : (tensor<?xf32>) -> tensor<5xf32>
55+
56+
// CHECK: return %[[CAST_RESULT]] : tensor<5xf32>
57+
return %0 : tensor<5xf32>
58+
}
59+
60+
// -----
61+
4162
// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)>
4263
// CHECK-LABEL: @test_abs_1d_dynamic
4364
// CHECK-SAME: ([[ARG0:%[0-9a-zA-Z_]*]]

mlir/test/Dialect/traits.mlir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,10 @@ func.func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<?xi32>) -> t
111111

112112
// -----
113113

114-
// Error for inferred dynamic dimension but existing static dimensions
114+
// It is alright to have an implicit dynamic-to-static cast in a dimension size
115+
// as long as the runtime result size is consistent with the result tensor's
116+
// static dimension.
115117
func.func @broadcast_tensor_tensor_tensor(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<2xi32> {
116-
// expected-error @+1 {{op result type '2' not broadcast compatible with broadcasted operands's shapes '?'}}
117118
%0 = "test.broadcastable"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<2xi32>
118119
return %0 : tensor<2xi32>
119120
}

0 commit comments

Comments
 (0)