Skip to content

[mlir][arith] Disallow casting tensor dimensions #93349

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,25 @@ class Arith_FloatBinaryOp<string mnemonic, list<Trait> traits = []> :
attr-dict `:` type($result) }];
}

// Checks that tensor input and outputs have identical shapes. This is stricker
// than the verification done in `SameOperandsAndResultShape` that allows for
// tensor dimensions to be 'compatible' (e.g., dynamic dimensions being
// compatible with static ones).
def SameInputOutputTensorDims : PredOpTrait<
"input and output have the same tensor dimensions",
AllMatchSameOperatorPred<["in", "out"],
"(::llvm::isa<::mlir::TensorType>($_self.getType()) ?"
" ::llvm::cast<::mlir::TensorType>($_self.getType()).getShape() :"
" ::llvm::ArrayRef<int64_t>{})">>;

// Base class for arithmetic cast operations. Requires a single operand and
// result. If either is a shaped type, then the other must be of the same shape.
// result. If either is a shaped type, then the other must be of the same
// shape. In the case of tensor types, this also includes the corresponding
// operand/result dimensions being equal.
class Arith_CastOp<string mnemonic, TypeConstraint From, TypeConstraint To,
list<Trait> traits = []> :
Arith_Op<mnemonic, traits # [Pure, SameOperandsAndResultShape,
DeclareOpInterfaceMethods<CastOpInterface>]>,
SameInputOutputTensorDims, DeclareOpInterfaceMethods<CastOpInterface>]>,
Arguments<(ins From:$in)>,
Results<(outs To:$out)> {
let assemblyFormat = "$in attr-dict `:` type($in) `to` type($out)";
Expand Down Expand Up @@ -1231,7 +1244,7 @@ def Arith_TruncIOp : Arith_IToICastOp<"trunci"> {

def Arith_TruncFOp :
Arith_Op<"truncf",
[Pure, SameOperandsAndResultShape,
[Pure, SameOperandsAndResultShape, SameInputOutputTensorDims,
DeclareOpInterfaceMethods<ArithRoundingModeInterface>,
DeclareOpInterfaceMethods<CastOpInterface>]>,
Arguments<(ins FloatLike:$in,
Expand Down
8 changes: 0 additions & 8 deletions mlir/test/Dialect/Arith/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2950,14 +2950,6 @@ func.func @unsignedExtendConstantResource() -> tensor<i16> {
return %ext : tensor<i16>
}

// Just checks that this doesn't crash.
// CHECK-LABEL: @signedExtendSplatAsDynamicShape
func.func @signedExtendSplatAsDynamicShape() -> tensor<?xi64> {
%splat = arith.constant dense<5> : tensor<2xi16>
%extsplat = arith.extsi %splat : tensor<2xi16> to tensor<?xi64>
return %extsplat : tensor<?xi64>
}

// CHECK-LABEL: @extsi_i0
// CHECK: %[[ZERO:.*]] = arith.constant 0 : i16
// CHECK: return %[[ZERO]] : i16
Expand Down
42 changes: 41 additions & 1 deletion mlir/test/Dialect/Arith/invalid.mlir
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
// RUN: mlir-opt -split-input-file %s -verify-diagnostics

func.func @test_index_cast_shape_error(%arg0 : tensor<index>) -> tensor<2xi64> {
// expected-error @+1 {{'arith.index_cast' op requires the same shape for all operands and results}}
// expected-error @+1 {{'arith.index_cast' op failed to verify that input and output have the same tensor dimensions}}
%0 = arith.index_cast %arg0 : tensor<index> to tensor<2xi64>
return %0 : tensor<2xi64>
}

// -----

func.func @test_index_cast_shape_dim_error(%arg0 : tensor<2xindex>) -> tensor<?xi64> {
// expected-error @+1 {{'arith.index_cast' op failed to verify that input and output have the same tensor dimensions}}
%0 = arith.index_cast %arg0 : tensor<2xindex> to tensor<?xi64>
return %0 : tensor<?xi64>
}

// -----

func.func @test_index_cast_tensor_error(%arg0 : tensor<index>) -> i64 {
// expected-error @+1 {{'arith.index_cast' op requires the same shape for all operands and results}}
%0 = arith.index_cast %arg0 : tensor<index> to i64
Expand Down Expand Up @@ -655,6 +663,14 @@ func.func @extsi_scalable_to_fl(%arg0 : vector<[4]xi32>) {

// -----

func.func @extsi_tensor_dim(%arg0 : tensor<4xi32>) {
// expected-error@+1 {{'arith.extsi' op failed to verify that input and output have the same tensor dimensions}}
%0 = arith.extsi %arg0 : tensor<4xi32> to tensor<?xi64>
return
}

// -----

func.func @extf_scalable_to_fl(%arg0 : vector<[4]xf32>) {
// expected-error@+1 {{'arith.extf' op requires the same shape for all operands and results}}
%0 = arith.extf %arg0 : vector<[4]xf32> to vector<4xf64>
Expand Down Expand Up @@ -703,6 +719,22 @@ func.func @bitcast_scalable_to_fl(%arg0 : vector<[4]xf32>) {

// -----

func.func @bitcast_tensor_dim(%arg0 : tensor<4xf32>) {
// expected-error@+1 {{'arith.bitcast' op failed to verify that input and output have the same tensor dimensions}}
%0 = arith.bitcast %arg0 : tensor<4xf32> to tensor<?xi32>
return
}

// -----

func.func @bitcast_tensor_dim(%arg0 : tensor<?xf32>) {
// expected-error@+1 {{'arith.bitcast' op failed to verify that input and output have the same tensor dimensions}}
%0 = arith.bitcast %arg0 : tensor<?xf32> to tensor<4xi32>
return
}

// -----

func.func @trunci_fl_to_scalable(%arg0 : vector<4xi32>) {
// expected-error@+1 {{'arith.trunci' op requires the same shape for all operands and results}}
%0 = arith.trunci %arg0 : vector<4xi32> to vector<[4]xi8>
Expand All @@ -719,6 +751,14 @@ func.func @truncf_fl_to_scalable(%arg0 : vector<4xf64>) {

// -----

func.func @truncf_tensor_dim(%arg0 : tensor<4xf64>) {
// expected-error@+1 {{'arith.truncf' op failed to verify that input and output have the same tensor dimensions}}
%0 = arith.truncf %arg0 : tensor<4xf64> to tensor<?xf32>
return
}

// -----

func.func @extui_fl_to_scalable(%arg0 : vector<4xi32>) {
// expected-error@+1 {{'arith.extui' op requires the same shape for all operands and results}}
%0 = arith.extui %arg0 : vector<4xi32> to vector<[4]xi64>
Expand Down
Loading