Skip to content

Commit ef6e7af

Browse files
authored
[mlir] [tosa] Bug fixes in shape inference pass (llvm#104146)
This change addresses 2 bugs in the TOSA shape inference pass (`--tosa-infer-shapes`). The included unit test contains a detailed description of the issues. - Input IR ``` func.func @main(%arg0: tensor<1x2x8xf32>) { %0 = tosa.cast %arg0 : (tensor<1x2x8xf32>) -> tensor<?x2x8xf32> %c0 = arith.constant 0 : index %dim = tensor.dim %0, %c0 : tensor<?x2x8xf32> %expanded_0 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32> %expanded_1 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32> return } ``` - Output IR ``` module { func.func @main(%arg0: tensor<1x2x8xf32>) { %0 = tosa.cast %arg0 : (tensor<1x2x8xf32>) -> tensor<1x2x8xf32> // This cast was previously inserted between both 'tensor.expand_shape' ops. %cast = tensor.cast %0 : tensor<1x2x8xf32> to tensor<?x2x8xf32> %c0 = arith.constant 0 : index %dim = tensor.dim %0, %c0 : tensor<1x2x8xf32> // The operand of the first 'tensor.expand_shape' op was not previously updated // from '%0' to '%cast' due to an invalidation of the iterator traversing the // use list of the 'tosa.cast' op. %expanded_0 = tensor.expand_shape %cast [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32> %expanded_1 = tensor.expand_shape %cast [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32> return } ```
1 parent 99696b3 commit ef6e7af

File tree

2 files changed

+81
-9
lines changed

2 files changed

+81
-9
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,18 +88,36 @@ class TypeModificationState {
8888
// For each use whose type changed, cast the value with the new type back to
8989
// the old type.
9090
for (auto [value, oldType] : oldTypes) {
91-
tensor::CastOp castedValue;
92-
for (auto &use : value.getUses()) {
93-
if (canBeRefined(use.getOwner()))
91+
// The call to 'use->set()' in the body of the loop below invalidates the
92+
// iterator used to traverse op uses, so it is important to make a copy of
93+
// these first.
94+
llvm::SmallVector<OpOperand *> uses = llvm::map_to_vector(
95+
value.getUses(),
96+
[](OpOperand &use) -> OpOperand * {
97+
return &use;
98+
});
99+
100+
// A 'tensor.cast' op is emitted only if needed. Once emitted, it is
101+
// cached and reused by all consumers.
102+
tensor::CastOp castValue;
103+
104+
// Traverse all uses
105+
for (OpOperand *use : uses) {
106+
if (canBeRefined(use->getOwner()))
94107
continue;
95108

96-
// Cache the cast to avoid generating duplicates
97-
if (!castedValue) {
98-
ImplicitLocOpBuilder builder{value.getLoc(), use.getOwner()};
99-
castedValue = builder.create<tensor::CastOp>(oldType, value);
109+
if (!castValue) {
110+
// Set the insertion point as far back as possible, since new
111+
// consumers of the 'tensor.cast' op generated in future iterations
112+
// are likely to be further up in the code due to the order in which
113+
// they appear in the use list.
114+
OpBuilder builder{value.getContext()};
115+
builder.setInsertionPointAfter(value.getDefiningOp());
116+
castValue =
117+
builder.create<tensor::CastOp>(value.getLoc(), oldType, value);
100118
}
101119

102-
use.set(castedValue);
120+
use->set(castValue);
103121
}
104122
}
105123

mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1373,4 +1373,58 @@ func.func @test_tosa_use_def_chain(%arg0: tensor<1x32x32x3xf32>, %arg1: tensor<1
13731373
// CHECK: (tensor<1x32x32x16xf32>) -> tensor<1x16x16x16xf32>
13741374
%1 = tosa.max_pool2d %0 {kernel = array<i64: 2, 2>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>} : (tensor<?x32x32x16xf32>) -> tensor<?x16x16x16xf32>
13751375
return %1 : tensor<?x16x16x16xf32>
1376-
}
1376+
}
1377+
1378+
// -----
1379+
1380+
// This test locks two bug fixes manifested in the code below.
1381+
//
1382+
// 1. Context
1383+
//
1384+
// When shape propagation hits an operation that does not support shape
1385+
// inference (here 'tensor.expand_shape'), it must revert the currently
1386+
// inferred shape of its consumers back to the originally expected input
1387+
// type to avoid potential op verification errors. This type reversal is
1388+
// done through an additional 'tensor.cast' op.
1389+
//
1390+
//
1391+
// 2. Preserving list of non-inferrable consumers
1392+
//
1393+
// When multiple non-inferrable consumers of a shape-inferred value are found
1394+
// (here, the 2 occurrences of 'tensor.expand_shape' consuming the output of
1395+
// 'tosa.cast'), their input argument ('%0') must be altered to consume the
1396+
// output the new 'tensor.cast' op. While these replacements occur, the use list
1397+
// of the producer ('tosa.cast') is also implicitly altered, invalidating any
1398+
// iterators associated with it. It is therefore necessary to create a copy of
1399+
// this use list ahead of time. Before this bug fix, the second
1400+
// 'tensor.expand_shape' op below was not updated correctly.
1401+
//
1402+
// 3. Guaranteeing def-use order
1403+
//
1404+
// When emitting the 'tensor.cast' op, it is important to guarantee that its
1405+
// output value is defined before all of its consumers (here, both of the
1406+
// 'tensor.expand_shape' ops. In a previous version of the code, this insertion
1407+
// occurred right before the first encountered consumer. Since use lists are
1408+
// saved in reverse order, the 'tensor.cast' op was inserted before the second
1409+
// 'tensor.expand_shape' op, leading to a def-use order violation when the
1410+
// first 'tensor.expand_shape' op was later updated. The current implementation
1411+
// sets the insertion point right after the producer of the last shape-inferred
1412+
// value (here 'tosa.cast'), which guarantees correct def-use order for all
1413+
// future operand updates.
1414+
1415+
// CHECK-LABEL: test_multiple_non_inferrable_consumers
1416+
// CHECK-SAME: %[[ARG:.*]]: tensor<1x2x8xf32>
1417+
func.func @test_multiple_non_inferrable_consumers(%arg0: tensor<1x2x8xf32>) {
1418+
// CHECK: %[[TOSA_CAST:.*]] = tosa.cast %[[ARG]] : (tensor<1x2x8xf32>) -> tensor<1x2x8xf32>
1419+
// CHECK: %[[TENSOR_CAST:.*]] = tensor.cast %[[TOSA_CAST]] : tensor<1x2x8xf32> to tensor<?x2x8xf32>
1420+
%0 = tosa.cast %arg0 : (tensor<1x2x8xf32>) -> tensor<?x2x8xf32>
1421+
1422+
%c0 = arith.constant 0 : index
1423+
%dim = tensor.dim %0, %c0 : tensor<?x2x8xf32>
1424+
1425+
// CHECK: tensor.expand_shape %[[TENSOR_CAST]]
1426+
// CHECK: tensor.expand_shape %[[TENSOR_CAST]]
1427+
%expanded_0 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32>
1428+
%expanded_1 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32>
1429+
return
1430+
}

0 commit comments

Comments
 (0)