Skip to content

Commit 887687a

Browse files
authored
Merge pull request #13 from Xilinx/matthias.tosa_bitint
TOSA: Test 2 bit integer matmul/cast and 3 bit unsigned cast
2 parents 0955a62 + a0c9b58 commit 887687a

File tree

5 files changed

+91
-5
lines changed

5 files changed

+91
-5
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@ class Tosa_QuantizedType<string n, list<int> params, bit signed>
3838
// Used to express accumulator results or compare results.
3939
//===----------------------------------------------------------------------===//
4040

41-
def Tosa_UInt8 : UI<8>;
42-
4341
def Tosa_Int8 : I<8>;
4442
def Tosa_Int16 : I<16>;
4543
def Tosa_Int32 : I<32>;
@@ -54,9 +52,11 @@ def Tosa_SignedInt : AnyTypeOf<[Tosa_Int8,
5452

5553
def Tosa_Bool : I<1>;
5654

57-
// No unsigned unquantized int types.
5855
def Tosa_Int : AnyTypeOf<[Tosa_Bool,
59-
Tosa_UInt8,
56+
AnyUnsignedInteger,
57+
AnySignlessInteger,
58+
// TODO: For backwards compatibility, keep Tosa_SignedInt, which is actually
59+
// a set of signless types.
6060
Tosa_SignedInt]>;
6161

6262
def Tosa_Int32Or64 : AnyTypeOf<[Tosa_Int32,

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
469469
args.front(), zero);
470470
}
471471

472-
if (arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
472+
if (dstTy.isSignlessInteger() &&
473+
arith::FPToSIOp::areCastCompatible(srcTy, dstTy)) {
473474
auto intMin = rewriter.create<arith::ConstantOp>(
474475
loc, rewriter.getF32FloatAttr(
475476
APInt::getSignedMinValue(dstTy.getIntOrFloatBitWidth())
@@ -487,6 +488,30 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args,
487488
return rewriter.create<arith::FPToSIOp>(loc, dstTy, clamped);
488489
}
489490

491+
if (dstTy.isUnsignedInteger() &&
492+
arith::FPToUIOp::areCastCompatible(srcTy, dstTy)) {
493+
auto intMin = rewriter.create<arith::ConstantOp>(
494+
loc, rewriter.getF32FloatAttr(
495+
APInt::getMinValue(dstTy.getIntOrFloatBitWidth())
496+
.getZExtValue()));
497+
498+
auto intMax = rewriter.create<arith::ConstantOp>(
499+
loc, rewriter.getF32FloatAttr(
500+
APInt::getMaxValue(dstTy.getIntOrFloatBitWidth())
501+
.getZExtValue()));
502+
503+
auto rounded = rewriter.create<math::RoundEvenOp>(loc, args[0]);
504+
505+
auto clamped = clampFloatHelper(loc, rounded, intMin, intMax, rewriter);
506+
507+
auto cast = rewriter.create<arith::FPToUIOp>(
508+
loc, rewriter.getIntegerType(dstTy.getIntOrFloatBitWidth()), clamped);
509+
// arith is signless, so temporarily cast back to being unsigned.
510+
return rewriter
511+
.create<UnrealizedConversionCastOp>(loc, dstTy, cast->getResult(0))
512+
.getResult(0);
513+
}
514+
490515
// Casting to boolean, integers need to only be checked as not-equal to
491516
// zero.
492517
if (srcTy.isa<IntegerType>() && dstTy.isInteger(1)) {
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" %s -verify-diagnostics -o -| FileCheck %s
2+
3+
func.func @test_cast(%arg0: tensor<1xf32>) -> tensor<1xf32> {
4+
// CHECK: linalg.generic
5+
// CHECK: arith.constant -2.000000e+00
6+
// CHECK: arith.constant 1.000000e+00
7+
// CHECK: math.roundeven
8+
// CHECK: arith.minf
9+
// CHECK: arith.maxf
10+
// CHECK: arith.fptosi
11+
%1 = "tosa.cast"(%arg0) : (tensor<1xf32>) -> tensor<1xi2>
12+
13+
// CHECK: linalg.generic
14+
// CHECK: arith.sitofp
15+
%2 = "tosa.cast"(%1) : (tensor<1xi2>) -> tensor<1xf32>
16+
17+
return %2 : tensor<1xf32>
18+
}
19+
20+
// -----
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named))" %s -verify-diagnostics -o -| FileCheck %s
2+
3+
// CHECK-LABEL: @matmul
4+
func.func @matmul(%arg0: tensor<1x5x3xi2>, %arg1: tensor<1x3x6xi2>) -> (tensor<1x5x6xi2>) {
5+
// CHECK: [[C0:%.+]] = arith.constant 0 : i2
6+
// CHECK: [[INIT:%.+]] = tensor.empty()
7+
// CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : i2) outs([[INIT]] : tensor<1x5x6xi2>) -> tensor<1x5x6xi2>
8+
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xi2>, tensor<1x3x6xi2>) outs([[FILLED]] : tensor<1x5x6xi2>) -> tensor<1x5x6xi2>
9+
%0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x5x3xi2>, tensor<1x3x6xi2>) -> (tensor<1x5x6xi2>)
10+
return %0 : tensor<1x5x6xi2>
11+
}
12+
13+
// -----
14+
15+
// CHECK-LABEL: @matmul
16+
func.func @matmul(%arg0: tensor<1x5x3xi2>, %arg1: tensor<1x3x6xi2>) -> (tensor<1x5x6xi4>) {
17+
// CHECK: [[C0:%.+]] = arith.constant 0 : i4
18+
// CHECK: [[INIT:%.+]] = tensor.empty()
19+
// CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : i4) outs([[INIT]] : tensor<1x5x6xi4>) -> tensor<1x5x6xi4>
20+
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xi2>, tensor<1x3x6xi2>) outs([[FILLED]] : tensor<1x5x6xi4>) -> tensor<1x5x6xi4>
21+
%0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x5x3xi2>, tensor<1x3x6xi2>) -> (tensor<1x5x6xi4>)
22+
return %0 : tensor<1x5x6xi4>
23+
}
24+
25+
// -----
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" %s | FileCheck %s
2+
3+
func.func @test_cast(%arg0: tensor<1xf32>) -> tensor<1xui3> {
4+
// CHECK: linalg.generic
5+
// CHECK: arith.constant 0.000000e+00
6+
// CHECK: arith.constant 7.000000e+00
7+
// CHECK: math.roundeven
8+
// CHECK: arith.minf
9+
// CHECK: arith.maxf
10+
// CHECK: arith.fptoui {{.*}} : f32 to i3
11+
// CHECK: builtin.unrealized_conversion_cast
12+
%1 = "tosa.cast"(%arg0) : (tensor<1xf32>) -> tensor<1xui3>
13+
14+
return %1 : tensor<1xui3>
15+
}
16+

0 commit comments

Comments
 (0)