Skip to content

Commit 2ec9d50

Browse files
authored
Merge pull request #33 from Xilinx/christopher.FXML-1991_update_add_and_sub_folding
[FXML-1991] Update Add and Sub Constant Folding
2 parents 1265071 + eb21102 commit 2ec9d50

File tree

2 files changed

+38
-4
lines changed

2 files changed

+38
-4
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -474,8 +474,6 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
474474
auto resultTy = getType().dyn_cast<RankedTensorType>();
475475
if (!lhsTy || !rhsTy || !resultTy)
476476
return {};
477-
if (lhsTy != rhsTy)
478-
return {};
479477

480478
auto resultETy = resultTy.getElementType();
481479
auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
@@ -504,6 +502,9 @@ OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
504502
if (!lhsAttr || !rhsAttr)
505503
return {};
506504

505+
if (lhsTy != rhsTy)
506+
return {};
507+
507508
return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
508509
lhsTy);
509510
}
@@ -635,8 +636,6 @@ OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
635636
auto resultTy = getType().dyn_cast<RankedTensorType>();
636637
if (!lhsTy || !rhsTy || !resultTy)
637638
return {};
638-
if (lhsTy != rhsTy)
639-
return {};
640639

641640
auto resultETy = resultTy.getElementType();
642641
auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
@@ -655,6 +654,9 @@ OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
655654
if (!lhsAttr || !rhsAttr)
656655
return {};
657656

657+
if (lhsTy != rhsTy)
658+
return {};
659+
658660
return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
659661
lhsTy);
660662
}

mlir/test/Dialect/Tosa/constant-op-fold.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,28 @@ func.func @fold_add_splat_f32() -> tensor<10xf32> {
164164

165165
// -----
166166

167+
// CHECK-LABEL: @fold_add_zero_splat_different_shape_f32
168+
func.func @fold_add_zero_splat_different_shape_f32(%arg0: tensor<1x10xf32>) -> tensor<1x10xf32> {
169+
%zero = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
170+
%add = "tosa.add"(%arg0, %zero) : (tensor<1x10xf32>, tensor<1x1xf32>) -> tensor<1x10xf32>
171+
// CHECK: return %arg0
172+
return %add : tensor<1x10xf32>
173+
}
174+
175+
// -----
176+
177+
// CHECK-LABEL: @fold_add_zero_broadcast_arg_f32
178+
func.func @fold_add_zero_broadcast_arg_f32(%arg0: tensor<1x10xf32>) -> tensor<4x10xf32> {
179+
%zero = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<4x10xf32>
180+
%add = "tosa.add"(%arg0, %zero) : (tensor<1x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32>
181+
// CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<4x10xf32>
182+
// CHECK: %[[ADD:.+]] = "tosa.add"(%arg0, %[[ZERO]]) : (tensor<1x10xf32>, tensor<4x10xf32>) -> tensor<4x10xf32>
183+
// CHECK: return %[[ADD]] : tensor<4x10xf32>
184+
return %add : tensor<4x10xf32>
185+
}
186+
187+
// -----
188+
167189
// CHECK-LABEL: @fold_div_zero_lhs_i32
168190
func.func @fold_div_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
169191
%zero = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
@@ -350,6 +372,16 @@ func.func @fold_sub_splat_f32() -> tensor<10xf32> {
350372

351373
// -----
352374

375+
// CHECK-LABEL: @fold_sub_zero_splat_different_shape_f32
376+
func.func @fold_sub_zero_splat_different_shape_f32(%arg0: tensor<1x10xf32>) -> tensor<1x10xf32> {
377+
%zero = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
378+
%sub = "tosa.sub"(%arg0, %zero) : (tensor<1x10xf32>, tensor<1x1xf32>) -> tensor<1x10xf32>
379+
// CHECK: return %arg0
380+
return %sub : tensor<1x10xf32>
381+
}
382+
383+
// -----
384+
353385
// CHECK-LABEL: @fold_greater_splat_f32
354386
func.func @fold_greater_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
355387
%0 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>

0 commit comments

Comments
 (0)