Skip to content

Commit 98aad40

Browse files
sabaumaeric-k256
authored andcommitted
[tosa] Improve inferred shapes of TOSA operations
The TosaInferShapes pass avoids updating the shapes of tensor operators when the consumers are not TOSA operations, limiting the efficacy of TosaInferShapes when the IR is a mix of TOSA and other operations. This change attempts to update the result shapes when the consumers themselves have reasonable type/shape inference methods. Reviewed By: eric-k256 Differential Revision: https://reviews.llvm.org/D151228
1 parent 0aa7ea4 commit 98aad40

File tree

2 files changed

+43
-15
lines changed

2 files changed

+43
-15
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/IR/BuiltinOps.h"
2222
#include "mlir/IR/IRMapping.h"
2323
#include "mlir/IR/Matchers.h"
24+
#include "mlir/Interfaces/InferTypeOpInterface.h"
2425
#include "mlir/Pass/Pass.h"
2526
#include "mlir/Transforms/DialectConversion.h"
2627
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -201,6 +202,16 @@ void propagateShapesInRegion(Region &region) {
201202
return it->second;
202203
};
203204

205+
// Check whether this use case is replaceable. We define an op as
206+
// being replaceable if it is used by a ReturnOp, a TosaOp, or an op with a
207+
// type-inference related interface.
208+
auto isReplaceableUser = [](Operation *user) -> bool {
209+
return isa<func::ReturnOp>(user) ||
210+
user->getDialect()->getNamespace() ==
211+
TosaDialect::getDialectNamespace() ||
212+
isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
213+
};
214+
204215
for (auto &block : region) {
205216
for (Operation &op : block) {
206217
if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
@@ -227,18 +238,8 @@ void propagateShapesInRegion(Region &region) {
227238
Value result = std::get<0>(it);
228239
ShapedTypeComponents predictedShape = std::get<1>(it);
229240

230-
// Check whether this use case is replaceable. We define an op as
231-
// being replaceable if it is used by a ReturnOp or a TosaOp.
232-
bool replaceable = true;
233-
for (auto *user : result.getUsers()) {
234-
if (isa<func::ReturnOp>(user))
235-
continue;
236-
if (user->getDialect()->getNamespace() ==
237-
TosaDialect::getDialectNamespace())
238-
continue;
239-
240-
replaceable = false;
241-
}
241+
if (!llvm::all_of(result.getUsers(), isReplaceableUser))
242+
continue;
242243

243244
// Determine the knowledge based on the output type.
244245
// TODO: should also query WIP type probably
@@ -256,9 +257,6 @@ void propagateShapesInRegion(Region &region) {
256257
}
257258
}
258259

259-
if (!replaceable)
260-
continue;
261-
262260
// Compute the new type based on the joined version.
263261
auto newKnowledge =
264262
ValueKnowledge::join(currentKnowledge, inferredKnowledge);

mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,3 +1237,33 @@ func.func @test_unranked_equal(%arg0 : tensor<*xf32>, %arg1 : tensor<f32>) -> ()
12371237

12381238
return
12391239
}
1240+
1241+
// -----
1242+
1243+
// CHECK-LABEL: test_non_tosa_consumer_shape
1244+
func.func @test_non_tosa_consumer_shape(%arg0: tensor<4x4xf32>) -> !shape.shape {
1245+
// CHECK: "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
1246+
%0 = "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<*xf32>
1247+
%1 = shape.shape_of %0 : tensor<*xf32> -> !shape.shape
1248+
return %1 : !shape.shape
1249+
}
1250+
1251+
// -----
1252+
1253+
// CHECK-LABEL: test_non_tosa_consumer_shape2
1254+
func.func @test_non_tosa_consumer_shape2(%arg0: tensor<4x4xf32>) -> tensor<?xindex> {
1255+
// CHECK: "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
1256+
%0 = "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<*xf32>
1257+
%1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
1258+
return %1 : tensor<?xindex>
1259+
}
1260+
1261+
// -----
1262+
1263+
// CHECK-LABEL: test_non_tosa_consumer_extract
1264+
func.func @test_non_tosa_consumer_extract(%arg0: tensor<4x4xf32>, %arg1: index) -> f32 {
1265+
// CHECK: "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xf32>
1266+
%0 = "tosa.log"(%arg0) : (tensor<4x4xf32>) -> tensor<?x?xf32>
1267+
%1 = tensor.extract %0[%arg1, %arg1] : tensor<?x?xf32>
1268+
return %1 : f32
1269+
}

0 commit comments

Comments
 (0)