-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Tosa] Fix attr type of out_shape for tosa.transpose_conv2d
#108041
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tosa Author: Longsheng Mou (CoTinker) ChangesThis patch fixes attr type of out_shape, which is i64 dense array attribute with at least 4 elements.
Fixes #107804. Full diff: https://github.com/llvm/llvm-project/pull/108041.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ab6daa39708d13..8ad741b3e65fe5 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -352,7 +352,7 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
Tosa_IntArrayAttr4:$out_pad,
Tosa_IntArrayAttr2:$stride,
- Tosa_IntArrayAttrUpto4:$out_shape,
+ Tosa_IntArrayAttrAtLeast4:$out_shape,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 14fc9c7a6730cc..99f430cefa2f1e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -151,9 +151,6 @@ def Tosa_Int8Like : Tosa_TypeLike<[Tosa_Int8], "signless-integer-8-bit-like">;
//===----------------------------------------------------------------------===//
// Attribute predicates and classes.
//===----------------------------------------------------------------------===//
-class DenseArrayMaxCt<int n> : AttrConstraint<
- CPred<"::llvm::cast<::mlir::DenseArrayAttr>($_self).size() <= " # n>,
- "with at least " # n # " elements">;
def Tosa_Fp32ArrayAttr2 : ConfinedAttr<DenseF32ArrayAttr, [DenseArrayCount<2>]>;
def Tosa_Fp32ArrayAttr3 : ConfinedAttr<DenseF32ArrayAttr, [DenseArrayCount<3>]>;
@@ -171,6 +168,8 @@ def Tosa_IntArrayAttrUpto2 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<2>
def Tosa_IntArrayAttrUpto4 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<4>]>;
def Tosa_IntArrayAttrUpto5 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<5>]>;
+def Tosa_IntArrayAttrAtLeast4 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMinCt<4>]>;
+
def Tosa_FloatAttr : Attr<CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
"arbitrary float attribute"> {
let storageType = [{ ::mlir::FloatAttr }];
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index 6774a7c568315d..853fb318c76e71 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -789,6 +789,14 @@ class DenseArrayCount<int n> : AttrConstraint<
CPred<"::llvm::cast<::mlir::DenseArrayAttr>($_self).size() == " #n>,
"with exactly " # n # " elements">;
+class DenseArrayMaxCt<int n> : AttrConstraint<
+ CPred<"::llvm::cast<::mlir::DenseArrayAttr>($_self).size() <= " # n>,
+ "with at most " # n # " elements">;
+
+class DenseArrayMinCt<int n> : AttrConstraint<
+ CPred<"::llvm::cast<::mlir::DenseArrayAttr>($_self).size() >= " # n>,
+ "with at least " # n # " elements">;
+
class DenseArrayStrictlyPositive<DenseArrayAttrBase arrayType> : AttrConstraint<
CPred<"::llvm::all_of(::llvm::cast<" # arrayType #">($_self).asArrayRef(), "
"[&](auto v) { return v > 0; })">,
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 418f7687b3cce8..0c38206e69423f 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -526,3 +526,12 @@ func.func @test_table_io_shape_mismatch(%arg0: tensor<?x16xi16>, %arg1: tensor<6
%0 = tosa.table %arg0, %arg1 : (tensor<?x16xi16>, tensor<6xi16>) -> tensor<?x15xi16>
return
}
+
+// -----
+
+// CHECK-LABEL: test_transpose_conv2d_invalid_outshape
+func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+ // expected-error@+1 {{'tosa.transpose_conv2d' op attribute 'out_shape' failed to satisfy constraint: i64 dense array attribute with at least 4 elements}}
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+ return %0 : tensor<1x32x32x16xf32>
+}
|
@llvm/pr-subscribers-mlir-core Author: Longsheng Mou (CoTinker) ChangesThis patch fixes attr type of out_shape, which is i64 dense array attribute with at least 4 elements.
Fixes #107804. Full diff: https://github.com/llvm/llvm-project/pull/108041.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ab6daa39708d13..8ad741b3e65fe5 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -352,7 +352,7 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
Tosa_IntArrayAttr4:$out_pad,
Tosa_IntArrayAttr2:$stride,
- Tosa_IntArrayAttrUpto4:$out_shape,
+ Tosa_IntArrayAttrAtLeast4:$out_shape,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
index 14fc9c7a6730cc..99f430cefa2f1e 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
@@ -151,9 +151,6 @@ def Tosa_Int8Like : Tosa_TypeLike<[Tosa_Int8], "signless-integer-8-bit-like">;
//===----------------------------------------------------------------------===//
// Attribute predicates and classes.
//===----------------------------------------------------------------------===//
-class DenseArrayMaxCt<int n> : AttrConstraint<
- CPred<"::llvm::cast<::mlir::DenseArrayAttr>($_self).size() <= " # n>,
- "with at least " # n # " elements">;
def Tosa_Fp32ArrayAttr2 : ConfinedAttr<DenseF32ArrayAttr, [DenseArrayCount<2>]>;
def Tosa_Fp32ArrayAttr3 : ConfinedAttr<DenseF32ArrayAttr, [DenseArrayCount<3>]>;
@@ -171,6 +168,8 @@ def Tosa_IntArrayAttrUpto2 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<2>
def Tosa_IntArrayAttrUpto4 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<4>]>;
def Tosa_IntArrayAttrUpto5 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMaxCt<5>]>;
+def Tosa_IntArrayAttrAtLeast4 : ConfinedAttr<DenseI64ArrayAttr, [DenseArrayMinCt<4>]>;
+
def Tosa_FloatAttr : Attr<CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
"arbitrary float attribute"> {
let storageType = [{ ::mlir::FloatAttr }];
diff --git a/mlir/include/mlir/IR/CommonAttrConstraints.td b/mlir/include/mlir/IR/CommonAttrConstraints.td
index 6774a7c568315d..853fb318c76e71 100644
--- a/mlir/include/mlir/IR/CommonAttrConstraints.td
+++ b/mlir/include/mlir/IR/CommonAttrConstraints.td
@@ -789,6 +789,14 @@ class DenseArrayCount<int n> : AttrConstraint<
CPred<"::llvm::cast<::mlir::DenseArrayAttr>($_self).size() == " #n>,
"with exactly " # n # " elements">;
+class DenseArrayMaxCt<int n> : AttrConstraint<
+ CPred<"::llvm::cast<::mlir::DenseArrayAttr>($_self).size() <= " # n>,
+ "with at most " # n # " elements">;
+
+class DenseArrayMinCt<int n> : AttrConstraint<
+ CPred<"::llvm::cast<::mlir::DenseArrayAttr>($_self).size() >= " # n>,
+ "with at least " # n # " elements">;
+
class DenseArrayStrictlyPositive<DenseArrayAttrBase arrayType> : AttrConstraint<
CPred<"::llvm::all_of(::llvm::cast<" # arrayType #">($_self).asArrayRef(), "
"[&](auto v) { return v > 0; })">,
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 418f7687b3cce8..0c38206e69423f 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -526,3 +526,12 @@ func.func @test_table_io_shape_mismatch(%arg0: tensor<?x16xi16>, %arg1: tensor<6
%0 = tosa.table %arg0, %arg1 : (tensor<?x16xi16>, tensor<6xi16>) -> tensor<?x15xi16>
return
}
+
+// -----
+
+// CHECK-LABEL: test_transpose_conv2d_invalid_outshape
+func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
+ // expected-error@+1 {{'tosa.transpose_conv2d' op attribute 'out_shape' failed to satisfy constraint: i64 dense array attribute with at least 4 elements}}
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+ return %0 : tensor<1x32x32x16xf32>
+}
|
This patch fixes attr type of out_shape, which is i64 dense array attribute with exactly 4 elements. - Fix description of DenseArrayMaxCt - Add DenseArrayMinCt and move it to CommonAttrConstraints.td - Change type of out_shape to Tosa_IntArrayAttr4
75abdbd
to
f43ec6f
Compare
I see, thanks. |
This patch fixes attr type of out_shape, which is i64 dense array attribute with exactly 4 elements.
Fixes #107804.