@@ -164,6 +164,28 @@ func.func @fold_add_splat_f32() -> tensor<10xf32> {
164
164
165
165
// -----
166
166
167
+ // CHECK-LABEL: @fold_add_zero_splat_different_shape_f32
168
+ func.func @fold_add_zero_splat_different_shape_f32 (%arg0: tensor <1 x10 xf32 >) -> tensor <1 x10 xf32 > {
169
+ %zero = " tosa.const" () {value = dense <0.0 > : tensor <1 x1 xf32 >} : () -> tensor <1 x1 xf32 >
170
+ %add = " tosa.add" (%arg0 , %zero ) : (tensor <1 x10 xf32 >, tensor <1 x1 xf32 >) -> tensor <1 x10 xf32 >
171
+ // CHECK: return %arg0
172
+ return %add : tensor <1 x10 xf32 >
173
+ }
174
+
175
+ // -----
176
+
177
+ // CHECK-LABEL: @fold_add_zero_broadcast_arg_f32
178
+ func.func @fold_add_zero_broadcast_arg_f32 (%arg0: tensor <1 x10 xf32 >) -> tensor <4 x10 xf32 > {
179
+ %zero = " tosa.const" () {value = dense <0.0 > : tensor <1 x1 xf32 >} : () -> tensor <4 x10 xf32 >
180
+ %add = " tosa.add" (%arg0 , %zero ) : (tensor <1 x10 xf32 >, tensor <4 x10 xf32 >) -> tensor <4 x10 xf32 >
181
+ // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<4x10xf32>
182
+ // CHECK: %[[ADD:.+]] = "tosa.add"(%arg0, %[[ZERO]]) : (tensor<1x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32>
183
+ // CHECK: return %[[ADD]] : tensor<4x10xf32>
184
+ return %add : tensor <4 x10 xf32 >
185
+ }
186
+
187
+ // -----
188
+
167
189
// CHECK-LABEL: @fold_div_zero_lhs_i32
168
190
func.func @fold_div_zero_lhs_i32 (%arg0: tensor <i32 >) -> tensor <i32 > {
169
191
%zero = " tosa.const" () {value = dense <0 > : tensor <i32 >} : () -> tensor <i32 >
@@ -350,6 +372,16 @@ func.func @fold_sub_splat_f32() -> tensor<10xf32> {
350
372
351
373
// -----
352
374
375
+ // CHECK-LABEL: @fold_sub_zero_splat_different_shape_f32
376
+ func.func @fold_sub_zero_splat_different_shape_f32 (%arg0: tensor <1 x10 xf32 >) -> tensor <1 x10 xf32 > {
377
+ %zero = " tosa.const" () {value = dense <0.0 > : tensor <1 x1 xf32 >} : () -> tensor <1 x1 xf32 >
378
+ %sub = " tosa.sub" (%arg0 , %zero ) : (tensor <1 x10 xf32 >, tensor <1 x1 xf32 >) -> tensor <1 x10 xf32 >
379
+ // CHECK: return %arg0
380
+ return %sub : tensor <1 x10 xf32 >
381
+ }
382
+
383
+ // -----
384
+
353
385
// CHECK-LABEL: @fold_greater_splat_f32
354
386
func.func @fold_greater_splat_f32 () -> (tensor <10 xi1 >, tensor <10 xi1 >) {
355
387
%0 = " tosa.const" () {value = dense <4.0 > : tensor <10 xf32 >} : () -> tensor <10 xf32 >
0 commit comments