Skip to content

[mlir][tosa] Update TileOp infer shape #134732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 9, 2025
Merged

[mlir][tosa] Update TileOp infer shape #134732

merged 1 commit into from
Apr 9, 2025

Conversation

Jerry-Ge
Copy link
Member

@Jerry-Ge Jerry-Ge commented Apr 7, 2025

update to use getConstShapeValues in TileOp's shape inference

@llvmbot
Copy link
Member

llvmbot commented Apr 7, 2025

@llvm/pr-subscribers-mlir

Author: Jerry-Ge (Jerry-Ge)

Changes

update to use getConstShapeValues in TileOp's shape inference

Change-Id: Ie40d127eec16f7edfa3121ed9dcd5e3134138754


Full diff: https://github.com/llvm/llvm-project/pull/134732.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+23-13)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+29)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c8e9ad8bd3346..92cb9875187bf 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1612,19 +1612,25 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     TileOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  DenseIntElementsAttr multiplesAttr;
-  if (!matchPattern(adaptor.getMultiples(), m_Constant(&multiplesAttr)))
-    return failure();
-
-  SmallVector<int64_t> multiples = llvm::to_vector(
-      llvm::map_range(multiplesAttr.getValues<APInt>(),
-                      [](const APInt &val) { return val.getSExtValue(); }));
+  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
+  SmallVector<int64_t> multiples;
+  if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
+                               multiples)) {
+    auto rank =
+        cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
+    SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
+    inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
+    return success();
+  } else {
+    multiples = convertToMlirShape(multiples);
+  }
 
   ShapeAdaptor inputShape(adaptor.getInput1().getType());
   SmallVector<int64_t> outputShape;
   if (!inputShape.hasRank()) {
     outputShape.resize(multiples.size(), ShapedType::kDynamic);
-    inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+    inferredReturnShapes.push_back(
+        ShapedTypeComponents(outputShape, inputType));
     return success();
   } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
     return failure();
@@ -1632,13 +1638,17 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
   // Any non dynamic dimension can be multiplied to a known size.
   outputShape.reserve(multiples.size());
   for (int i = 0, s = inputShape.getRank(); i < s; i++) {
-    int64_t dim = inputShape.getDimSize(i);
-    if (dim != ShapedType::kDynamic)
-      dim *= multiples[i];
-    outputShape.push_back(dim);
+    if (multiples[i] == ShapedType::kDynamic) {
+      outputShape.push_back(ShapedType::kDynamic);
+    } else {
+      int64_t dim = inputShape.getDimSize(i);
+      if (dim != ShapedType::kDynamic)
+        dim *= multiples[i];
+      outputShape.push_back(dim);
+    }
   }
 
