Skip to content

Commit 0312b25

Browse files
committed
[mlir][tosa] Add tosa.table lowering to linalg.generic
Table op lowering to linalg.generic for both i8 (behaves like a gather) and a pair of gathers with a quantized interpolation. Differential Revision: https://reviews.llvm.org/D99756
1 parent dfec26b commit 0312b25

File tree

2 files changed

+202
-18
lines changed

2 files changed

+202
-18
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 159 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,37 +1407,178 @@ class ArgMaxConverter : public OpRewritePattern<tosa::ArgMaxOp> {
14071407
}
14081408
};
14091409

1410+
// Lowerings the TableOp to a series of gathers and numerica operations. This
1411+
// includes interpolation between the high/low values. For the I8 varient, this
1412+
// simplifies to a single gather operation.
1413+
class TableConverter : public OpRewritePattern<tosa::TableOp> {
1414+
public:
1415+
using OpRewritePattern<tosa::TableOp>::OpRewritePattern;
1416+
1417+
LogicalResult matchAndRewrite(tosa::TableOp op,
1418+
PatternRewriter &rewriter) const final {
1419+
auto loc = op.getLoc();
1420+
Value input = op.input();
1421+
Value table = op.table();
1422+
auto inputTy = input.getType().cast<ShapedType>();
1423+
auto tableTy = table.getType().cast<ShapedType>();
1424+
auto resultTy = op.getType().cast<ShapedType>();
1425+
1426+
if (!inputTy.hasStaticShape())
1427+
return rewriter.notifyMatchFailure(
1428+
op, "require input type to have static shape");
1429+
1430+
auto inputElementTy = inputTy.getElementType();
1431+
auto tableElementTy = tableTy.getElementType();
1432+
auto resultElementTy = resultTy.getElementType();
1433+
1434+
auto initTensor =
1435+
rewriter
1436+
.create<linalg::InitTensorOp>(loc, ArrayRef<Value>{},
1437+
resultTy.getShape(), resultElementTy)
1438+
.result();
1439+
1440+
SmallVector<AffineMap, 2> affineMaps = {
1441+
rewriter.getMultiDimIdentityMap(resultTy.getRank()),
1442+
rewriter.getMultiDimIdentityMap(resultTy.getRank())};
1443+
1444+
auto genericOp = rewriter.create<linalg::GenericOp>(
1445+
loc, resultTy, ValueRange({input}), ValueRange{initTensor}, affineMaps,
1446+
getNParallelLoopsAttrs(resultTy.getRank()));
1447+
rewriter.replaceOp(op, genericOp.getResult(0));
1448+
1449+
{
1450+
OpBuilder::InsertionGuard regionGuard(rewriter);
1451+
Block *block =
1452+
rewriter.createBlock(&genericOp.region(), genericOp.region().end(),
1453+
TypeRange({inputElementTy, resultElementTy}));
1454+
1455+
auto inputValue = block->getArgument(0);
1456+
rewriter.setInsertionPointToStart(block);
1457+
if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) &&
1458+
resultElementTy.isInteger(8)) {
1459+
Value index = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(),
1460+
inputValue);
1461+
Value extract =
1462+
rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
1463+
rewriter.create<linalg::YieldOp>(loc, extract);
1464+
return success();
1465+
}
1466+
1467+
if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) &&
1468+
resultElementTy.isInteger(32)) {
1469+
Value extend = rewriter.create<SignExtendIOp>(
1470+
loc, rewriter.getI32Type(), inputValue);
1471+
1472+
auto offset =
1473+
rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(32768));
1474+
auto seven =
1475+
rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(7));
1476+
auto one =
1477+
rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(1));
1478+
auto b1111111 =
1479+
rewriter.create<ConstantOp>(loc, rewriter.getI32IntegerAttr(127));
1480+
1481+
// Compute the index and fractional part from the input value:
1482+
// value = value + 32768
1483+
// index = value >> 7;
1484+
// fraction = 0x01111111 & value
1485+
auto extendAdd = rewriter.create<AddIOp>(loc, extend, offset);
1486+
Value index =
1487+
rewriter.create<UnsignedShiftRightOp>(loc, extendAdd, seven);
1488+
Value fraction = rewriter.create<mlir::AndOp>(loc, extendAdd, b1111111);
1489+
1490+
// Extract the base and next values from the table.
1491+
// base = (int32_t) table[index];
1492+
// next = (int32_t) table[index + 1];
1493+
Value indexPlusOne = rewriter.create<AddIOp>(loc, index, one);
1494+
1495+
index =
1496+
rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(), index);
1497+
indexPlusOne = rewriter.create<IndexCastOp>(
1498+
loc, rewriter.getIndexType(), indexPlusOne);
1499+
1500+
Value base =
1501+
rewriter.create<tensor::ExtractOp>(loc, table, ValueRange{index});
1502+
Value next = rewriter.create<tensor::ExtractOp>(
1503+
loc, table, ValueRange{indexPlusOne});
1504+
1505+
base = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), base);
1506+
next = rewriter.create<SignExtendIOp>(loc, rewriter.getI32Type(), next);
1507+
1508+
// Use the fractional part to interpolate between the input values:
1509+
// result = (base << 7) + (next - base) * fraction
1510+
Value baseScaled = rewriter.create<ShiftLeftOp>(loc, base, seven);
1511+
Value diff = rewriter.create<SubIOp>(loc, next, base);
1512+
Value diffScaled = rewriter.create<MulIOp>(loc, diff, fraction);
1513+
Value result = rewriter.create<AddIOp>(loc, baseScaled, diffScaled);
1514+
1515+
rewriter.create<linalg::YieldOp>(loc, result);
1516+
1517+
return success();
1518+
}
1519+
}
1520+
1521+
return rewriter.notifyMatchFailure(
1522+
op, "unable to create body for tosa.table op");
1523+
}
1524+
};
1525+
14101526
} // namespace
14111527

