-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tosa] Update TileOp infer shape #134732
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Jerry-Ge (Jerry-Ge) Changesupdate to use getConstShapeValues in TileOp's shape inference Change-Id: Ie40d127eec16f7edfa3121ed9dcd5e3134138754 Full diff: https://github.com/llvm/llvm-project/pull/134732.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c8e9ad8bd3346..92cb9875187bf 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1612,19 +1612,25 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
TileOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- DenseIntElementsAttr multiplesAttr;
- if (!matchPattern(adaptor.getMultiples(), m_Constant(&multiplesAttr)))
- return failure();
-
- SmallVector<int64_t> multiples = llvm::to_vector(
- llvm::map_range(multiplesAttr.getValues<APInt>(),
- [](const APInt &val) { return val.getSExtValue(); }));
+ Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
+ SmallVector<int64_t> multiples;
+ if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
+ multiples)) {
+ auto rank =
+ cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
+ SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
+ inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
+ return success();
+ } else {
+ multiples = convertToMlirShape(multiples);
+ }
ShapeAdaptor inputShape(adaptor.getInput1().getType());
SmallVector<int64_t> outputShape;
if (!inputShape.hasRank()) {
outputShape.resize(multiples.size(), ShapedType::kDynamic);
- inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+ inferredReturnShapes.push_back(
+ ShapedTypeComponents(outputShape, inputType));
return success();
} else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
return failure();
@@ -1632,13 +1638,17 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
// Any non dynamic dimension can be multiplied to a known size.
outputShape.reserve(multiples.size());
for (int i = 0, s = inputShape.getRank(); i < s; i++) {
- int64_t dim = inputShape.getDimSize(i);
- if (dim != ShapedType::kDynamic)
- dim *= multiples[i];
- outputShape.push_back(dim);
+ if (multiples[i] == ShapedType::kDynamic) {
+ outputShape.push_back(ShapedType::kDynamic);
+ } else {
+ int64_t dim = inputShape.getDimSize(i);
+ if (dim != ShapedType::kDynamic)
+ dim *= multiples[i];
+ outputShape.push_back(dim);
+ }
}
- inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
return success();
}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 761e489bdeae5..19d5bd38535de 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -599,6 +599,17 @@ func.func @test_tile(%arg0 : tensor<2x3x?xi32>) -> () {
// -----
+// CHECK-LABEL: @test_tile_unknown_multiples
+func.func @test_tile_unknown_multiples(%arg0 : tensor<2x3x?xi32>) -> () {
+ // CHECK: %[[CST:.*]] = tosa.const_shape {value = dense<[2, -1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ // CHECK: tosa.tile %arg0, %[[CST]] : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<4x?x?xi32>
+ %cst = tosa.const_shape {value = dense<[2, -1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %0 = tosa.tile %arg0, %cst : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<?x?x?xi32>
+ return
+}
+
+// -----
+
// CHECK-LABEL: @test_transpose_static
func.func @test_transpose_static(%arg0 : tensor<3x4x5xi32>) -> () {
// CHECK: tosa.transpose %arg0 {perms = array<i32: 2, 1, 0>} : (tensor<3x4x5xi32>) -> tensor<5x4x3xi32>
@@ -1506,3 +1517,21 @@ func.func @test_mul_scalar(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<*xf
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
+
+// CHECK-LABEL: @test_tiled
+// off_value is tiled into [N, 1, 1] where N = product(arg0.shape[:])
+func.func @test_tiled(%arg0: tensor<1x2x3xf32>) -> tensor<?x1x1xf32> {
+ // CHECK-DAG: %[[CST:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
+ // CHECK-DAG: %[[CONCAT:.*]] = tosa.concat_shape {{.*}} : (!tosa.shape<1>, !tosa.shape<2>) -> !tosa.shape<3>
+ // CHECK: %[[TILED:.*]] = tosa.tile %[[CST]], %[[CONCAT]] : (tensor<1x1x1xf32>, !tosa.shape<3>) -> tensor<6x1x1xf32>
+ %off_value = "tosa.const"() { value = dense<0.5> : tensor<1x1x1xf32> } : () -> tensor<1x1x1xf32>
+ %0 = tosa.dim %arg0 { axis = 0 : i32 } : (tensor<1x2x3xf32>) -> !tosa.shape<1>
+ %1 = tosa.dim %arg0 { axis = 1 : i32 } : (tensor<1x2x3xf32>) -> !tosa.shape<1>
+ %2 = tosa.dim %arg0 { axis = 2 : i32 } : (tensor<1x2x3xf32>) -> !tosa.shape<1>
+ %3 = tosa.mul_shape %0, %1 : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<1>
+ %4 = tosa.mul_shape %3, %2 : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<1>
+ %cst_shape_1_1 = tosa.const_shape { value = dense<1> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %5 = tosa.concat_shape %4, %cst_shape_1_1 : (!tosa.shape<1>, !tosa.shape<2>) -> !tosa.shape<3>
+ %tiled = tosa.tile %off_value, %5: (tensor<1x1x1xf32>, !tosa.shape<3>) -> tensor<?x1x1xf32>
+ return %tiled : tensor<?x1x1xf32>
+}
|
@llvm/pr-subscribers-mlir-tosa Author: Jerry-Ge (Jerry-Ge) Changesupdate to use getConstShapeValues in TileOp's shape inference Change-Id: Ie40d127eec16f7edfa3121ed9dcd5e3134138754 Full diff: https://github.com/llvm/llvm-project/pull/134732.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c8e9ad8bd3346..92cb9875187bf 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1612,19 +1612,25 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
TileOp::Adaptor adaptor,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
- DenseIntElementsAttr multiplesAttr;
- if (!matchPattern(adaptor.getMultiples(), m_Constant(&multiplesAttr)))
- return failure();
-
- SmallVector<int64_t> multiples = llvm::to_vector(
- llvm::map_range(multiplesAttr.getValues<APInt>(),
- [](const APInt &val) { return val.getSExtValue(); }));
+ Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
+ SmallVector<int64_t> multiples;
+ if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
+ multiples)) {
+ auto rank =
+ cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
+ SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
+ inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
+ return success();
+ } else {
+ multiples = convertToMlirShape(multiples);
+ }
ShapeAdaptor inputShape(adaptor.getInput1().getType());
SmallVector<int64_t> outputShape;
if (!inputShape.hasRank()) {
outputShape.resize(multiples.size(), ShapedType::kDynamic);
- inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+ inferredReturnShapes.push_back(
+ ShapedTypeComponents(outputShape, inputType));
return success();
} else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
return failure();
@@ -1632,13 +1638,17 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
// Any non dynamic dimension can be multiplied to a known size.
outputShape.reserve(multiples.size());
for (int i = 0, s = inputShape.getRank(); i < s; i++) {
- int64_t dim = inputShape.getDimSize(i);
- if (dim != ShapedType::kDynamic)
- dim *= multiples[i];
- outputShape.push_back(dim);
+ if (multiples[i] == ShapedType::kDynamic) {
+ outputShape.push_back(ShapedType::kDynamic);
+ } else {
+ int64_t dim = inputShape.getDimSize(i);
+ if (dim != ShapedType::kDynamic)
+ dim *= multiples[i];
+ outputShape.push_back(dim);
+ }
}
- inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
return success();
}
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 761e489bdeae5..19d5bd38535de 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -599,6 +599,17 @@ func.func @test_tile(%arg0 : tensor<2x3x?xi32>) -> () {
// -----
+// CHECK-LABEL: @test_tile_unknown_multiples
+func.func @test_tile_unknown_multiples(%arg0 : tensor<2x3x?xi32>) -> () {
+ // CHECK: %[[CST:.*]] = tosa.const_shape {value = dense<[2, -1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ // CHECK: tosa.tile %arg0, %[[CST]] : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<4x?x?xi32>
+ %cst = tosa.const_shape {value = dense<[2, -1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3>
+ %0 = tosa.tile %arg0, %cst : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<?x?x?xi32>
+ return
+}
+
+// -----
+
// CHECK-LABEL: @test_transpose_static
func.func @test_transpose_static(%arg0 : tensor<3x4x5xi32>) -> () {
// CHECK: tosa.transpose %arg0 {perms = array<i32: 2, 1, 0>} : (tensor<3x4x5xi32>) -> tensor<5x4x3xi32>
@@ -1506,3 +1517,21 @@ func.func @test_mul_scalar(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<*xf
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
+
+// CHECK-LABEL: @test_tiled
+// off_value is tiled into [N, 1, 1] where N = product(arg0.shape[:])
+func.func @test_tiled(%arg0: tensor<1x2x3xf32>) -> tensor<?x1x1xf32> {
+ // CHECK-DAG: %[[CST:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
+ // CHECK-DAG: %[[CONCAT:.*]] = tosa.concat_shape {{.*}} : (!tosa.shape<1>, !tosa.shape<2>) -> !tosa.shape<3>
+ // CHECK: %[[TILED:.*]] = tosa.tile %[[CST]], %[[CONCAT]] : (tensor<1x1x1xf32>, !tosa.shape<3>) -> tensor<6x1x1xf32>
+ %off_value = "tosa.const"() { value = dense<0.5> : tensor<1x1x1xf32> } : () -> tensor<1x1x1xf32>
+ %0 = tosa.dim %arg0 { axis = 0 : i32 } : (tensor<1x2x3xf32>) -> !tosa.shape<1>
+ %1 = tosa.dim %arg0 { axis = 1 : i32 } : (tensor<1x2x3xf32>) -> !tosa.shape<1>
+ %2 = tosa.dim %arg0 { axis = 2 : i32 } : (tensor<1x2x3xf32>) -> !tosa.shape<1>
+ %3 = tosa.mul_shape %0, %1 : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<1>
+ %4 = tosa.mul_shape %3, %2 : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<1>
+ %cst_shape_1_1 = tosa.const_shape { value = dense<1> : tensor<2xindex> } : () -> !tosa.shape<2>
+ %5 = tosa.concat_shape %4, %cst_shape_1_1 : (!tosa.shape<1>, !tosa.shape<2>) -> !tosa.shape<3>
+ %tiled = tosa.tile %off_value, %5: (tensor<1x1x1xf32>, !tosa.shape<3>) -> tensor<?x1x1xf32>
+ return %tiled : tensor<?x1x1xf32>
+}
|
1beeaef
to
bf79892
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
update to use getConstShapeValues in TileOp's shape inference Signed-off-by: Tai Ly <[email protected]> Change-Id: Ie40d127eec16f7edfa3121ed9dcd5e3134138754
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, though might be worth adding a test case for this path: https://github.com/llvm/llvm-project/pull/134732/files#diff-90956ba24a2a97cc56a9a3659c7e46e56f1bd791a869246c6a758f9c93f1434fR1617
Right now. Quote from @Tai78641 : "There's no way to construct a failure case because there is a trait to enforce that shape input must have a shape operand. but only shape operand left is const_shape " |
Makes sense, thanks! In that case I think my preference would be to return "failure()", to avoid maintaining dead code |
update to use getConstShapeValues in TileOp's shape inference Signed-off-by: Tai Ly <[email protected]> Co-authored-by: Tai Ly <[email protected]>
update to use getConstShapeValues in TileOp's shape inference Signed-off-by: Tai Ly <[email protected]> Co-authored-by: Tai Ly <[email protected]>
update to use getConstShapeValues in TileOp's shape inference