Skip to content

[mlir] Bug fixes in TOSA shape inference pass #104146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,36 @@ class TypeModificationState {
// For each use whose type changed, cast the value with the new type back to
// the old type.
for (auto [value, oldType] : oldTypes) {
tensor::CastOp castedValue;
for (auto &use : value.getUses()) {
if (canBeRefined(use.getOwner()))
// The call to 'use->set()' in the body of the loop below invalidates the
// iterator used to traverse op uses, so it is important to make a copy of
// these first.
llvm::SmallVector<OpOperand *> uses = llvm::map_to_vector(
value.getUses(),
[](OpOperand &use) -> OpOperand * {
return &use;
});

// A 'tensor.cast' op is emitted only if needed. Once emitted, it is
// cached and reused by all consumers.
tensor::CastOp castValue;

// Traverse all uses
for (OpOperand *use : uses) {
if (canBeRefined(use->getOwner()))
continue;

// Cache the cast to avoid generating duplicates
if (!castedValue) {
ImplicitLocOpBuilder builder{value.getLoc(), use.getOwner()};
castedValue = builder.create<tensor::CastOp>(oldType, value);
if (!castValue) {
// Set the insertion point as far back as possible, since new
// consumers of the 'tensor.cast' op generated in future iterations
// are likely to be further up in the code due to the order in which
// they appear in the use list.
OpBuilder builder{value.getContext()};
builder.setInsertionPointAfter(value.getDefiningOp());
castValue =
builder.create<tensor::CastOp>(value.getLoc(), oldType, value);
}

use.set(castedValue);
use->set(castValue);
}
}

Expand Down
56 changes: 55 additions & 1 deletion mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1373,4 +1373,58 @@ func.func @test_tosa_use_def_chain(%arg0: tensor<1x32x32x3xf32>, %arg1: tensor<1
// CHECK: (tensor<1x32x32x16xf32>) -> tensor<1x16x16x16xf32>
%1 = tosa.max_pool2d %0 {kernel = array<i64: 2, 2>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>} : (tensor<?x32x32x16xf32>) -> tensor<?x16x16x16xf32>
return %1 : tensor<?x16x16x16xf32>
}
}

// -----

// This test locks two bug fixes manifested in the code below.
//
// 1. Context
//
// When shape propagation hits an operation that does not support shape
// inference (here 'tensor.expand_shape'), it must revert the currently
// inferred shape of its consumers back to the originally expected input
// type to avoid potential op verification errors. This type reversal is
// done through an additional 'tensor.cast' op.
//
//
// 2. Preserving list of non-inferrable consumers
//
// When multiple non-inferrable consumers of a shape-inferred value are found
// (here, the 2 occurrences of 'tensor.expand_shape' consuming the output of
// 'tosa.cast'), their input argument ('%0') must be altered to consume the
// output the new 'tensor.cast' op. While these replacements occur, the use list
// of the producer ('tosa.cast') is also implicitly altered, invalidating any
// iterators associated with it. It is therefore necessary to create a copy of
// this use list ahead of time. Before this bug fix, the second
// 'tensor.expand_shape' op below was not updated correctly.
//
// 3. Guaranteeing def-use order
//
// When emitting the 'tensor.cast' op, it is important to guarantee that its
// output value is defined before all of its consumers (here, both of the
// 'tensor.expand_shape' ops. In a previous version of the code, this insertion
// occurred right before the first encountered consumer. Since use lists are
// saved in reverse order, the 'tensor.cast' op was inserted before the second
// 'tensor.expand_shape' op, leading to a def-use order violation when the
// first 'tensor.expand_shape' op was later updated. The current implementation
// sets the insertion point right after the producer of the last shape-inferred
// value (here 'tosa.cast'), which guarantees correct def-use order for all
// future operand updates.

// CHECK-LABEL: test_multiple_non_inferrable_consumers
// CHECK-SAME: %[[ARG:.*]]: tensor<1x2x8xf32>
func.func @test_multiple_non_inferrable_consumers(%arg0: tensor<1x2x8xf32>) {
// CHECK: %[[TOSA_CAST:.*]] = tosa.cast %[[ARG]] : (tensor<1x2x8xf32>) -> tensor<1x2x8xf32>
// CHECK: %[[TENSOR_CAST:.*]] = tensor.cast %[[TOSA_CAST]] : tensor<1x2x8xf32> to tensor<?x2x8xf32>
%0 = tosa.cast %arg0 : (tensor<1x2x8xf32>) -> tensor<?x2x8xf32>

%c0 = arith.constant 0 : index
%dim = tensor.dim %0, %c0 : tensor<?x2x8xf32>

// CHECK: tensor.expand_shape %[[TENSOR_CAST]]
// CHECK: tensor.expand_shape %[[TENSOR_CAST]]
%expanded_0 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32>
%expanded_1 = tensor.expand_shape %0 [[0], [1, 2], [3]] output_shape [%dim, 1, 4, 8] : tensor<?x2x8xf32> into tensor<?x1x2x8xf32>
return
}
Loading