14121528
void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
14131529
RewritePatternSet *patterns) {
14141530
patterns->add<
1415-
PointwiseConverter<tosa::AddOp>, PointwiseConverter<tosa::SubOp>,
1416-
PointwiseConverter<tosa::MulOp>, PointwiseConverter<tosa::ReciprocalOp>,
1417-
PointwiseConverter<tosa::NegateOp>, PointwiseConverter<tosa::PowOp>,
1418-
PointwiseConverter<tosa::RsqrtOp>, PointwiseConverter<tosa::LogOp>,
1419-
PointwiseConverter<tosa::ExpOp>, PointwiseConverter<tosa::AbsOp>,
1420-
PointwiseConverter<tosa::TanhOp>, PointwiseConverter<tosa::BitwiseAndOp>,
1531+
// clang-format off
1532+
PointwiseConverter<tosa::AddOp>,
1533+
PointwiseConverter<tosa::SubOp>,
1534+
PointwiseConverter<tosa::MulOp>,
1535+
PointwiseConverter<tosa::NegateOp>,
1536+
PointwiseConverter<tosa::PowOp>,
1537+
PointwiseConverter<tosa::ReciprocalOp>,
1538+
PointwiseConverter<tosa::RsqrtOp>,
1539+
PointwiseConverter<tosa::LogOp>,
1540+
PointwiseConverter<tosa::ExpOp>,
1541+
PointwiseConverter<tosa::AbsOp>,
1542+
PointwiseConverter<tosa::TanhOp>,
1543+
PointwiseConverter<tosa::BitwiseAndOp>,
14211544
PointwiseConverter<tosa::BitwiseOrOp>,
14221545
PointwiseConverter<tosa::BitwiseNotOp>,
14231546
PointwiseConverter<tosa::BitwiseXorOp>,
14241547
PointwiseConverter<tosa::LogicalAndOp>,
14251548
PointwiseConverter<tosa::LogicalNotOp>,
14261549
PointwiseConverter<tosa::LogicalOrOp>,
1427-
PointwiseConverter<tosa::LogicalXorOp>, PointwiseConverter<tosa::CastOp>,
1550+
PointwiseConverter<tosa::LogicalXorOp>,
1551+
PointwiseConverter<tosa::CastOp>,
14281552
PointwiseConverter<tosa::LogicalLeftShiftOp>,
14291553
PointwiseConverter<tosa::LogicalRightShiftOp>,
1430-
PointwiseConverter<tosa::SelectOp>, PointwiseConverter<tosa::GreaterOp>,
1554+
PointwiseConverter<tosa::SelectOp>,
1555+
PointwiseConverter<tosa::GreaterOp>,
14311556
PointwiseConverter<tosa::GreaterEqualOp>,
1432-
PointwiseConverter<tosa::MaximumOp>, PointwiseConverter<tosa::MinimumOp>,
1433-
PointwiseConverter<tosa::CeilOp>, PointwiseConverter<tosa::FloorOp>,
1434-
PointwiseConverter<tosa::ClampOp>, PointwiseConverter<tosa::ReluNOp>,
1435-
PointwiseConverter<tosa::SigmoidOp>, IdentityNConverter<tosa::IdentityOp>,
1436-
IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceAllOp>,
1437-
ReduceConverter<tosa::ReduceAnyOp>, ReduceConverter<tosa::ReduceMinOp>,
1438-
ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
1439-
ReduceConverter<tosa::ReduceProdOp>, ArgMaxConverter, ConcatConverter,
1440-
PadConverter, ReshapeConverter, RescaleConverter, ReverseConverter,
1441-
TileConverter, TransposeConverter, MatMulConverter,
1557+
PointwiseConverter<tosa::MaximumOp>,
1558+
PointwiseConverter<tosa::MinimumOp>,
1559+
PointwiseConverter<tosa::CeilOp>,
1560+
PointwiseConverter<tosa::FloorOp>,
1561+
PointwiseConverter<tosa::ClampOp>,
1562+
PointwiseConverter<tosa::ReluNOp>,
1563+
PointwiseConverter<tosa::SigmoidOp>,
1564+
IdentityNConverter<tosa::IdentityOp>,
1565+
IdentityNConverter<tosa::IdentityNOp>,
1566+
ReduceConverter<tosa::ReduceAllOp>,
1567+
ReduceConverter<tosa::ReduceAnyOp>,
1568+
ReduceConverter<tosa::ReduceMinOp>,
1569+
ReduceConverter<tosa::ReduceMaxOp>,
1570+
ReduceConverter<tosa::ReduceSumOp>,
1571+
ReduceConverter<tosa::ReduceProdOp>,
1572+
ArgMaxConverter,
1573+
ConcatConverter,
1574+
PadConverter,
1575+
ReshapeConverter,
1576+
RescaleConverter,
1577+
ReverseConverter,
1578+
TableConverter,
1579+
TileConverter,
1580+
TransposeConverter,
1581+
MatMulConverter,
14421582
FullyConnectedConverter>(patterns->getContext());
1583+
// clang-format on
14431584
}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,3 +830,46 @@ func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () {
830830

831831
return
832832
}
833+
834+
// -----
835+
836+
// CHECK-LABEL: @table8
837+
func @table8(%arg0: tensor<6xi8>, %arg1: tensor<513xi8>) -> () {
838+
// CHECK: %[[INIT:.+]] = linalg.init_tensor [6]
839+
// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<6xi8>) outs(%[[INIT]] : tensor<6xi8>)
840+
// CHECK: ^bb0(%[[ARG_IN:.+]]: i8, %[[ARG_INIT:.+]]: i8)
841+
// CHECK: %[[CAST:.+]] = index_cast %[[ARG_IN]]
842+
// CHECK: %[[EXTRACT:.+]] = tensor.extract %arg1[%[[CAST]]]
843+
// CHECK: linalg.yield %[[EXTRACT]]
844+
%0 = "tosa.table"(%arg0, %arg1) : (tensor<6xi8>, tensor<513xi8>) -> (tensor<6xi8>)
845+
return
846+
}
847+
848+
// CHECK-LABEL: @table16
849+
func @table16(%arg0: tensor<6xi16>, %arg1: tensor<513xi16>) -> () {
850+
// CHECK: %[[INIT:.+]] = linalg.init_tensor [6]
851+
// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<6xi16>) outs(%[[INIT]] : tensor<6xi32>)
852+
// CHECK: ^bb0(%arg2: i16, %arg3: i32)
853+
// CHECK: %[[EXT_IN:.+]] = sexti %arg2
854+
// CHECK: %[[C32768:.+]] = constant 32768
855+
// CHECK: %[[C7:.+]] = constant 7
856+
// CHECK: %[[C1:.+]] = constant 1
857+
// CHECK: %[[C127:.+]] = constant 127
858+
// CHECK: %[[INADD:.+]] = addi %[[EXT_IN]], %[[C32768]]
859+
// CHECK: %[[IDX:.+]] = shift_right_unsigned %[[INADD]], %[[C7]]
860+
// CHECK: %[[FRACTION:.+]] = and %[[INADD]], %[[C127]]
861+
// CHECK: %[[IDXPLUS1:.+]] = addi %[[IDX]], %[[C1]]
862+
// CHECK: %[[IDX_CAST:.+]] = index_cast %[[IDX]]
863+
// CHECK: %[[IDXPLUS1_CAST:.+]] = index_cast %[[IDXPLUS1]]
864+
// CHECK: %[[BASE:.+]] = tensor.extract %arg1[%[[IDX_CAST]]]
865+
// CHECK: %[[NEXT:.+]] = tensor.extract %arg1[%[[IDXPLUS1_CAST]]]
866+
// CHECK: %[[BASE_EXT:.+]] = sexti %[[BASE]]
867+
// CHECK: %[[NEXT_EXT:.+]] = sexti %[[NEXT]]
868+
// CHECK: %[[BASE_MUL:.+]] = shift_left %[[BASE_EXT]], %[[C7]]
869+
// CHECK: %[[DIFF:.+]] = subi %[[NEXT_EXT]], %[[BASE_EXT]]
870+
// CHECK: %[[DIFF_MUL:.+]] = muli %[[DIFF]], %[[FRACTION]]
871+
// CHECK: %[[RESULT:.+]] = addi %[[BASE_MUL]], %[[DIFF_MUL]]
872+
// CHECK: linalg.yield %[[RESULT]]
873+
%0 = "tosa.table"(%arg0, %arg1) : (tensor<6xi16>, tensor<513xi16>) -> (tensor<6xi32>)
874+
return
875+
}

0 commit comments

Comments
 (0)