Skip to content

Commit 5bfe4b9

Browse files
authored
[mlir][arith] Disallow casting tensor dimensions (#93349)
Tighten the verifier for arith cast ops to disallow changing tensor dimensions, e.g., static to dynamic. After this change: * `arith.cast_op %x : tensor<4xi32> to tensor<4xf32>` remains valid * `arith.cast_op %x : tensor<4xi32> to tensor<?xf32>` becomes invalid * `arith.cast_op %x : tensor<?xi32> to tensor<4xf32>` becomes invalid This is mostly to simplify the op semantics. See the discussion thread for more context: https://discourse.llvm.org/t/rfc-remove-arith-math-ops-on-tensors/74357/63.
1 parent bd5cd4b commit 5bfe4b9

File tree

3 files changed

+57
-12
lines changed

3 files changed

+57
-12
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,25 @@ class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
8383
attr-dict `:` type($result) }];
8484
}
8585

86+
// Checks that tensor input and outputs have identical shapes. This is stricker
87+
// than the verification done in `SameOperandsAndResultShape` that allows for
88+
// tensor dimensions to be 'compatible' (e.g., dynamic dimensions being
89+
// compatible with static ones).
90+
def SameInputOutputTensorDims : PredOpTrait<
91+
"input and output have the same tensor dimensions",
92+
AllMatchSameOperatorPred<["in", "out"],
93+
"(::llvm::isa<::mlir::TensorType>($_self.getType()) ?"
94+
" ::llvm::cast<::mlir::TensorType>($_self.getType()).getShape() :"
95+
" ::llvm::ArrayRef<int64_t>{})">>;
96+
8697
// Base class for arithmetic cast operations. Requires a single operand and
87-
// result. If either is a shaped type, then the other must be of the same shape.
98+
// result. If either is a shaped type, then the other must be of the same
99+
// shape. In the case of tensor types, this also includes the corresponding
100+
// operand/result dimensions being equal.
88101
class Arith_CastOp<string mnemonic, TypeConstraint From, TypeConstraint To,
89102
list<Trait> traits = []> :
90103
Arith_Op<mnemonic, traits # [Pure, SameOperandsAndResultShape,
91-
DeclareOpInterfaceMethods<CastOpInterface>]>,
104+
SameInputOutputTensorDims, DeclareOpInterfaceMethods<CastOpInterface>]>,
92105
Arguments<(ins From:$in)>,
93106
Results<(outs To:$out)> {
94107
let assemblyFormat = "$in attr-dict `:` type($in) `to` type($out)";
@@ -1231,7 +1244,7 @@ def Arith_TruncIOp : Arith_IToICastOp<"trunci"> {
12311244

12321245
def Arith_TruncFOp :
12331246
Arith_Op<"truncf",
1234-
[Pure, SameOperandsAndResultShape,
1247+
[Pure, SameOperandsAndResultShape, SameInputOutputTensorDims,
12351248
DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
12361249
DeclareOpInterfaceMethods<CastOpInterface>]>,
12371250
Arguments<(ins FloatLike:$in,

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2950,14 +2950,6 @@ func.func @unsignedExtendConstantResource() -> tensor<i16> {
29502950
return %ext : tensor<i16>
29512951
}
29522952

2953-
// Just checks that this doesn't crash.
2954-
// CHECK-LABEL: @signedExtendSplatAsDynamicShape
2955-
func.func @signedExtendSplatAsDynamicShape() -> tensor<?xi64> {
2956-
%splat = arith.constant dense<5> : tensor<2xi16>
2957-
%extsplat = arith.extsi %splat : tensor<2xi16> to tensor<?xi64>
2958-
return %extsplat : tensor<?xi64>
2959-
}
2960-
29612953
// CHECK-LABEL: @extsi_i0
29622954
// CHECK: %[[ZERO:.*]] = arith.constant 0 : i16
29632955
// CHECK: return %[[ZERO]] : i16

mlir/test/Dialect/Arith/invalid.mlir

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
// RUN: mlir-opt -split-input-file %s -verify-diagnostics
22

33
func.func @test_index_cast_shape_error(%arg0 : tensor<index>) -> tensor<2xi64> {
4-
// expected-error @+1 {{'arith.index_cast' op requires the same shape for all operands and results}}
4+
// expected-error @+1 {{'arith.index_cast' op failed to verify that input and output have the same tensor dimensions}}
55
%0 = arith.index_cast %arg0 : tensor<index> to tensor<2xi64>
66
return %0 : tensor<2xi64>
77
}
88

99
// -----
1010

11+
func.func @test_index_cast_shape_dim_error(%arg0 : tensor<2xindex>) -> tensor<?xi64> {
12+
// expected-error @+1 {{'arith.index_cast' op failed to verify that input and output have the same tensor dimensions}}
13+
%0 = arith.index_cast %arg0 : tensor<2xindex> to tensor<?xi64>
14+
return %0 : tensor<?xi64>
15+
}
16+
17+
// -----
18+
1119
func.func @test_index_cast_tensor_error(%arg0 : tensor<index>) -> i64 {
1220
// expected-error @+1 {{'arith.index_cast' op requires the same shape for all operands and results}}
1321
%0 = arith.index_cast %arg0 : tensor<index> to i64
@@ -655,6 +663,14 @@ func.func @extsi_scalable_to_fl(%arg0 : vector<[4]xi32>) {
655663

656664
// -----
657665

666+
func.func @extsi_tensor_dim(%arg0 : tensor<4xi32>) {
667+
// expected-error@+1 {{'arith.extsi' op failed to verify that input and output have the same tensor dimensions}}
668+
%0 = arith.extsi %arg0 : tensor<4xi32> to tensor<?xi64>
669+
return
670+
}
671+
672+
// -----
673+
658674
func.func @extf_scalable_to_fl(%arg0 : vector<[4]xf32>) {
659675
// expected-error@+1 {{'arith.extf' op requires the same shape for all operands and results}}
660676
%0 = arith.extf %arg0 : vector<[4]xf32> to vector<4xf64>
@@ -703,6 +719,22 @@ func.func @bitcast_scalable_to_fl(%arg0 : vector<[4]xf32>) {
703719

704720
// -----
705721

722+
func.func @bitcast_tensor_dim(%arg0 : tensor<4xf32>) {
723+
// expected-error@+1 {{'arith.bitcast' op failed to verify that input and output have the same tensor dimensions}}
724+
%0 = arith.bitcast %arg0 : tensor<4xf32> to tensor<?xi32>
725+
return
726+
}
727+
728+
// -----
729+
730+
func.func @bitcast_tensor_dim(%arg0 : tensor<?xf32>) {
731+
// expected-error@+1 {{'arith.bitcast' op failed to verify that input and output have the same tensor dimensions}}
732+
%0 = arith.bitcast %arg0 : tensor<?xf32> to tensor<4xi32>
733+
return
734+
}
735+
736+
// -----
737+
706738
func.func @trunci_fl_to_scalable(%arg0 : vector<4xi32>) {
707739
// expected-error@+1 {{'arith.trunci' op requires the same shape for all operands and results}}
708740
%0 = arith.trunci %arg0 : vector<4xi32> to vector<[4]xi8>
@@ -719,6 +751,14 @@ func.func @truncf_fl_to_scalable(%arg0 : vector<4xf64>) {
719751

720752
// -----
721753

754+
func.func @truncf_tensor_dim(%arg0 : tensor<4xf64>) {
755+
// expected-error@+1 {{'arith.truncf' op failed to verify that input and output have the same tensor dimensions}}
756+
%0 = arith.truncf %arg0 : tensor<4xf64> to tensor<?xf32>
757+
return
758+
}
759+
760+
// -----
761+
722762
func.func @extui_fl_to_scalable(%arg0 : vector<4xi32>) {
723763
// expected-error@+1 {{'arith.extui' op requires the same shape for all operands and results}}
724764
%0 = arith.extui %arg0 : vector<4xi32> to vector<[4]xi64>

0 commit comments

Comments
 (0)