Skip to content

Commit fc8ecac

Browse files
authored
Merge pull request #105 from Xilinx/tiagot.cherry-pick-fix-avgpool2d
Cherry-picked upstream commit to fix AvgPool2d types
2 parents 3a5f724 + b199e1d commit fc8ecac

File tree

2 files changed

+24
-10
lines changed

2 files changed

+24
-10
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -248,20 +248,20 @@ LogicalResult tosa::AvgPool2dOp::verify() {
248248
if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
249249
return emitOpError("accumulator type for integer tensor is not i32");
250250

251-
if ((inputETy.isBF16() || inputETy.isF16()) &&
252-
!(accType.isF16() || accType.isF32()))
253-
return emitOpError("accumulator type for f16/bf16 tensor is not f16/f32");
251+
if (inputETy.isF16() && !(accType.isF16() || accType.isF32()))
252+
return emitOpError("accumulator type for f16 tensor is not f16/f32");
253+
254+
if (inputETy.isBF16() && !accType.isF32())
255+
return emitOpError("accumulator type for bf16 tensor is not f32");
254256

255257
if (inputETy.isF32() && !accType.isF32())
256258
return emitOpError("accumulator type for f32 tensor is not f32");
257259

258-
if (inputETy.isF32() && resultETy.isF32())
259-
return success();
260-
if (inputETy.isBF16() && resultETy.isBF16())
261-
return success();
262-
if (inputETy.isInteger(8) && resultETy.isInteger(8))
263-
return success();
264-
if (inputETy.isInteger(16) && resultETy.isInteger(16))
260+
if ((inputETy.isF32() && resultETy.isF32()) ||
261+
(inputETy.isF16() && resultETy.isF16()) ||
262+
(inputETy.isBF16() && resultETy.isBF16()) ||
263+
(inputETy.isInteger(8) && resultETy.isInteger(8)) ||
264+
(inputETy.isInteger(16) && resultETy.isInteger(16)))
265265
return success();
266266

267267
return emitOpError("input/output element types are incompatible.");

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,20 @@ func.func @test_avg_pool2d_f32(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32
1616
return %0 : tensor<1x7x7x9xf32>
1717
}
1818

19+
// -----
20+
// CHECK-LABEL: avg_pool2d_f16
21+
func.func @test_avg_pool2d_f16(%arg0: tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16> {
22+
%0 = tosa.avg_pool2d %arg0 {acc_type = f16, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16>
23+
return %0 : tensor<1x7x7x9xf16>
24+
}
25+
26+
// -----
27+
// CHECK-LABEL: avg_pool2d_f16_accumf32
28+
func.func @test_avg_pool2d_f16_accumf32(%arg0: tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16> {
29+
%0 = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf16>) -> tensor<1x7x7x9xf16>
30+
return %0 : tensor<1x7x7x9xf16>
31+
}
32+
1933
// -----
2034
// CHECK-LABEL: avg_pool2d_i8
2135
func.func @test_avg_pool2d_i8(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> {

0 commit comments

Comments
 (0)