-  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
   return success();
 }
 
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 761e489bdeae5..19d5bd38535de 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -599,6 +599,17 @@ func.func @test_tile(%arg0 : tensor<2x3x?xi32>) -> () {
 
 // -----
 
+// CHECK-LABEL: @test_tile_unknown_multiples
+func.func @test_tile_unknown_multiples(%arg0 : tensor<2x3x?xi32>) -> () {
+  // CHECK: %[[CST:.*]] = tosa.const_shape {value = dense<[2, -1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  // CHECK: tosa.tile %arg0, %[[CST]] : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<4x?x?xi32>
+  %cst = tosa.const_shape {value = dense<[2, -1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %0 = tosa.tile %arg0, %cst : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<?x?x?xi32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @test_transpose_static
 func.func @test_transpose_static(%arg0 : tensor<3x4x5xi32>) -> () {
   // CHECK: tosa.transpose %arg0 {perms = array<i32: 2, 1, 0>} : (tensor<3x4x5xi32>) -> tensor<5x4x3xi32>
@@ -1506,3 +1517,21 @@ func.func @test_mul_scalar(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<*xf
   %0 = tosa.mul %arg0, %arg1, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<*xf32>
   return %0 : tensor<*xf32>
 }
+
+// CHECK-LABEL: @test_tiled
+// off_value is tiled into [N, 1, 1] where N = product(arg0.shape[:])
+func.func @test_tiled(%arg0: tensor<1x2x3xf32>) -> tensor<?x1x1xf32> {
+    // CHECK-DAG: %[[CST:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
+    // CHECK-DAG: %[[CONCAT:.*]] = tosa.concat_shape {{.*}} : (!tosa.shape<1>, !tosa.shape<2>) -> !tosa.shape<3>
+    // CHECK: %[[TILED:.*]] = tosa.tile %[[CST]], %[[CONCAT]] : (tensor<1x1x1xf32>, !tosa.shape<3>) -> tensor<6x1x1xf32>
+    %off_value = "tosa.const"() { value = dense<0.5> : tensor<1x1x1xf32> } : () -> tensor<1x1x1xf32>
+    %0 = tosa.dim %arg0 { axis = 0 : i32 } : (tensor<1x2x3xf32>) -> !tosa.shape<1>
+    %1 = tosa.dim %arg0 { axis = 1 : i32 } : (tensor<1x2x3xf32>) -> !tosa.shape<1>
+    %2 = tosa.dim %arg0 { axis = 2 : i32 } : (tensor<1x2x3xf32>) -> !tosa.shape<1>
+    %3 = tosa.mul_shape %0, %1 : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<1>
+    %4 = tosa.mul_shape %3, %2 : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<1>
+    %cst_shape_1_1 = tosa.const_shape { value = dense<1> : tensor<2xindex> } : () -> !tosa.shape<2>
+    %5 = tosa.concat_shape %4, %cst_shape_1_1 : (!tosa.shape<1>, !tosa.shape<2>) -> !tosa.shape<3>
+    %tiled = tosa.tile %off_value, %5: (tensor<1x1x1xf32>, !tosa.shape<3>) -> tensor<?x1x1xf32>
+    return %tiled : tensor<?x1x1xf32>
+}

@llvmbot
Copy link
Member

llvmbot commented Apr 7, 2025

@llvm/pr-subscribers-mlir-tosa

Author: Jerry-Ge (Jerry-Ge)

Changes

update to use getConstShapeValues in TileOp's shape inference

Change-Id: Ie40d127eec16f7edfa3121ed9dcd5e3134138754


Full diff: https://github.com/llvm/llvm-project/pull/134732.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+23-13)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+29)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index c8e9ad8bd3346..92cb9875187bf 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1612,19 +1612,25 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
     TileOp::Adaptor adaptor,
     SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
-  DenseIntElementsAttr multiplesAttr;
-  if (!matchPattern(adaptor.getMultiples(), m_Constant(&multiplesAttr)))
-    return failure();
-
-  SmallVector<int64_t> multiples = llvm::to_vector(
-      llvm::map_range(multiplesAttr.getValues<APInt>(),
-                      [](const APInt &val) { return val.getSExtValue(); }));
+  Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType());
+  SmallVector<int64_t> multiples;
+  if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(),
+                               multiples)) {
+    auto rank =
+        cast<tosa::shapeType>(adaptor.getMultiples().getType()).getRank();
+    SmallVector<int64_t> fallback(rank, ShapedType::kDynamic);
+    inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType));
+    return success();
+  } else {
+    multiples = convertToMlirShape(multiples);
+  }
 
   ShapeAdaptor inputShape(adaptor.getInput1().getType());
   SmallVector<int64_t> outputShape;
   if (!inputShape.hasRank()) {
     outputShape.resize(multiples.size(), ShapedType::kDynamic);
-    inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+    inferredReturnShapes.push_back(
+        ShapedTypeComponents(outputShape, inputType));
     return success();
   } else if (static_cast<size_t>(inputShape.getRank()) != multiples.size())
     return failure();
@@ -1632,13 +1638,17 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents(
   // Any non dynamic dimension can be multiplied to a known size.
   outputShape.reserve(multiples.size());
   for (int i = 0, s = inputShape.getRank(); i < s; i++) {
-    int64_t dim = inputShape.getDimSize(i);
-    if (dim != ShapedType::kDynamic)
-      dim *= multiples[i];
-    outputShape.push_back(dim);
+    if (multiples[i] == ShapedType::kDynamic) {
+      outputShape.push_back(ShapedType::kDynamic);
+    } else {
+      int64_t dim = inputShape.getDimSize(i);
+      if (dim != ShapedType::kDynamic)
+        dim *= multiples[i];
+      outputShape.push_back(dim);
+    }
   }
 
-  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
+  inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
   return success();
 }
 
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 761e489bdeae5..19d5bd38535de 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -599,6 +599,17 @@ func.func @test_tile(%arg0 : tensor<2x3x?xi32>) -> () {
 
 // -----
 
+// CHECK-LABEL: @test_tile_unknown_multiples
+func.func @test_tile_unknown_multiples(%arg0 : tensor<2x3x?xi32>) -> () {
+  // CHECK: %[[CST:.*]] = tosa.const_shape {value = dense<[2, -1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  // CHECK: tosa.tile %arg0, %[[CST]] : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<4x?x?xi32>
+  %cst = tosa.const_shape {value = dense<[2, -1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3>
+  %0 = tosa.tile %arg0, %cst : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<?x?x?xi32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @test_transpose_static
 func.func @test_transpose_static(%arg0 : tensor<3x4x5xi32>) -> () {
   // CHECK: tosa.transpose %arg0 {perms = array<i32: 2, 1, 0>} : (tensor<3x4x5xi32>) -> tensor<5x4x3xi32>
@@ -1506,3 +1517,21 @@ func.func @test_mul_scalar(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<*xf
   %0 = tosa.mul %arg0, %arg1, %shift : (tensor<f32>, tensor<f32>, tensor<1xi8>) -> tensor<*xf32>
   return %0 : tensor<*xf32>
 }
+
+// CHECK-LABEL: @test_tiled
+// off_value is tiled into [N, 1, 1] where N = product(arg0.shape[:])
+func.func @test_tiled(%arg0: tensor<1x2x3xf32>) -> tensor<?x1x1xf32> {
+    // CHECK-DAG: %[[CST:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32>
+    // CHECK-DAG: %[[CONCAT:.*]] = tosa.concat_shape {{.*}} : (!tosa.shape<1>, !tosa.shape<2>) -> !tosa.shape<3>
+    // CHECK: %[[TILED:.*]] = tosa.tile %[[CST]], %[[CONCAT]] : (tensor<1x1x1xf32>, !tosa.shape<3>) -> tensor<6x1x1xf32>
+    %off_value = "tosa.const"() { value = dense<0.5> : tensor<1x1x1xf32> } : () -> tensor<1x1x1xf32>
+    %0 = tosa.dim %arg0 { axis = 0 : i32 } : (tensor<1x2x3xf32>) -> !tosa.shape<1>
+    %1 = tosa.dim %arg0 { axis = 1 : i32 } : (tensor<1x2x3xf32>) -> !tosa.shape<1>
+    %2 = tosa.dim %arg0 { axis = 2 : i32 } : (tensor<1x2x3xf32>) -> !tosa.shape<1>
+    %3 = tosa.mul_shape %0, %1 : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<1>
+    %4 = tosa.mul_shape %3, %2 : (!tosa.shape<1>, !tosa.shape<1>) -> !tosa.shape<1>
+    %cst_shape_1_1 = tosa.const_shape { value = dense<1> : tensor<2xindex> } : () -> !tosa.shape<2>
+    %5 = tosa.concat_shape %4, %cst_shape_1_1 : (!tosa.shape<1>, !tosa.shape<2>) -> !tosa.shape<3>
+    %tiled = tosa.tile %off_value, %5: (tensor<1x1x1xf32>, !tosa.shape<3>) -> tensor<?x1x1xf32>
+    return %tiled : tensor<?x1x1xf32>
+}

@Jerry-Ge Jerry-Ge force-pushed the tile_op branch 2 times, most recently from 1beeaef to bf79892 Compare April 7, 2025 21:16
Copy link

github-actions bot commented Apr 7, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

update to use getConstShapeValues in TileOp's shape inference

Signed-off-by: Tai Ly <[email protected]>
Change-Id: Ie40d127eec16f7edfa3121ed9dcd5e3134138754
Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Jerry-Ge
Copy link
Member Author

Jerry-Ge commented Apr 8, 2025

LGTM, though might be worth adding a test case for this path: https://github.com/llvm/llvm-project/pull/134732/files#diff-90956ba24a2a97cc56a9a3659c7e46e56f1bd791a869246c6a758f9c93f1434fR1617

Right now. Quote from @Tai78641 : "There's no way to construct a failure case because there is a trait to enforce that shape input must have a shape operand. but only shape operand left is const_shape "

@lhutton1
Copy link
Contributor

lhutton1 commented Apr 9, 2025

Makes sense, thanks! In that case I think my preference would be to return "failure()", to avoid maintaining dead code

@Jerry-Ge Jerry-Ge merged commit 751c3f5 into llvm:main Apr 9, 2025
11 checks passed
AllinLeeYL pushed a commit to AllinLeeYL/llvm-project that referenced this pull request Apr 10, 2025
update to use getConstShapeValues in TileOp's shape inference

Signed-off-by: Tai Ly <[email protected]>
Co-authored-by: Tai Ly <[email protected]>
var-const pushed a commit to ldionne/llvm-project that referenced this pull request Apr 17, 2025
update to use getConstShapeValues in TileOp's shape inference

Signed-off-by: Tai Ly <[email protected]>
Co-authored-by: Tai Ly <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants