Skip to content

Commit 1a431bc

Browse files
authored
[mlir][Tosa] Fix attr type of out_shape for tosa.transpose_conv2d (#108041)
This patch fixes attr type of out_shape, which is i64 dense array attribute with exactly 4 elements. - Fix description of DenseArrayMaxCt - Add DenseArrayMinCt and move it to CommonAttrConstraints.td - Change type of out_shape to Tosa_IntArrayAttr4 Fixes #107804.
1 parent 44d1221 commit 1a431bc

File tree

4 files changed

+18
-4
lines changed

4 files changed

+18
-4
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
347347
Tosa_Tensor1D:$bias,
348348
Tosa_IntArrayAttr4:$out_pad,
349349
Tosa_IntArrayAttr2:$stride,
350-
Tosa_IntArrayAttrUpto4:$out_shape,
350+
Tosa_IntArrayAttr4:$out_shape,
351351
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
352352
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
353353
);

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,6 @@ def Tosa_Int8Like : Tosa_TypeLike<[Tosa_Int8], "signless-integer-8-bit-like">;
176176
//===----------------------------------------------------------------------===//
177177
// Attribute predicates and classes.
178178
//===----------------------------------------------------------------------===//
179-
class DenseArrayMaxCt<int n> : AttrConstraint<
180-
CPred<"::llvm::cast<::mlir::DenseArrayAttr>($_self).size() <= " # n>,
181-
"with at least " # n # " elements">;
182179

183180
def Tosa_Fp32ArrayAttr2 : ConfinedAttr<DenseF32ArrayAttr, [DenseArrayCount<2>]>;
184181
def Tosa_Fp32ArrayAttr3 : ConfinedAttr<DenseF32ArrayAttr, [DenseArrayCount<3>]>;

mlir/include/mlir/IR/CommonAttrConstraints.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,14 @@ class DenseArrayCount<int n> : AttrConstraint<
789789
CPred<"::llvm::cast<::mlir::DenseArrayAttr>($_self).size() == " #n>,
790790
"with exactly " # n # " elements">;
791791

792+
class DenseArrayMaxCt<int n> : AttrConstraint<
793+
CPred<"::llvm::cast<::mlir::DenseArrayAttr>($_self).size() <= " # n>,
794+
"with at most " # n # " elements">;
795+
796+
class DenseArrayMinCt<int n> : AttrConstraint<
797+
CPred<"::llvm::cast<::mlir::DenseArrayAttr>($_self).size() >= " # n>,
798+
"with at least " # n # " elements">;
799+
792800
class DenseArrayStrictlyPositive<DenseArrayAttrBase arrayType> : AttrConstraint<
793801
CPred<"::llvm::all_of(::llvm::cast<" # arrayType #">($_self).asArrayRef(), "
794802
"[&](auto v) { return v > 0; })">,

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,3 +578,12 @@ func.func @test_table_io_shape_mismatch(%arg0: tensor<?x16xi16>, %arg1: tensor<6
578578
%0 = tosa.table %arg0, %arg1 : (tensor<?x16xi16>, tensor<6xi16>) -> tensor<?x15xi16>
579579
return
580580
}
581+
582+
// -----
583+
584+
// CHECK-LABEL: test_transpose_conv2d_invalid_outshape
585+
func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
586+
// expected-error@+1 {{'tosa.transpose_conv2d' op attribute 'out_shape' failed to satisfy constraint: i64 dense array attribute with exactly 4 elements}}
587+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
588+
return %0 : tensor<1x32x32x16xf32>
589+
}

0 commit comments

Comments
 (0)