-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][arith] Disallow casting tensor dimensions #93349
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-arith @llvm/pr-subscribers-mlir Author: Jakub Kuderski (kuhar) ChangesTighten the verifier for arith cast ops to disallow changing tensor dimensions, e.g., static to dynamic. After this change:
This is mostly to simplify the op semantics. See the discussion thread for more context: https://discourse.llvm.org/t/rfc-remove-arith-math-ops-on-tensors/74357/63. Full diff: https://github.com/llvm/llvm-project/pull/93349.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
index 4e4c6fd601777..bdf264aec1d5d 100644
--- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
+++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
@@ -83,12 +83,25 @@ class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
attr-dict `:` type($result) }];
}
+// Checks that tensor input and outputs have identical shapes. This is stricker
+// than the verification done in `SameOperandsAndResultShape` that allows for
+// tensor dimensions to be 'compatible' (e.g., dynamic dimensions being
+// compatible with static ones).
+def SameInputOutputTensorDims : PredOpTrait<
+ "input and output have the same tensor dimensions",
+ AllMatchSameOperatorPred<["in", "out"],
+ "(::llvm::isa<::mlir::TensorType>($_self.getType()) ?"
+ " ::llvm::cast<::mlir::TensorType>($_self.getType()).getShape() :"
+ " ::llvm::ArrayRef<int64_t>{})">>;
+
// Base class for arithmetic cast operations. Requires a single operand and
-// result. If either is a shaped type, then the other must be of the same shape.
+// result. If either is a shaped type, then the other must be of the same
+// shape. In the case of tensor types, this also includes the corresponding
+// operand/result dimensions being equal.
class Arith_CastOp<string mnemonic, TypeConstraint From, TypeConstraint To,
list<Trait> traits = []> :
Arith_Op<mnemonic, traits # [Pure, SameOperandsAndResultShape,
- DeclareOpInterfaceMethods<CastOpInterface>]>,
+ SameInputOutputTensorDims, DeclareOpInterfaceMethods<CastOpInterface>]>,
Arguments<(ins From:$in)>,
Results<(outs To:$out)> {
let assemblyFormat = "$in attr-dict `:` type($in) `to` type($out)";
diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir
index 1a387c20c4b29..e4f95bb0545a2 100644
--- a/mlir/test/Dialect/Arith/canonicalize.mlir
+++ b/mlir/test/Dialect/Arith/canonicalize.mlir
@@ -2950,14 +2950,6 @@ func.func @unsignedExtendConstantResource() -> tensor<i16> {
return %ext : tensor<i16>
}
-// Just checks that this doesn't crash.
-// CHECK-LABEL: @signedExtendSplatAsDynamicShape
-func.func @signedExtendSplatAsDynamicShape() -> tensor<?xi64> {
- %splat = arith.constant dense<5> : tensor<2xi16>
- %extsplat = arith.extsi %splat : tensor<2xi16> to tensor<?xi64>
- return %extsplat : tensor<?xi64>
-}
-
// CHECK-LABEL: @extsi_i0
// CHECK: %[[ZERO:.*]] = arith.constant 0 : i16
// CHECK: return %[[ZERO]] : i16
diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir
index ada849220bb83..a3cfb6baa2e1d 100644
--- a/mlir/test/Dialect/Arith/invalid.mlir
+++ b/mlir/test/Dialect/Arith/invalid.mlir
@@ -1,13 +1,21 @@
// RUN: mlir-opt -split-input-file %s -verify-diagnostics
func.func @test_index_cast_shape_error(%arg0 : tensor<index>) -> tensor<2xi64> {
- // expected-error @+1 {{'arith.index_cast' op requires the same shape for all operands and results}}
+ // expected-error @+1 {{'arith.index_cast' op failed to verify that input and output have the same tensor dimensions}}
%0 = arith.index_cast %arg0 : tensor<index> to tensor<2xi64>
return %0 : tensor<2xi64>
}
// -----
+func.func @test_index_cast_shape_dim_error(%arg0 : tensor<2xindex>) -> tensor<?xi64> {
+ // expected-error @+1 {{'arith.index_cast' op failed to verify that input and output have the same tensor dimensions}}
+ %0 = arith.index_cast %arg0 : tensor<2xindex> to tensor<?xi64>
+ return %0 : tensor<?xi64>
+}
+
+// -----
+
func.func @test_index_cast_tensor_error(%arg0 : tensor<index>) -> i64 {
// expected-error @+1 {{'arith.index_cast' op requires the same shape for all operands and results}}
%0 = arith.index_cast %arg0 : tensor<index> to i64
@@ -655,6 +663,14 @@ func.func @extsi_scalable_to_fl(%arg0 : vector<[4]xi32>) {
// -----
+func.func @extsi_tesor_dim(%arg0 : tensor<4xi32>) {
+ // expected-error@+1 {{'arith.extsi' op failed to verify that input and output have the same tensor dimensions}}
+ %0 = arith.extsi %arg0 : tensor<4xi32> to tensor<?xi64>
+ return
+}
+
+// -----
+
func.func @extf_scalable_to_fl(%arg0 : vector<[4]xf32>) {
// expected-error@+1 {{'arith.extf' op requires the same shape for all operands and results}}
%0 = arith.extf %arg0 : vector<[4]xf32> to vector<4xf64>
@@ -703,6 +719,14 @@ func.func @bitcast_scalable_to_fl(%arg0 : vector<[4]xf32>) {
// -----
+func.func @bitcast_tensor_dim(%arg0 : tensor<4xf32>) {
+ // expected-error@+1 {{'arith.bitcast' op failed to verify that input and output have the same tensor dimensions}}
+ %0 = arith.bitcast %arg0 : tensor<4xf32> to tensor<?xi32>
+ return
+}
+
+// -----
+
func.func @trunci_fl_to_scalable(%arg0 : vector<4xi32>) {
// expected-error@+1 {{'arith.trunci' op requires the same shape for all operands and results}}
%0 = arith.trunci %arg0 : vector<4xi32> to vector<[4]xi8>
|
Tighten the verifier for arith cast ops to disallow changing tensor dimensions, e.g., static to dynamic. After this change: * `arith.cast_op %x : tensor<4xi32> to tensor<4xf32>` remains valid * `arith.cast_op %x : tensor<4xi32> to tensor<?xf32>` becomes invalid * `arith.cast_op %x : tensor<?xi32> to tensor<4xf32>` becomes invalid This is mostly to simplify the op semantics. See the discussion thread for more context: https://discourse.llvm.org/t/rfc-remove-arith-math-ops-on-tensors/74357/63.
This will help make things more sane, at least until we stop supporting tensors in |
Tighten the verifier for arith cast ops to disallow changing tensor dimensions, e.g., static to dynamic. After this change: * `arith.cast_op %x : tensor<4xi32> to tensor<4xf32>` remains valid * `arith.cast_op %x : tensor<4xi32> to tensor<?xf32>` becomes invalid * `arith.cast_op %x : tensor<?xi32> to tensor<4xf32>` becomes invalid This is mostly to simplify the op semantics. See the discussion thread for more context: https://discourse.llvm.org/t/rfc-remove-arith-math-ops-on-tensors/74357/63.
…d66d572fb Local branch amd-gfx 6b1d66d Merged main:516a9f5183446d695c701fcdc562d543c9ccb297 into amd-gfx:295897600144 Remote branch main 5bfe4b9 [mlir][arith] Disallow casting tensor dimensions (llvm#93349)
Tighten the verifier for arith cast ops to disallow changing tensor dimensions, e.g., static to dynamic. After this change:
arith.cast_op %x : tensor<4xi32> to tensor<4xf32>
remains validarith.cast_op %x : tensor<4xi32> to tensor<?xf32>
becomes invalidarith.cast_op %x : tensor<?xi32> to tensor<4xf32>
becomes invalidThis is mostly to simplify the op semantics. See the discussion thread for more context: https://discourse.llvm.org/t/rfc-remove-arith-math-ops-on-tensors/74357/63.