Skip to content

[mlir][arith] Disallow casting tensor dimensions #93349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 29, 2024

Conversation

kuhar
Copy link
Member

@kuhar kuhar commented May 24, 2024

Tighten the verifier for arith cast ops to disallow changing tensor dimensions, e.g., static to dynamic. After this change:

  • arith.cast_op %x : tensor<4xi32> to tensor<4xf32> remains valid
  • arith.cast_op %x : tensor<4xi32> to tensor<?xf32> becomes invalid
  • arith.cast_op %x : tensor<?xi32> to tensor<4xf32> becomes invalid

This is mostly to simplify the op semantics. See the discussion thread for more context: https://discourse.llvm.org/t/rfc-remove-arith-math-ops-on-tensors/74357/63.

@llvmbot
Copy link
Member

llvmbot commented May 24, 2024

@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir

Author: Jakub Kuderski (kuhar)

Changes

Tighten the verifier for arith cast ops to disallow changing tensor dimensions, e.g., static to dynamic. After this change:

  • arith.cast_op %x : tensor&lt;4xi32&gt; to tensor&lt;4xf32&gt; remains valid
  • arith.cast_op %x : tensor&lt;4xi32&gt; to tensor&lt;?xf32&gt; becomes invalid
  • arith.cast_op %x : tensor&lt;?xi32&gt; to tensor&lt;4xf32&gt; becomes invalid

This is mostly to simplify the op semantics. See the discussion thread for more context: https://discourse.llvm.org/t/rfc-remove-arith-math-ops-on-tensors/74357/63.


Full diff: https://github.com/llvm/llvm-project/pull/93349.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Arith/IR/ArithOps.td (+15-2)
  • (modified) mlir/test/Dialect/Arith/canonicalize.mlir (-8)
  • (modified) mlir/test/Dialect/Arith/invalid.mlir (+25-1)
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 4e4c6fd601777..bdf264aec1d5d 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -83,12 +83,25 @@ class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
                           attr-dict `:` type($result) }];
 }
 
+// Checks that tensor input and outputs have identical shapes. This is stricker
+// than the verification done in `SameOperandsAndResultShape` that allows for
+// tensor dimensions to be 'compatible' (e.g., dynamic dimensions being
+// compatible with static ones).
+def SameInputOutputTensorDims : PredOpTrait<
+    "input and output have the same tensor dimensions",
+    AllMatchSameOperatorPred<["in", "out"],
+      "(::llvm::isa<::mlir::TensorType>($_self.getType()) ?"
+      " ::llvm::cast<::mlir::TensorType>($_self.getType()).getShape() :"
+      " ::llvm::ArrayRef<int64_t>{})">>;
+
 // Base class for arithmetic cast operations. Requires a single operand and
-// result. If either is a shaped type, then the other must be of the same shape.
+// result. If either is a shaped type, then the other must be of the same
+// shape.  In the case of tensor types, this also includes the corresponding
+// operand/result dimensions being equal.
 class Arith_CastOp<string mnemonic, TypeConstraint From, TypeConstraint To,
                    list<Trait> traits = []> :
     Arith_Op<mnemonic, traits # [Pure, SameOperandsAndResultShape,
-      DeclareOpInterfaceMethods<CastOpInterface>]>,
+      SameInputOutputTensorDims, DeclareOpInterfaceMethods<CastOpInterface>]>,
     Arguments<(ins From:$in)>,
     Results<(outs To:$out)> {
   let assemblyFormat = "$in attr-dict `:` type($in) `to` type($out)";
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 1a387c20c4b29..e4f95bb0545a2 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2950,14 +2950,6 @@ func.func @unsignedExtendConstantResource() -> tensor<i16> {
   return %ext : tensor<i16>
 }
 
-// Just checks that this doesn't crash.
-// CHECK-LABEL: @signedExtendSplatAsDynamicShape
-func.func @signedExtendSplatAsDynamicShape() -> tensor<?xi64> {
-  %splat = arith.constant dense<5> : tensor<2xi16>
-  %extsplat = arith.extsi %splat : tensor<2xi16> to tensor<?xi64>
-  return %extsplat : tensor<?xi64>
-}
-
 // CHECK-LABEL: @extsi_i0
 //       CHECK:   %[[ZERO:.*]] = arith.constant 0 : i16
 //       CHECK:   return %[[ZERO]] : i16
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index ada849220bb83..a3cfb6baa2e1d 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -1,13 +1,21 @@
 // RUN: mlir-opt -split-input-file %s -verify-diagnostics
 
 func.func @test_index_cast_shape_error(%arg0 : tensor<index>) -> tensor<2xi64> {
-  // expected-error @+1 {{'arith.index_cast' op requires the same shape for all operands and results}}
+  // expected-error @+1 {{'arith.index_cast' op failed to verify that input and output have the same tensor dimensions}}
   %0 = arith.index_cast %arg0 : tensor<index> to tensor<2xi64>
   return %0 : tensor<2xi64>
 }
 
 // -----
 
+func.func @test_index_cast_shape_dim_error(%arg0 : tensor<2xindex>) -> tensor<?xi64> {
+  // expected-error @+1 {{'arith.index_cast' op failed to verify that input and output have the same tensor dimensions}}
+  %0 = arith.index_cast %arg0 : tensor<2xindex> to tensor<?xi64>
+  return %0 : tensor<?xi64>
+}
+
+// -----
+
 func.func @test_index_cast_tensor_error(%arg0 : tensor<index>) -> i64 {
   // expected-error @+1 {{'arith.index_cast' op requires the same shape for all operands and results}}
   %0 = arith.index_cast %arg0 : tensor<index> to i64
@@ -655,6 +663,14 @@ func.func @extsi_scalable_to_fl(%arg0 : vector<[4]xi32>) {
 
 // -----
 
+func.func @extsi_tesor_dim(%arg0 : tensor<4xi32>) {
+  // expected-error@+1 {{'arith.extsi' op failed to verify that input and output have the same tensor dimensions}}
+  %0 = arith.extsi %arg0 : tensor<4xi32> to tensor<?xi64>
+  return
+}
+
+// -----
+
 func.func @extf_scalable_to_fl(%arg0 : vector<[4]xf32>) {
   // expected-error@+1 {{'arith.extf' op requires the same shape for all operands and results}}
   %0 = arith.extf %arg0 : vector<[4]xf32> to vector<4xf64>
@@ -703,6 +719,14 @@ func.func @bitcast_scalable_to_fl(%arg0 : vector<[4]xf32>) {
 
 // -----
 
+func.func @bitcast_tensor_dim(%arg0 : tensor<4xf32>) {
+  // expected-error@+1 {{'arith.bitcast' op failed to verify that input and output have the same tensor dimensions}}
+  %0 = arith.bitcast %arg0 : tensor<4xf32> to tensor<?xi32>
+  return
+}
+
+// -----
+
 func.func @trunci_fl_to_scalable(%arg0 : vector<4xi32>) {
   // expected-error@+1 {{'arith.trunci' op requires the same shape for all operands and results}}
   %0 = arith.trunci %arg0 : vector<4xi32> to vector<[4]xi8>

kuhar added 3 commits May 24, 2024 17:45
Tighten the verifier for arith cast ops to disallow changing tensor
dimensions, e.g., static to dynamic. After this change:
* `arith.cast_op %x : tensor<4xi32> to tensor<4xf32>` remains valid
* `arith.cast_op %x : tensor<4xi32> to tensor<?xf32>` becomes invalid
* `arith.cast_op %x : tensor<?xi32> to tensor<4xf32>` becomes invalid

This is mostly to simplify the op semantics. See the discussion thread
for more context: https://discourse.llvm.org/t/rfc-remove-arith-math-ops-on-tensors/74357/63.
@kuhar kuhar force-pushed the cast-tensor-dims branch from 7cb3da5 to 1f5b0da Compare May 24, 2024 21:45
@rengolin
Copy link
Member

This will help make things more sane, at least until we stop supporting tensors in arith / math.

@kuhar kuhar merged commit 5bfe4b9 into llvm:main May 29, 2024
7 checks passed
vg0204 pushed a commit to vg0204/llvm-project that referenced this pull request May 29, 2024
Tighten the verifier for arith cast ops to disallow changing tensor
dimensions, e.g., static to dynamic. After this change:
* `arith.cast_op %x : tensor<4xi32> to tensor<4xf32>` remains valid
* `arith.cast_op %x : tensor<4xi32> to tensor<?xf32>` becomes invalid
* `arith.cast_op %x : tensor<?xi32> to tensor<4xf32>` becomes invalid

This is mostly to simplify the op semantics. See the discussion thread
for more context:
https://discourse.llvm.org/t/rfc-remove-arith-math-ops-on-tensors/74357/63.
qiaojbao pushed a commit to GPUOpen-Drivers/llvm-project that referenced this pull request Jun 26, 2024
…d66d572fb

Local branch amd-gfx 6b1d66d Merged main:516a9f5183446d695c701fcdc562d543c9ccb297 into amd-gfx:295897600144
Remote branch main 5bfe4b9 [mlir][arith] Disallow casting tensor dimensions (llvm#93349)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants