Skip to content

Commit 22d05ac

Browse files
authored
[Tosa] Add local_bound attribute (#73001)
This adds an optional bool attribute, local_bound, with default false, to following ops per TOSA spec 0.90: CONV2D CONV3D DEPTHWISE_CONV2D FFT2D RFFT2D TRANSPOSE_CONV2D also added tests in ops.mlir to validate this attribute is optional Signed-off-by: Tai Ly <[email protected]>
1 parent ea47887 commit 22d05ac

File tree

2 files changed

+55
-7
lines changed

2 files changed

+55
-7
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
107107
Tosa_IntArrayAttr4:$pad,
108108
Tosa_IntArrayAttr2:$stride,
109109
Tosa_IntArrayAttr2:$dilation,
110-
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
110+
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
111+
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
111112
);
112113

113114
let results = (outs
@@ -136,7 +137,8 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
136137
Tosa_IntArrayAttr6:$pad,
137138
Tosa_IntArrayAttr3:$stride,
138139
Tosa_IntArrayAttr3:$dilation,
139-
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
140+
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
141+
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
140142
);
141143

142144
let results = (outs
@@ -166,7 +168,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
166168
Tosa_IntArrayAttr4:$pad,
167169
Tosa_IntArrayAttr2:$stride,
168170
Tosa_IntArrayAttr2:$dilation,
169-
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
171+
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
172+
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
170173
);
171174

172175
let results = (outs
@@ -201,7 +204,8 @@ def Tosa_FFT2dOp : Tosa_InferShapedTypeOp<"fft2d"> {
201204
Tosa_Tensor3D:$input_real,
202205
Tosa_Tensor3D:$input_imag,
203206

204-
BoolAttr:$inverse
207+
BoolAttr:$inverse,
208+
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
205209
);
206210

207211
let results = (outs
@@ -315,7 +319,8 @@ def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> {
315319
}];
316320

317321
let arguments = (ins
318-
Tosa_Tensor3D:$input
322+
Tosa_Tensor3D:$input,
323+
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
319324
);
320325

321326
let results = (outs
@@ -347,7 +352,8 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
347352
Tosa_IntArrayAttr4:$out_pad,
348353
Tosa_IntArrayAttr2:$stride,
349354
Tosa_IntArrayAttrUpto4:$out_shape,
350-
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
355+
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
356+
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
351357
);
352358

353359
let results = (outs

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func.func @test_avg_pool2d_q8(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>
5454
// -----
5555
// CHECK-LABEL: conv2d
5656
func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
57-
%0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
57+
%0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<8x1x1x4xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
5858
return %0 : tensor<1x4x4x8xf32>
5959
}
6060

@@ -68,20 +68,48 @@ func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8>
6868
return %3 : tensor<1x1x1x3xi8>
6969
}
7070

71+
// -----
72+
// CHECK-LABEL: conv3d
73+
func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xf32>, %arg1: tensor<34x1x1x1x17xf32>, %arg2: tensor<34xf32>) -> tensor<1x4x8x21x34xf32> {
74+
%0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<34xf32>) -> tensor<1x4x8x21x34xf32>
75+
return %0 : tensor<1x4x8x21x34xf32>
76+
}
77+
78+
// -----
79+
// CHECK-LABEL: conv3d_with_local_bound
80+
func.func @test_conv3d_with_local_bound(%arg0: tensor<1x4x8x21x17xf32>, %arg1: tensor<34x1x1x1x17xf32>, %arg2: tensor<34xf32>) -> tensor<1x4x8x21x34xf32> {
81+
%0 = tosa.conv3d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1, 1>, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, local_bound = true} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<34xf32>) -> tensor<1x4x8x21x34xf32>
82+
return %0 : tensor<1x4x8x21x34xf32>
83+
}
84+
7185
// -----
7286
// CHECK-LABEL: depthwise_conv2d
7387
func.func @test_depthwise_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<1x1x4x2xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
7488
%0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x4x4x4xf32>, tensor<1x1x4x2xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
7589
return %0 : tensor<1x4x4x8xf32>
7690
}
7791

92+
// -----
93+
// CHECK-LABEL: depthwise_conv2d_with_local_bound
94+
func.func @test_depthwise_conv2d_with_local_bound(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<1x1x4x2xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {
95+
%0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, local_bound = true} : (tensor<1x4x4x4xf32>, tensor<1x1x4x2xf32>, tensor<8xf32>) -> tensor<1x4x4x8xf32>
96+
return %0 : tensor<1x4x4x8xf32>
97+
}
98+
7899
// -----
79100
// CHECK-LABEL: fft2d
80101
func.func @test_fft2d(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
81102
%0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>)
82103
return %0, %1 : tensor<1x4x8xf32>, tensor<1x4x8xf32>
83104
}
84105

106+
// -----
107+
// CHECK-LABEL: fft2d_with_local_bound
108+
func.func @test_fft2d_with_local_bound(%arg0: tensor<1x4x8xf32>, %arg1: tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>) {
109+
%0, %1 = tosa.fft2d %arg0, %arg1 {inverse = false, local_bound = true} : (tensor<1x4x8xf32>, tensor<1x4x8xf32>) -> (tensor<1x4x8xf32>, tensor<1x4x8xf32>)
110+
return %0, %1 : tensor<1x4x8xf32>, tensor<1x4x8xf32>
111+
}
112+
85113
// -----
86114
// CHECK-LABEL: fully_connected
87115
func.func @test_fully_connected(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>, %arg2: tensor<28xf32>) -> tensor<14x28xf32> {
@@ -124,13 +152,27 @@ func.func @test_rfft2d(%arg0: tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tenso
124152
return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32>
125153
}
126154

155+
// -----
156+
// CHECK-LABEL: rfft2d_with_local_bound
157+
func.func @test_rfft2d_with_local_bound(%arg0: tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>) {
158+
%0, %1 = tosa.rfft2d %arg0 {local_bound = true} : (tensor<13x8x16xf32>) -> (tensor<13x8x9xf32>, tensor<13x8x9xf32>)
159+
return %0, %1 : tensor<13x8x9xf32>, tensor<13x8x9xf32>
160+
}
161+
127162
// -----
128163
// CHECK-LABEL: transpose_conv2d
129164
func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
130165
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
131166
return %0 : tensor<1x32x32x16xf32>
132167
}
133168

169+
// -----
170+
// CHECK-LABEL: transpose_conv2d_with_local_bound
171+
func.func @test_transpose_conv2d_with_local_bound(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>, %arg2: tensor<16xf32>) -> tensor<1x32x32x16xf32> {
172+
%0 = tosa.transpose_conv2d %arg0, %arg1, %arg2 {out_pad = array<i64: 0, 0, 0, 0>, out_shape = array<i64: 1, 32, 32, 16>, stride = array<i64: 1, 1>, local_bound = false} : (tensor<1x32x32x8xf32>, tensor<16x1x1x8xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
173+
return %0 : tensor<1x32x32x16xf32>
174+
}
175+
134176
// -----
135177
// CHECK-LABEL: clamp
136178
func.func @test_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {

0 commit comments

Comments
 (0)